Tatopenn's picture
Upload 20 files
4eff328 verified
import numpy as np
from typing import List, Tuple, Optional
from .registry import HAS_JAX
from .gates import GATES, PARAMETRIC_GATES, GATE_IDS
from .compiler import QuantumTranspiler
if HAS_JAX:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from .compiler import _compile_and_run_circuit_jit
# ─────────────────────────────────────────────────────────────────────────────
# Internal helpers
# ─────────────────────────────────────────────────────────────────────────────
# Copyright (c) 2026 Salvatore Pennacchio <jtatopenn@libero.it>
# Distributed under the Business Source License 1.1 (BSL 1.1)
# See LICENSE.md in the project root for full license terms.
def _qubit_stride_pairs(n: int, qubit: int):
"""
Return (stride, outer_step, inner_step) for the MSB-first statevector
convention used throughout this simulator.
In MSB-first ordering qubit 0 is the *most* significant bit, so:
physical_bit_position = n - 1 - qubit
stride = 1 << physical_bit_position
"""
phys = n - 1 - qubit
stride = 1 << phys
return stride
def _cx_numpy(sv: np.ndarray, n: int, ctrl: int, tgt: int) -> np.ndarray:
"""
Vectorised CX on a NumPy statevector.
No Python loops β€” uses strided index arithmetic.
"""
dim = len(sv)
c_stride = 1 << (n - 1 - ctrl)
t_stride = 1 << (n - 1 - tgt)
all_i = np.arange(dim, dtype=np.intp)
# Select indices where ctrl bit == 1 and tgt bit == 0
mask = ((all_i & c_stride) != 0) & ((all_i & t_stride) == 0)
idx_0 = all_i[mask]
idx_1 = idx_0 | t_stride
sv = sv.copy()
sv[idx_0], sv[idx_1] = sv[idx_1].copy(), sv[idx_0].copy()
return sv
def _cz_numpy(sv: np.ndarray, n: int, ctrl: int, tgt: int) -> np.ndarray:
"""Vectorised CZ on a NumPy statevector."""
dim = len(sv)
c_stride = 1 << (n - 1 - ctrl)
t_stride = 1 << (n - 1 - tgt)
all_i = np.arange(dim, dtype=np.intp)
mask = ((all_i & c_stride) != 0) & ((all_i & t_stride) != 0)
sv = sv.copy()
sv[mask] *= -1
return sv
# ─────────────────────────────────────────────────────────────────────────────
# DenseSVSimulator
# ─────────────────────────────────────────────────────────────────────────────
class DenseSVSimulator:
"""
Dense statevector quantum circuit simulator.
Qubit ordering: MSB-first (qubit 0 is the most significant bit).
Backends: NumPy (CPU), JAX XLA JIT (CPU/GPU/TPU).
Parameters
----------
n_qubits : number of qubits
use_gpu : reserved for future CuPy/JAX GPU dispatch
use_float32: use complex64 instead of complex128
"""
def __init__(self, n_qubits: int,
use_gpu: bool = False,
use_float32: bool = False):
if n_qubits < 1 or n_qubits > 34:
raise ValueError(f"n_qubits must be in [1, 34], got {n_qubits}")
self.n = n_qubits
self.dim = 1 << n_qubits # 2 ** n_qubits
self.use_float32 = use_float32
self.dtype = np.complex64 if use_float32 else np.complex128
self.xp = jnp if HAS_JAX else np
self._reset_sv()
# ── state initialisation ─────────────────────────────────────────
def _reset_sv(self):
"""Allocate |0...0⟩ on the active backend."""
if HAS_JAX:
self.sv = jnp.zeros(self.dim, dtype=self.dtype).at[0].set(1.0)
else:
self.sv = np.zeros(self.dim, dtype=self.dtype)
self.sv[0] = 1.0
def set_initial_state(self, state: Optional[np.ndarray] = None):
"""
Reset the simulator.
Parameters
----------
state : optional complex array of length 2**n.
If None, resets to |0...0⟩.
The array is normalised automatically.
"""
if state is None:
self._reset_sv()
return
state = np.asarray(state, dtype=self.dtype)
if state.shape != (self.dim,):
raise ValueError(
f"State vector length {len(state)} != 2**{self.n} = {self.dim}")
norm = np.linalg.norm(state)
if norm < 1e-12:
raise ValueError("Cannot set a zero-norm state vector")
state = state / norm
if HAS_JAX:
self.sv = jnp.array(state)
else:
self.sv = state.copy()
# Alias used by the VQE engine
def set_state(self, state: np.ndarray):
self.set_initial_state(state)
# ── normalisation ─────────────────────────────────────────────────
def normalize(self):
norm = float(self.xp.linalg.norm(self.sv))
if norm > 1e-12:
if HAS_JAX:
self.sv = self.sv / norm
else:
self.sv /= norm
# ── 1-qubit gate ──────────────────────────────────────────────────
def apply_gate_1q(self, gate: np.ndarray, qubit: int):
"""
Apply a 2Γ—2 unitary to *qubit* via tensor contraction.
Uses reshape + moveaxis + matmul β€” fully vectorised,
no Python loops, compatible with both NumPy and JAX.
"""
if not 0 <= qubit < self.n:
raise ValueError(f"Qubit index {qubit} out of range [0, {self.n})")
gate = self.xp.array(gate, dtype=self.dtype)
sv_nd = self.sv.reshape([2] * self.n)
sv_moved = self.xp.moveaxis(sv_nd, qubit, -1) # qubit axis β†’ last
flat_shape = (self.dim >> 1, 2)
# matmul: (dim/2, 2) @ (2, 2).T β†’ (dim/2, 2)
result = self.xp.dot(sv_moved.reshape(flat_shape),
gate.T)
self.sv = self.xp.moveaxis(
result.reshape([2] * self.n), -1, qubit).ravel()
# ── 2-qubit gate ──────────────────────────────────────────────────
def apply_gate_2q(self, gate: np.ndarray, q1: int, q2: int):
"""
Apply a 4Γ—4 unitary to qubits (q1, q2) via tensor contraction.
"""
if q1 == q2:
raise ValueError("Control and target qubits must differ")
if not (0 <= q1 < self.n and 0 <= q2 < self.n):
raise ValueError(f"Qubit indices ({q1},{q2}) out of range [0, {self.n})")
gate = self.xp.array(gate, dtype=self.dtype)
sv_nd = self.sv.reshape([2] * self.n)
sv_moved = self.xp.moveaxis(sv_nd, (q1, q2), (-2, -1))
flat_shape = (self.dim >> 2, 4)
result = self.xp.dot(sv_moved.reshape(flat_shape),
gate.reshape(4, 4).T)
self.sv = self.xp.moveaxis(
result.reshape([2] * self.n), (-2, -1), (q1, q2)).ravel()
# ── specialised 2-qubit gates ─────────────────────────────────────
def apply_cx(self, ctrl: int, tgt: int):
"""
CX (CNOT) gate.
JAX path: matrix contraction via apply_gate_2q.
NumPy path: fully vectorised index swap β€” no Python loops.
"""
if ctrl == tgt:
raise ValueError("Control and target qubits must differ")
if not (0 <= ctrl < self.n and 0 <= tgt < self.n):
raise ValueError(f"Qubit indices ({ctrl},{tgt}) out of range [0, {self.n})")
if HAS_JAX:
cx_mat = jnp.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0],
], dtype=self.dtype)
self.apply_gate_2q(cx_mat, ctrl, tgt)
else:
self.sv = _cx_numpy(np.array(self.sv), self.n, ctrl, tgt)
def apply_cz(self, ctrl: int, tgt: int):
"""
CZ gate.
JAX path: matrix contraction via apply_gate_2q.
NumPy path: fully vectorised sign flip β€” no Python loops.
"""
if ctrl == tgt:
raise ValueError("Control and target qubits must differ")
if not (0 <= ctrl < self.n and 0 <= tgt < self.n):
raise ValueError(f"Qubit indices ({ctrl},{tgt}) out of range [0, {self.n})")
if HAS_JAX:
cz_mat = jnp.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, -1],
], dtype=self.dtype)
self.apply_gate_2q(cz_mat, ctrl, tgt)
else:
self.sv = _cz_numpy(np.array(self.sv), self.n, ctrl, tgt)
def apply_rx(self, qubit: int, theta: float):
"""Apply a parameterized RX gate using the active backend (NumPy/JAX)."""
cos, sin = self.xp.cos(theta / 2), self.xp.sin(theta / 2)
mat = self.xp.array([[cos, -1j * sin], [-1j * sin, cos]], dtype=self.dtype)
self.apply_gate_1q(mat, qubit)
def apply_ry(self, qubit: int, theta: float):
"""Apply a parameterized RY gate using the active backend (NumPy/JAX)."""
cos, sin = self.xp.cos(theta / 2), self.xp.sin(theta / 2)
mat = self.xp.array([[cos, -sin], [sin, cos]], dtype=self.dtype)
self.apply_gate_1q(mat, qubit)
def apply_rz(self, qubit: int, theta: float):
"""Apply a parameterized RZ gate using the active backend (NumPy/JAX)."""
exp_neg = self.xp.exp(-1j * theta / 2)
exp_pos = self.xp.exp(1j * theta / 2)
mat = self.xp.array([[exp_neg, 0.0], [0.0, exp_pos]], dtype=self.dtype)
self.apply_gate_1q(mat, qubit)
# ── measurement ───────────────────────────────────────────────────
def measure(self, qubit_idx: int) -> int:
"""
Projective measurement on *qubit_idx*.
Returns 0 or 1 and collapses the statevector.
Uses MSB-first physical bit index: phys = n - 1 - qubit_idx.
BUG FIX (original): the original NumPy collapse wrote
sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
which zeroed the *wrong* basis state (0 when result=1, 1 when result=0)
and never normalised the JAX path.
"""
if not 0 <= qubit_idx < self.n:
raise ValueError(
f"Qubit {qubit_idx} out of range [0, {self.n})")
phys = self.n - 1 - qubit_idx
stride = 1 << phys
# ── compute marginal probabilities ──────────────────────────
if HAS_JAX:
probs = jnp.abs(self.sv) ** 2
sv_nd = probs.reshape([2] * self.n)
mv = jnp.moveaxis(sv_nd, phys, 0)
prob_0 = float(jnp.sum(mv[0]))
prob_1 = float(jnp.sum(mv[1]))
else:
sv_res = self.sv.reshape(-1, 2, stride)
prob_0 = float(np.sum(np.abs(sv_res[:, 0, :]) ** 2))
prob_1 = float(np.sum(np.abs(sv_res[:, 1, :]) ** 2))
total = prob_0 + prob_1
if total < 1e-12:
raise RuntimeError("Statevector norm is zero β€” cannot measure")
prob_0 /= total
prob_1 /= total
result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))
# ── collapse ────────────────────────────────────────────────
# Zero out the amplitudes corresponding to the *opposite* outcome.
zero_slot = 1 - result # if result=0, zero slot 1; if result=1, zero slot 0
if HAS_JAX:
sv_nd = self.sv.reshape([2] * self.n)
mv = jnp.moveaxis(sv_nd, phys, 0)
mv = mv.at[zero_slot].set(0.0 + 0j)
self.sv = jnp.moveaxis(mv, 0, phys).ravel()
else:
sv_res = self.sv.reshape(-1, 2, stride)
sv_res[:, zero_slot, :] = 0.0
self.sv = sv_res.ravel()
self.normalize()
return result
# ── circuit execution ─────────────────────────────────────────────
def run_circuit(self, circuit: List[Tuple], transpile: bool = True):
target = QuantumTranspiler.transpile(circuit) if transpile else circuit
for cmd in target:
name = cmd[0].lower()
args = cmd[1:]
if name in GATES:
mat = self.xp.array(GATES[name], dtype=self.dtype)
if mat.shape == (2, 2):
self.apply_gate_1q(mat, int(args[0]))
else:
self.apply_gate_2q(mat, int(args[0]), int(args[1]))
elif name in PARAMETRIC_GATES:
if len(args) == 2:
mat = self.xp.array(PARAMETRIC_GATES[name](args[1]), dtype=self.dtype)
self.apply_gate_1q(mat, int(args[0]))
elif len(args) == 3:
mat = self.xp.array(PARAMETRIC_GATES[name](args[2]), dtype=self.dtype)
self.apply_gate_2q(mat, int(args[0]), int(args[1]))
elif len(args) == 4:
mat = self.xp.array(
PARAMETRIC_GATES[name](args[1], args[2], args[3]),
dtype=self.dtype)
self.apply_gate_1q(mat, int(args[0]))
def run_circuit_jit_beast_mode(self, circuit: List):
if not HAS_JAX:
return self.run_circuit(circuit)
target = QuantumTranspiler.transpile(circuit)
compiled_ops = []
for cmd in target:
name = cmd[0].lower() if isinstance(cmd[0], str) else str(cmd[0]).lower()
if name not in GATE_IDS:
continue
g_id = float(GATE_IDS[name])
args = cmd[1:]
# ── gate argument parsing ──────────────────────────────
# 1-qubit parametric: (name, qubit, param)
if name in ('rx', 'ry', 'rz', 'p', 'u1', 'phase'):
q1 = float(args[0])
p = float(args[1]) if len(args) > 1 else 0.0
compiled_ops.append([g_id, q1, 0.0, p])
# 2-qubit parametric: (name, ctrl, tgt, param)
elif name in ('cp', 'crz', 'cphase'):
ctrl = float(args[0])
tgt = float(args[1]) if len(args) > 1 else 0.0
p = float(args[2]) if len(args) > 2 else 0.0
compiled_ops.append([g_id, ctrl, tgt, p])
# 2-qubit non-parametric: (name, ctrl, tgt)
elif name in ('cx', 'cz', 'swap', 'cy'):
ctrl = float(args[0])
tgt = float(args[1]) if len(args) > 1 else 0.0
compiled_ops.append([g_id, ctrl, tgt, 0.0])
# 1-qubit non-parametric: (name, qubit)
else:
q1 = float(args[0]) if args else 0.0
compiled_ops.append([g_id, q1, 0.0, 0.0])
if compiled_ops:
ops_jnp = jnp.array(compiled_ops, dtype=jnp.float64)
self.sv = _compile_and_run_circuit_jit(self.sv, ops_jnp)
def run_circuit_with_chunking(self, circuit: List, chunk_size: int = 500):
"""
Execute a circuit in chunks to avoid JIT recompilation on
large variable-length circuits.
Each chunk is a separate _compile_and_run_circuit_jit call
with a fixed-size ops array, allowing XLA to cache each size.
"""
target = QuantumTranspiler.transpile(circuit)
for i in range(0, len(target), chunk_size):
self.run_circuit_jit_beast_mode(target[i: i + chunk_size])
def run_parametric_batch_jit(self,
base_circuit: List,
parameter_batch: np.ndarray) -> "jnp.ndarray":
if not HAS_JAX:
raise RuntimeError("run_parametric_batch_jit requires JAX")
target = QuantumTranspiler.transpile(base_circuit)
compiled_ops = []
for cmd in target:
name = cmd[0].lower() if isinstance(cmd[0], str) else str(cmd[0]).lower()
if name not in GATE_IDS:
continue
g_id = float(GATE_IDS[name])
args = cmd[1:]
if name in ('rx', 'ry', 'rz', 'p', 'u1', 'phase'):
compiled_ops.append([g_id, float(args[0]), 0.0, -1.0]) # -1.0 = param slot
elif name in ('cp', 'crz', 'cphase'):
compiled_ops.append([g_id, float(args[0]), float(args[1]) if len(args) > 1 else 0.0, -1.0])
elif name in ('cx', 'cz', 'swap', 'cy'):
compiled_ops.append([g_id, float(args[0]), float(args[1]) if len(args) > 1 else 0.0, 0.0])
else:
compiled_ops.append([g_id, float(args[0]) if args else 0.0, 0.0, 0.0])
template = jnp.array(compiled_ops, dtype=jnp.float64)
init_sv = jnp.zeros(self.dim, dtype=jnp.complex128).at[0].set(1.0)
def simulate_single_instance(single_params: "jnp.ndarray") -> "jnp.ndarray":
"""Run one parameter vector through the circuit."""
def patch_and_apply(carry: "jnp.ndarray",
op: "jnp.ndarray"):
"""
carry: jnp.int32 scalar β€” current parametric gate index.
op: [g_id, q1, q2, p_sentinel]
"""
idx = carry
is_param = op[3] == -1.0
final_p = jnp.where(is_param, single_params[idx], op[3])
next_idx = jnp.where(is_param, idx + jnp.int32(1), idx)
patched = jnp.array([op[0], op[1], op[2], final_p],
dtype=jnp.float64)
return next_idx, patched
_, patched_ops = jax.lax.scan(
patch_and_apply,
jnp.int32(0), # BUG FIX: was (0,) tuple β€” must be a scalar
template,
)
return _compile_and_run_circuit_jit(init_sv, patched_ops)
return jax.jit(jax.vmap(simulate_single_instance, in_axes=(0,)))(
jnp.asarray(parameter_batch, dtype=jnp.float64)
)
# ── observables ───────────────────────────────────────────────────
def get_probabilities(self) -> np.ndarray:
"""Return measurement probability distribution as a NumPy float64 array."""
probs = np.array(self.xp.abs(self.sv) ** 2, dtype=np.float64)
# guard against floating-point leakage outside [0, 1]
probs = np.clip(probs, 0.0, 1.0)
total = probs.sum()
if total > 1e-12:
probs /= total
return probs
def get_statevector(self) -> np.ndarray:
"""Return the current statevector as a NumPy complex array."""
return np.array(self.sv, dtype=self.dtype)
def memory_mb(self) -> float:
"""Statevector memory footprint in megabytes."""
bytes_per_element = 8 if self.use_float32 else 16 # complex64=8, complex128=16
return self.dim * bytes_per_element / 1_000_000