File size: 13,774 Bytes
4eff328 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | # Copyright (c) 2026 Salvatore Pennacchio <jtatopenn@libero.it>
# Distributed under the Business Source License 1.1 (BSL 1.1)
# See LICENSE.md in the project root for full license terms.
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
# ββ optional JAX import (same pattern as registry) ββββββββββββββββββββββββββββ
try:
from .registry import HAS_JAX
except ImportError:
try:
import jax
HAS_JAX = True
except ImportError:
HAS_JAX = False
if HAS_JAX:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Gate-ID encoding (shared between _apply_gate_fast_step and the beast-mode
# command builder in the dashboard).
#
# 0 I (identity) 7 T
# 1 H 8 Tdg
# 2 X 9 Rx(ΞΈ)
# 3 Y 10 Ry(ΞΈ)
# 4 Z 11 Rz(ΞΈ)
# 5 S 12 Phase / P(ΞΈ) / U1(ΞΈ)
# 6 Sdg ββ 2-qubit gates ββ
# 20 CX / CNOT
# 21 CZ
# 22 CP(ΞΈ) / CRZ(ΞΈ)
# 23 SWAP
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if HAS_JAX:
@jax.jit
def _apply_gate_fast_step(sv: "jnp.ndarray",
operation: "jnp.ndarray"):
"""
Apply a single quantum gate to statevector *sv*.
Parameters
----------
sv : complex128 statevector of shape (2**n,)
operation : float64 array [g_id, q1, q2, param]
g_id β gate identifier (see table above)
q1 β target qubit (1-qubit gates) or control qubit (2-qubit)
q2 β target qubit for 2-qubit gates; unused for 1-qubit
param β rotation angle in radians; 0.0 for non-parametric gates
Returns
-------
(new_sv, None) β compatible with jax.lax.scan
"""
g_id = operation[0].astype(jnp.int32)
q1 = operation[1].astype(jnp.int32)
q2 = operation[2].astype(jnp.int32)
param = operation[3]
dim = sv.shape[0]
inv2 = jnp.float64(1.0 / jnp.sqrt(2.0))
half_p = param * jnp.float64(0.5)
cos_p = jnp.cos(half_p).astype(jnp.complex128)
sin_p = jnp.sin(half_p).astype(jnp.complex128)
exp_pos = jnp.exp( 1j * param).astype(jnp.complex128)
exp_neg = jnp.exp(-1j * param).astype(jnp.complex128)
exp_ph4 = jnp.exp( 1j * jnp.pi / 4.0).astype(jnp.complex128)
exp_mh4 = jnp.exp(-1j * jnp.pi / 4.0).astype(jnp.complex128)
# ββ 1-qubit gate matrix selection via lax.switch ββββββββββββββ
# Index must be in [0, 12]; anything outside is clamped to 0 (I).
safe_gid = jnp.clip(g_id, 0, 12)
g_1q = jax.lax.switch(
safe_gid,
[
# 0 I
lambda _: jnp.eye(2, dtype=jnp.complex128),
# 1 H
lambda _: jnp.array(
[[inv2, inv2],
[inv2, -inv2]], dtype=jnp.complex128),
# 2 X
lambda _: jnp.array(
[[0.0+0j, 1.0+0j],
[1.0+0j, 0.0+0j]], dtype=jnp.complex128),
# 3 Y
lambda _: jnp.array(
[[0.0+0j, -1j],
[1j, 0.0+0j]], dtype=jnp.complex128),
# 4 Z
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, -1.0+0j]], dtype=jnp.complex128),
# 5 S
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, 1j ]], dtype=jnp.complex128),
# 6 Sdg
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, -1j ]], dtype=jnp.complex128),
# 7 T
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_ph4]], dtype=jnp.complex128),
# 8 Tdg
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_mh4]], dtype=jnp.complex128),
# 9 Rx(ΞΈ) = [[cos ΞΈ/2, -i sin ΞΈ/2], [-i sin ΞΈ/2, cos ΞΈ/2]]
lambda _: jnp.array(
[[cos_p, -1j * sin_p],
[-1j * sin_p, cos_p ]], dtype=jnp.complex128),
# 10 Ry(ΞΈ) = [[cos ΞΈ/2, -sin ΞΈ/2], [sin ΞΈ/2, cos ΞΈ/2]]
lambda _: jnp.array(
[[cos_p, -sin_p],
[sin_p, cos_p]], dtype=jnp.complex128),
# 11 Rz(ΞΈ) = [[e^{-iΞΈ/2}, 0], [0, e^{iΞΈ/2}]]
lambda _: jnp.array(
[[jnp.exp(-1j * half_p), 0.0+0j ],
[0.0+0j, jnp.exp(1j * half_p)]],
dtype=jnp.complex128),
# 12 Phase / P(ΞΈ) / U1(ΞΈ) = [[1, 0], [0, e^{iΞΈ}]]
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_pos]], dtype=jnp.complex128),
],
operand=None,
)
# ββ 1-qubit application ββββββββββββββββββββββββββββββββββββββββ
def do_1q(_sv):
stride = jnp.int64(1) << q1.astype(jnp.int64)
idx_full = jnp.arange(dim, dtype=jnp.int64)
mask_0 = (idx_full & stride) == 0
# idx_0: indices where qubit q1 == 0
# idx_1: corresponding |1β© partners
# We build them without xp.where-tuple confusion:
# any index i has its pair at i ^ stride.
# For i in |0β© slots (mask_0): partner = i | stride = i ^ stride
# For i in |1β© slots (Β¬mask_0): partner = i ^ stride (clears bit)
# We process all indices simultaneously using the |0β© slot's amplitude.
idx_pair = idx_full ^ stride # each element's partner
amp_self = _sv[idx_full] # a[i]
amp_pair = _sv[idx_pair] # a[i ^ stride]
# When mask_0: amp_self = a_0, amp_pair = a_1
# new_0 = g00*a_0 + g01*a_1
# new_1 = g10*a_0 + g11*a_1
g00 = g_1q[0, 0]; g01 = g_1q[0, 1]
g10 = g_1q[1, 0]; g11 = g_1q[1, 1]
new_when_0 = g00 * amp_self + g01 * amp_pair # result for |0β© slot
new_when_1 = g10 * amp_pair + g11 * amp_self # result for |1β© slot
# NOTE: for |1β© slots, amp_pair is the |0β© amplitude and
# amp_self is the |1β© amplitude β roles are swapped.
return jnp.where(mask_0, new_when_0, new_when_1)
# ββ 2-qubit application βββββββββββββββββββββββββββββββββββββββ
def do_2q(_sv):
ctrl = q1.astype(jnp.int64)
trgt = q2.astype(jnp.int64)
idx_full = jnp.arange(dim, dtype=jnp.int64)
ctrl_bit_set = (idx_full & (jnp.int64(1) << ctrl)) != 0
trgt_bit_set = (idx_full & (jnp.int64(1) << trgt)) != 0
# CX: flip target bit when control is set
def apply_cx(__sv):
partner = idx_full ^ (jnp.int64(1) << trgt)
swapped = __sv[partner]
return jnp.where(ctrl_bit_set, swapped, __sv)
# CZ: negate amplitude when both control and target bits are set
def apply_cz(__sv):
both_set = ctrl_bit_set & trgt_bit_set
return jnp.where(both_set, -__sv, __sv)
# CP(ΞΈ): phase kick e^{iΞΈ} on |11β© component
def apply_cp(__sv):
both_set = ctrl_bit_set & trgt_bit_set
return jnp.where(both_set, exp_pos * __sv, __sv)
# SWAP: exchange amplitudes of ctrl-bit and trgt-bit positions
def apply_swap(__sv):
# Standard SWAP = CX(c,t) Β· CX(t,c) Β· CX(c,t)
# Computed directly: for each (ctrl=0,trgt=1) pair with
# the other bits identical, swap the two amplitudes.
only_ctrl = ( ctrl_bit_set & ~trgt_bit_set)
only_trgt = (~ctrl_bit_set & trgt_bit_set)
swap_mask = only_ctrl | only_trgt
partner = idx_full ^ (jnp.int64(1) << ctrl) ^ (jnp.int64(1) << trgt)
return jnp.where(swap_mask, __sv[partner], __sv)
# Dispatch on g_id: 20=CX, 21=CZ, 22=CP, 23=SWAP
is_cx = g_id == 20
is_cz = g_id == 21
is_cp = g_id == 22
# is_swap = g_id == 23 (default branch)
after_cx = jax.lax.cond(is_cx, apply_cx, lambda s: s, _sv)
after_cz = jax.lax.cond(is_cz, apply_cz, lambda s: s, _sv)
after_cp = jax.lax.cond(is_cp, apply_cp, lambda s: s, _sv)
after_swap = apply_swap(_sv)
# Pick the right result
result = jnp.where(is_cx, after_cx,
jnp.where(is_cz, after_cz,
jnp.where(is_cp, after_cp,
after_swap)))
return result
# ββ branch on 1-qubit vs 2-qubit βββββββββββββββββββββββββββββ
# g_id <= 12 β 1-qubit; g_id >= 20 β 2-qubit.
# Both branches must have identical output dtypes β enforced here
# by casting both outputs to complex128.
is_1q = g_id <= 12
new_sv = jax.lax.cond(
is_1q,
lambda s: do_1q(s).astype(jnp.complex128),
lambda s: do_2q(s).astype(jnp.complex128),
sv,
)
return new_sv, None
@jax.jit
def _compile_and_run_circuit_jit(state_vector: "jnp.ndarray",
compiled_ops: "jnp.ndarray") -> "jnp.ndarray":
"""
Execute a pre-compiled gate sequence on *state_vector* via jax.lax.scan.
Parameters
----------
state_vector : complex128 array of shape (2**n,)
compiled_ops : float64 array of shape (n_gates, 4)
each row = [g_id, q1, q2, param]
Returns
-------
Final statevector after all gates.
"""
final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
return final_sv
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# QuantumTranspiler
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class QuantumTranspiler:
"""
Gate-level transpiler: decomposes non-native gates into the native
{H, T, Tdg, CX, CZ} basis and performs basic circuit optimisations.
"""
@staticmethod
def decompose_toffoli(c1: int, c2: int, t: int) -> List[Tuple]:
"""
Decompose CCX (Toffoli) into 15 native gates using the
standard T/Tdg/CX Barenco decomposition.
Gate count: 6 CX + 7 single-qubit (H, T, Tdg) = 15 total.
"""
return [
('h', t),
('cx', c2, t), ('tdg', t),
('cx', c1, t), ('t', t),
('cx', c2, t), ('tdg', t),
('cx', c1, t),
('t', c2), ('t', t),
('cx', c1, c2), ('h', t),
('t', c1), ('tdg', c2),
('cx', c1, c2),
]
@staticmethod
def decompose_swap(q1: int, q2: int) -> List[Tuple]:
"""Decompose SWAP into 3 CX gates."""
return [('cx', q1, q2), ('cx', q2, q1), ('cx', q1, q2)]
@classmethod
def transpile(cls, circuit: List[Tuple]) -> List[Tuple]:
"""
Expand CCX β 15 native gates and SWAP β 3 CX.
All other gates are passed through unchanged.
Parameters
----------
circuit : list of tuples (gate_name, qubit, ...)
Returns
-------
Expanded circuit as a list of tuples.
"""
out: List[Tuple] = []
for cmd in circuit:
name = cmd[0].lower() # BUG FIX: was cmd.lower() on a tuple
if name == 'ccx':
out.extend(cls.decompose_toffoli(*cmd[1:4]))
elif name == 'swap':
out.extend(cls.decompose_swap(*cmd[1:3]))
else:
out.append(cmd)
return out
|