Tatopenn's picture
Upload 20 files
4eff328 verified
# Copyright (c) 2026 Salvatore Pennacchio <jtatopenn@libero.it>
# Distributed under the Business Source License 1.1 (BSL 1.1)
# See LICENSE.md in the project root for full license terms.
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
# ── optional JAX import (same pattern as registry) ────────────────────────────
try:
from .registry import HAS_JAX
except ImportError:
try:
import jax
HAS_JAX = True
except ImportError:
HAS_JAX = False
if HAS_JAX:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# ─────────────────────────────────────────────────────────────────────────────
# Gate-ID encoding (shared between _apply_gate_fast_step and the beast-mode
# command builder in the dashboard).
#
# 0 I (identity) 7 T
# 1 H 8 Tdg
# 2 X 9 Rx(ΞΈ)
# 3 Y 10 Ry(ΞΈ)
# 4 Z 11 Rz(ΞΈ)
# 5 S 12 Phase / P(ΞΈ) / U1(ΞΈ)
# 6 Sdg ── 2-qubit gates ──
# 20 CX / CNOT
# 21 CZ
# 22 CP(ΞΈ) / CRZ(ΞΈ)
# 23 SWAP
# ─────────────────────────────────────────────────────────────────────────────
if HAS_JAX:
@jax.jit
def _apply_gate_fast_step(sv: "jnp.ndarray",
operation: "jnp.ndarray"):
"""
Apply a single quantum gate to statevector *sv*.
Parameters
----------
sv : complex128 statevector of shape (2**n,)
operation : float64 array [g_id, q1, q2, param]
g_id β€” gate identifier (see table above)
q1 β€” target qubit (1-qubit gates) or control qubit (2-qubit)
q2 β€” target qubit for 2-qubit gates; unused for 1-qubit
param β€” rotation angle in radians; 0.0 for non-parametric gates
Returns
-------
(new_sv, None) β€” compatible with jax.lax.scan
"""
g_id = operation[0].astype(jnp.int32)
q1 = operation[1].astype(jnp.int32)
q2 = operation[2].astype(jnp.int32)
param = operation[3]
dim = sv.shape[0]
inv2 = jnp.float64(1.0 / jnp.sqrt(2.0))
half_p = param * jnp.float64(0.5)
cos_p = jnp.cos(half_p).astype(jnp.complex128)
sin_p = jnp.sin(half_p).astype(jnp.complex128)
exp_pos = jnp.exp( 1j * param).astype(jnp.complex128)
exp_neg = jnp.exp(-1j * param).astype(jnp.complex128)
exp_ph4 = jnp.exp( 1j * jnp.pi / 4.0).astype(jnp.complex128)
exp_mh4 = jnp.exp(-1j * jnp.pi / 4.0).astype(jnp.complex128)
# ── 1-qubit gate matrix selection via lax.switch ──────────────
# Index must be in [0, 12]; anything outside is clamped to 0 (I).
safe_gid = jnp.clip(g_id, 0, 12)
g_1q = jax.lax.switch(
safe_gid,
[
# 0 I
lambda _: jnp.eye(2, dtype=jnp.complex128),
# 1 H
lambda _: jnp.array(
[[inv2, inv2],
[inv2, -inv2]], dtype=jnp.complex128),
# 2 X
lambda _: jnp.array(
[[0.0+0j, 1.0+0j],
[1.0+0j, 0.0+0j]], dtype=jnp.complex128),
# 3 Y
lambda _: jnp.array(
[[0.0+0j, -1j],
[1j, 0.0+0j]], dtype=jnp.complex128),
# 4 Z
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, -1.0+0j]], dtype=jnp.complex128),
# 5 S
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, 1j ]], dtype=jnp.complex128),
# 6 Sdg
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, -1j ]], dtype=jnp.complex128),
# 7 T
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_ph4]], dtype=jnp.complex128),
# 8 Tdg
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_mh4]], dtype=jnp.complex128),
# 9 Rx(ΞΈ) = [[cos ΞΈ/2, -i sin ΞΈ/2], [-i sin ΞΈ/2, cos ΞΈ/2]]
lambda _: jnp.array(
[[cos_p, -1j * sin_p],
[-1j * sin_p, cos_p ]], dtype=jnp.complex128),
# 10 Ry(ΞΈ) = [[cos ΞΈ/2, -sin ΞΈ/2], [sin ΞΈ/2, cos ΞΈ/2]]
lambda _: jnp.array(
[[cos_p, -sin_p],
[sin_p, cos_p]], dtype=jnp.complex128),
# 11 Rz(ΞΈ) = [[e^{-iΞΈ/2}, 0], [0, e^{iΞΈ/2}]]
lambda _: jnp.array(
[[jnp.exp(-1j * half_p), 0.0+0j ],
[0.0+0j, jnp.exp(1j * half_p)]],
dtype=jnp.complex128),
# 12 Phase / P(ΞΈ) / U1(ΞΈ) = [[1, 0], [0, e^{iΞΈ}]]
lambda _: jnp.array(
[[1.0+0j, 0.0+0j],
[0.0+0j, exp_pos]], dtype=jnp.complex128),
],
operand=None,
)
# ── 1-qubit application ────────────────────────────────────────
def do_1q(_sv):
stride = jnp.int64(1) << q1.astype(jnp.int64)
idx_full = jnp.arange(dim, dtype=jnp.int64)
mask_0 = (idx_full & stride) == 0
# idx_0: indices where qubit q1 == 0
# idx_1: corresponding |1⟩ partners
# We build them without xp.where-tuple confusion:
# any index i has its pair at i ^ stride.
# For i in |0⟩ slots (mask_0): partner = i | stride = i ^ stride
# For i in |1⟩ slots (¬mask_0): partner = i ^ stride (clears bit)
# We process all indices simultaneously using the |0⟩ slot's amplitude.
idx_pair = idx_full ^ stride # each element's partner
amp_self = _sv[idx_full] # a[i]
amp_pair = _sv[idx_pair] # a[i ^ stride]
# When mask_0: amp_self = a_0, amp_pair = a_1
# new_0 = g00*a_0 + g01*a_1
# new_1 = g10*a_0 + g11*a_1
g00 = g_1q[0, 0]; g01 = g_1q[0, 1]
g10 = g_1q[1, 0]; g11 = g_1q[1, 1]
new_when_0 = g00 * amp_self + g01 * amp_pair # result for |0⟩ slot
new_when_1 = g10 * amp_pair + g11 * amp_self # result for |1⟩ slot
# NOTE: for |1⟩ slots, amp_pair is the |0⟩ amplitude and
# amp_self is the |1⟩ amplitude β€” roles are swapped.
return jnp.where(mask_0, new_when_0, new_when_1)
# ── 2-qubit application ───────────────────────────────────────
def do_2q(_sv):
ctrl = q1.astype(jnp.int64)
trgt = q2.astype(jnp.int64)
idx_full = jnp.arange(dim, dtype=jnp.int64)
ctrl_bit_set = (idx_full & (jnp.int64(1) << ctrl)) != 0
trgt_bit_set = (idx_full & (jnp.int64(1) << trgt)) != 0
# CX: flip target bit when control is set
def apply_cx(__sv):
partner = idx_full ^ (jnp.int64(1) << trgt)
swapped = __sv[partner]
return jnp.where(ctrl_bit_set, swapped, __sv)
# CZ: negate amplitude when both control and target bits are set
def apply_cz(__sv):
both_set = ctrl_bit_set & trgt_bit_set
return jnp.where(both_set, -__sv, __sv)
# CP(θ): phase kick e^{iθ} on |11⟩ component
def apply_cp(__sv):
both_set = ctrl_bit_set & trgt_bit_set
return jnp.where(both_set, exp_pos * __sv, __sv)
# SWAP: exchange amplitudes of ctrl-bit and trgt-bit positions
def apply_swap(__sv):
# Standard SWAP = CX(c,t) Β· CX(t,c) Β· CX(c,t)
# Computed directly: for each (ctrl=0,trgt=1) pair with
# the other bits identical, swap the two amplitudes.
only_ctrl = ( ctrl_bit_set & ~trgt_bit_set)
only_trgt = (~ctrl_bit_set & trgt_bit_set)
swap_mask = only_ctrl | only_trgt
partner = idx_full ^ (jnp.int64(1) << ctrl) ^ (jnp.int64(1) << trgt)
return jnp.where(swap_mask, __sv[partner], __sv)
# Dispatch on g_id: 20=CX, 21=CZ, 22=CP, 23=SWAP
is_cx = g_id == 20
is_cz = g_id == 21
is_cp = g_id == 22
# is_swap = g_id == 23 (default branch)
after_cx = jax.lax.cond(is_cx, apply_cx, lambda s: s, _sv)
after_cz = jax.lax.cond(is_cz, apply_cz, lambda s: s, _sv)
after_cp = jax.lax.cond(is_cp, apply_cp, lambda s: s, _sv)
after_swap = apply_swap(_sv)
# Pick the right result
result = jnp.where(is_cx, after_cx,
jnp.where(is_cz, after_cz,
jnp.where(is_cp, after_cp,
after_swap)))
return result
# ── branch on 1-qubit vs 2-qubit ─────────────────────────────
# g_id <= 12 β†’ 1-qubit; g_id >= 20 β†’ 2-qubit.
# Both branches must have identical output dtypes β€” enforced here
# by casting both outputs to complex128.
is_1q = g_id <= 12
new_sv = jax.lax.cond(
is_1q,
lambda s: do_1q(s).astype(jnp.complex128),
lambda s: do_2q(s).astype(jnp.complex128),
sv,
)
return new_sv, None
@jax.jit
def _compile_and_run_circuit_jit(state_vector: "jnp.ndarray",
compiled_ops: "jnp.ndarray") -> "jnp.ndarray":
"""
Execute a pre-compiled gate sequence on *state_vector* via jax.lax.scan.
Parameters
----------
state_vector : complex128 array of shape (2**n,)
compiled_ops : float64 array of shape (n_gates, 4)
each row = [g_id, q1, q2, param]
Returns
-------
Final statevector after all gates.
"""
final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
return final_sv
# ─────────────────────────────────────────────────────────────────────────────
# QuantumTranspiler
# ─────────────────────────────────────────────────────────────────────────────
class QuantumTranspiler:
"""
Gate-level transpiler: decomposes non-native gates into the native
{H, T, Tdg, CX, CZ} basis and performs basic circuit optimisations.
"""
@staticmethod
def decompose_toffoli(c1: int, c2: int, t: int) -> List[Tuple]:
"""
Decompose CCX (Toffoli) into 15 native gates using the
standard T/Tdg/CX Barenco decomposition.
Gate count: 6 CX + 7 single-qubit (H, T, Tdg) = 15 total.
"""
return [
('h', t),
('cx', c2, t), ('tdg', t),
('cx', c1, t), ('t', t),
('cx', c2, t), ('tdg', t),
('cx', c1, t),
('t', c2), ('t', t),
('cx', c1, c2), ('h', t),
('t', c1), ('tdg', c2),
('cx', c1, c2),
]
@staticmethod
def decompose_swap(q1: int, q2: int) -> List[Tuple]:
"""Decompose SWAP into 3 CX gates."""
return [('cx', q1, q2), ('cx', q2, q1), ('cx', q1, q2)]
@classmethod
def transpile(cls, circuit: List[Tuple]) -> List[Tuple]:
"""
Expand CCX β†’ 15 native gates and SWAP β†’ 3 CX.
All other gates are passed through unchanged.
Parameters
----------
circuit : list of tuples (gate_name, qubit, ...)
Returns
-------
Expanded circuit as a list of tuples.
"""
out: List[Tuple] = []
for cmd in circuit:
name = cmd[0].lower() # BUG FIX: was cmd.lower() on a tuple
if name == 'ccx':
out.extend(cls.decompose_toffoli(*cmd[1:4]))
elif name == 'swap':
out.extend(cls.decompose_swap(*cmd[1:3]))
else:
out.append(cmd)
return out