|
|
|
|
|
|
|
|
|
|
| import numpy as np
|
| from typing import List, Tuple
|
| from dataclasses import dataclass
|
|
|
|
|
| try:
|
| from .registry import HAS_JAX
|
| except ImportError:
|
| try:
|
| import jax
|
| HAS_JAX = True
|
| except ImportError:
|
| HAS_JAX = False
|
|
|
| if HAS_JAX:
|
| import jax
|
| import jax.numpy as jnp
|
| jax.config.update("jax_enable_x64", True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if HAS_JAX:
|
|
|
| @jax.jit
|
| def _apply_gate_fast_step(sv: "jnp.ndarray",
|
| operation: "jnp.ndarray"):
|
| """
|
| Apply a single quantum gate to statevector *sv*.
|
|
|
| Parameters
|
| ----------
|
| sv : complex128 statevector of shape (2**n,)
|
| operation : float64 array [g_id, q1, q2, param]
|
| g_id β gate identifier (see table above)
|
| q1 β target qubit (1-qubit gates) or control qubit (2-qubit)
|
| q2 β target qubit for 2-qubit gates; unused for 1-qubit
|
| param β rotation angle in radians; 0.0 for non-parametric gates
|
|
|
| Returns
|
| -------
|
| (new_sv, None) β compatible with jax.lax.scan
|
| """
|
| g_id = operation[0].astype(jnp.int32)
|
| q1 = operation[1].astype(jnp.int32)
|
| q2 = operation[2].astype(jnp.int32)
|
| param = operation[3]
|
| dim = sv.shape[0]
|
|
|
| inv2 = jnp.float64(1.0 / jnp.sqrt(2.0))
|
| half_p = param * jnp.float64(0.5)
|
| cos_p = jnp.cos(half_p).astype(jnp.complex128)
|
| sin_p = jnp.sin(half_p).astype(jnp.complex128)
|
| exp_pos = jnp.exp( 1j * param).astype(jnp.complex128)
|
| exp_neg = jnp.exp(-1j * param).astype(jnp.complex128)
|
| exp_ph4 = jnp.exp( 1j * jnp.pi / 4.0).astype(jnp.complex128)
|
| exp_mh4 = jnp.exp(-1j * jnp.pi / 4.0).astype(jnp.complex128)
|
|
|
|
|
|
|
| safe_gid = jnp.clip(g_id, 0, 12)
|
|
|
| g_1q = jax.lax.switch(
|
| safe_gid,
|
| [
|
|
|
| lambda _: jnp.eye(2, dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[inv2, inv2],
|
| [inv2, -inv2]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[0.0+0j, 1.0+0j],
|
| [1.0+0j, 0.0+0j]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[0.0+0j, -1j],
|
| [1j, 0.0+0j]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, -1.0+0j]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, 1j ]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, -1j ]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, exp_ph4]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, exp_mh4]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[cos_p, -1j * sin_p],
|
| [-1j * sin_p, cos_p ]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[cos_p, -sin_p],
|
| [sin_p, cos_p]], dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[jnp.exp(-1j * half_p), 0.0+0j ],
|
| [0.0+0j, jnp.exp(1j * half_p)]],
|
| dtype=jnp.complex128),
|
|
|
| lambda _: jnp.array(
|
| [[1.0+0j, 0.0+0j],
|
| [0.0+0j, exp_pos]], dtype=jnp.complex128),
|
| ],
|
| operand=None,
|
| )
|
|
|
|
|
| def do_1q(_sv):
|
| stride = jnp.int64(1) << q1.astype(jnp.int64)
|
| idx_full = jnp.arange(dim, dtype=jnp.int64)
|
| mask_0 = (idx_full & stride) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| idx_pair = idx_full ^ stride
|
| amp_self = _sv[idx_full]
|
| amp_pair = _sv[idx_pair]
|
|
|
|
|
|
|
|
|
| g00 = g_1q[0, 0]; g01 = g_1q[0, 1]
|
| g10 = g_1q[1, 0]; g11 = g_1q[1, 1]
|
|
|
| new_when_0 = g00 * amp_self + g01 * amp_pair
|
| new_when_1 = g10 * amp_pair + g11 * amp_self
|
|
|
|
|
|
|
| return jnp.where(mask_0, new_when_0, new_when_1)
|
|
|
|
|
| def do_2q(_sv):
|
| ctrl = q1.astype(jnp.int64)
|
| trgt = q2.astype(jnp.int64)
|
| idx_full = jnp.arange(dim, dtype=jnp.int64)
|
|
|
| ctrl_bit_set = (idx_full & (jnp.int64(1) << ctrl)) != 0
|
| trgt_bit_set = (idx_full & (jnp.int64(1) << trgt)) != 0
|
|
|
|
|
| def apply_cx(__sv):
|
| partner = idx_full ^ (jnp.int64(1) << trgt)
|
| swapped = __sv[partner]
|
| return jnp.where(ctrl_bit_set, swapped, __sv)
|
|
|
|
|
| def apply_cz(__sv):
|
| both_set = ctrl_bit_set & trgt_bit_set
|
| return jnp.where(both_set, -__sv, __sv)
|
|
|
|
|
| def apply_cp(__sv):
|
| both_set = ctrl_bit_set & trgt_bit_set
|
| return jnp.where(both_set, exp_pos * __sv, __sv)
|
|
|
|
|
| def apply_swap(__sv):
|
|
|
|
|
|
|
| only_ctrl = ( ctrl_bit_set & ~trgt_bit_set)
|
| only_trgt = (~ctrl_bit_set & trgt_bit_set)
|
| swap_mask = only_ctrl | only_trgt
|
| partner = idx_full ^ (jnp.int64(1) << ctrl) ^ (jnp.int64(1) << trgt)
|
| return jnp.where(swap_mask, __sv[partner], __sv)
|
|
|
|
|
| is_cx = g_id == 20
|
| is_cz = g_id == 21
|
| is_cp = g_id == 22
|
|
|
|
|
| after_cx = jax.lax.cond(is_cx, apply_cx, lambda s: s, _sv)
|
| after_cz = jax.lax.cond(is_cz, apply_cz, lambda s: s, _sv)
|
| after_cp = jax.lax.cond(is_cp, apply_cp, lambda s: s, _sv)
|
| after_swap = apply_swap(_sv)
|
|
|
|
|
| result = jnp.where(is_cx, after_cx,
|
| jnp.where(is_cz, after_cz,
|
| jnp.where(is_cp, after_cp,
|
| after_swap)))
|
| return result
|
|
|
|
|
|
|
|
|
|
|
| is_1q = g_id <= 12
|
| new_sv = jax.lax.cond(
|
| is_1q,
|
| lambda s: do_1q(s).astype(jnp.complex128),
|
| lambda s: do_2q(s).astype(jnp.complex128),
|
| sv,
|
| )
|
| return new_sv, None
|
|
|
|
|
| @jax.jit
|
| def _compile_and_run_circuit_jit(state_vector: "jnp.ndarray",
|
| compiled_ops: "jnp.ndarray") -> "jnp.ndarray":
|
| """
|
| Execute a pre-compiled gate sequence on *state_vector* via jax.lax.scan.
|
|
|
| Parameters
|
| ----------
|
| state_vector : complex128 array of shape (2**n,)
|
| compiled_ops : float64 array of shape (n_gates, 4)
|
| each row = [g_id, q1, q2, param]
|
|
|
| Returns
|
| -------
|
| Final statevector after all gates.
|
| """
|
| final_sv, _ = jax.lax.scan(_apply_gate_fast_step, state_vector, compiled_ops)
|
| return final_sv
|
|
|
|
|
|
|
|
|
|
|
|
|
| class QuantumTranspiler:
|
| """
|
| Gate-level transpiler: decomposes non-native gates into the native
|
| {H, T, Tdg, CX, CZ} basis and performs basic circuit optimisations.
|
| """
|
|
|
| @staticmethod
|
| def decompose_toffoli(c1: int, c2: int, t: int) -> List[Tuple]:
|
| """
|
| Decompose CCX (Toffoli) into 15 native gates using the
|
| standard T/Tdg/CX Barenco decomposition.
|
|
|
| Gate count: 6 CX + 7 single-qubit (H, T, Tdg) = 15 total.
|
| """
|
| return [
|
| ('h', t),
|
| ('cx', c2, t), ('tdg', t),
|
| ('cx', c1, t), ('t', t),
|
| ('cx', c2, t), ('tdg', t),
|
| ('cx', c1, t),
|
| ('t', c2), ('t', t),
|
| ('cx', c1, c2), ('h', t),
|
| ('t', c1), ('tdg', c2),
|
| ('cx', c1, c2),
|
| ]
|
|
|
| @staticmethod
|
| def decompose_swap(q1: int, q2: int) -> List[Tuple]:
|
| """Decompose SWAP into 3 CX gates."""
|
| return [('cx', q1, q2), ('cx', q2, q1), ('cx', q1, q2)]
|
|
|
| @classmethod
|
| def transpile(cls, circuit: List[Tuple]) -> List[Tuple]:
|
| """
|
| Expand CCX β 15 native gates and SWAP β 3 CX.
|
| All other gates are passed through unchanged.
|
|
|
| Parameters
|
| ----------
|
| circuit : list of tuples (gate_name, qubit, ...)
|
|
|
| Returns
|
| -------
|
| Expanded circuit as a list of tuples.
|
| """
|
| out: List[Tuple] = []
|
| for cmd in circuit:
|
| name = cmd[0].lower()
|
| if name == 'ccx':
|
| out.extend(cls.decompose_toffoli(*cmd[1:4]))
|
| elif name == 'swap':
|
| out.extend(cls.decompose_swap(*cmd[1:3]))
|
| else:
|
| out.append(cmd)
|
| return out
|
|
|