QureadAI / quread /engine.py
hchevva's picture
Update quread/engine.py
7db25b7 verified
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
from .gates import I, SINGLE_QUBIT_GATES, rx, ry, rz
from quread.gates import single_qubit_gate_matrix
Op = Dict[str, Any]
def _normalize_state(state: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(state)
if norm == 0:
raise ValueError("State norm is zero; cannot normalize.")
return state / norm
def _bit(value: int, bit_index_from_right: int) -> int:
# bit_index_from_right: 0 means least significant bit
return (value >> bit_index_from_right) & 1
def _flip_bit(value: int, bit_index_from_right: int) -> int:
return value ^ (1 << bit_index_from_right)
@dataclass
class QuantumStateVector:
n_qubits: int
state: np.ndarray = field(init=False)
history: List[Op] = field(default_factory=list)
def __post_init__(self) -> None:
if self.n_qubits < 1:
raise ValueError("n_qubits must be >= 1.")
dim = 2 ** self.n_qubits
self.state = np.zeros(dim, dtype=complex)
self.state[0] = 1.0 + 0j # |0...0>
def reset(self) -> None:
dim = 2 ** self.n_qubits
self.state = np.zeros(dim, dtype=complex)
self.state[0] = 1.0 + 0j
self.history.clear()
# --------- Gate application (matrix-free, beginner friendly) ---------
def apply_single(self, gate_name: str, target: int, theta: Optional[float] = None) -> None:
if not (0 <= target < self.n_qubits):
raise ValueError("target out of range")
def _parse_angle(label: str) -> float:
# supports: "π", "pi", "π/2", "pi/2"
s = label.strip().lower().replace(" ", "")
if s in ("π", "pi"):
return float(np.pi)
if s in ("π/2", "pi/2"):
return float(np.pi / 2)
raise ValueError(f"Unsupported angle in {gate_name}. Use π or π/2.")
g = gate_name.strip()
# --- normalize common UI labels to internal canonical names ---
# dagger variants
if g in ("T†", "Tdg"):
g = "Tdg"
if g in ("S†", "Sdg"):
g = "Sdg"
# identity variants
if g in ("I†", "Idg"):
g = "I"
# sqrt variants
if g in ("√X", "SX"):
g = "SQRTX"
if g in ("√Z", "SZ"):
g = "SQRTZ"
# --- build gate matrix ---
gate = None
# fixed-angle rotation labels from UI
if g.startswith(("Rx(", "RX(")) and g.endswith(")"):
ang = _parse_angle(g[g.find("(")+1 : -1])
gate = rx(float(ang))
g = f"RX({g[g.find('(')+1:-1]})" # preserve readable name in history
elif g.startswith(("Ry(", "RY(")) and g.endswith(")"):
ang = _parse_angle(g[g.find("(")+1 : -1])
gate = ry(float(ang))
g = f"RY({g[g.find('(')+1:-1]})"
elif g.startswith(("Rz(", "RZ(")) and g.endswith(")"):
ang = _parse_angle(g[g.find("(")+1 : -1])
gate = rz(float(ang))
g = f"RZ({g[g.find('(')+1:-1]})"
# original parametric API still supported
elif g == "RX":
if theta is None:
raise ValueError("RX requires theta")
gate = rx(float(theta))
elif g == "RY":
if theta is None:
raise ValueError("RY requires theta")
gate = ry(float(theta))
elif g == "RZ":
if theta is None:
raise ValueError("RZ requires theta")
gate = rz(float(theta))
# all normal single-qubit gates (incl. new ones) via map
elif g in SINGLE_QUBIT_GATES:
gate = SINGLE_QUBIT_GATES[g]
else:
raise ValueError(f"Unknown gate: {gate_name}")
# Apply using pairwise amplitude updates (no full 2^n x 2^n matrix)
msb_index = self.n_qubits - 1 - target # wire index -> bit position from right
new_state = self.state.copy()
dim = len(self.state)
for basis in range(dim):
if _bit(basis, msb_index) == 0:
partner = _flip_bit(basis, msb_index)
a0 = self.state[basis]
a1 = self.state[partner]
new_state[basis] = gate[0, 0] * a0 + gate[0, 1] * a1
new_state[partner] = gate[1, 0] * a0 + gate[1, 1] * a1
self.state = _normalize_state(new_state)
op: Op = {"type": "single", "gate": g, "target": target}
if theta is not None and g in ("RX", "RY", "RZ"):
op["theta"] = float(theta)
self.history.append(op)
def apply_cnot(self, control: int, target: int) -> None:
if control == target:
raise ValueError("control and target must be different")
if not (0 <= control < self.n_qubits) or not (0 <= target < self.n_qubits):
raise ValueError("control/target out of range")
c_bit = self.n_qubits - 1 - control
t_bit = self.n_qubits - 1 - target
new_state = self.state.copy()
dim = len(self.state)
visited = set()
for basis in range(dim):
if basis in visited:
continue
if _bit(basis, c_bit) == 1:
flipped = _flip_bit(basis, t_bit)
# swap amplitudes basis <-> flipped
visited.add(basis); visited.add(flipped)
new_state[basis], new_state[flipped] = self.state[flipped], self.state[basis]
self.state = _normalize_state(new_state)
self.history.append({"type": "cnot", "control": control, "target": target})
# --------- Measurement ---------
def probabilities(self) -> np.ndarray:
probs = np.abs(self.state) ** 2
total = float(np.sum(probs))
if total == 0:
return probs
return probs / total
def sample(self, shots: int = 1024) -> Dict[str, int]:
probs = self.probabilities()
dim = len(probs)
outcomes = np.random.choice(np.arange(dim), size=int(shots), p=probs)
counts: Dict[str, int] = {}
for idx in outcomes:
b = format(int(idx), f"0{self.n_qubits}b")
counts[b] = counts.get(b, 0) + 1
return dict(sorted(counts.items()))
def measure_collapse(self) -> str:
probs = self.probabilities()
dim = len(probs)
idx = int(np.random.choice(np.arange(dim), p=probs))
collapsed = np.zeros(dim, dtype=complex)
collapsed[idx] = 1.0 + 0j
self.state = collapsed
bitstring = format(idx, f"0{self.n_qubits}b")
self.history.append({"type": "measure", "result": bitstring})
return bitstring
# --------- Convenience helpers ---------
def ket_notation(self, max_terms: int = 16, tol: float = 1e-9) -> str:
# human readable statevector
terms = []
for i, amp in enumerate(self.state):
if abs(amp) > tol:
b = format(i, f"0{self.n_qubits}b")
terms.append((amp, b))
# sort by magnitude desc
terms.sort(key=lambda x: abs(x[0]), reverse=True)
terms = terms[:max_terms]
if not terms:
return "0"
parts = []
for amp, b in terms:
a = complex(amp)
parts.append(f"({a.real:+.4f}{a.imag:+.4f}j)|{b}⟩")
return " + ".join(parts).lstrip("+").strip()