from __future__ import annotations import numpy as np from dataclasses import dataclass, field from typing import List, Dict, Any, Optional, Tuple from .gates import I, SINGLE_QUBIT_GATES, rx, ry, rz from quread.gates import single_qubit_gate_matrix Op = Dict[str, Any] def _normalize_state(state: np.ndarray) -> np.ndarray: norm = np.linalg.norm(state) if norm == 0: raise ValueError("State norm is zero; cannot normalize.") return state / norm def _bit(value: int, bit_index_from_right: int) -> int: # bit_index_from_right: 0 means least significant bit return (value >> bit_index_from_right) & 1 def _flip_bit(value: int, bit_index_from_right: int) -> int: return value ^ (1 << bit_index_from_right) @dataclass class QuantumStateVector: n_qubits: int state: np.ndarray = field(init=False) history: List[Op] = field(default_factory=list) def __post_init__(self) -> None: if self.n_qubits < 1: raise ValueError("n_qubits must be >= 1.") dim = 2 ** self.n_qubits self.state = np.zeros(dim, dtype=complex) self.state[0] = 1.0 + 0j # |0...0> def reset(self) -> None: dim = 2 ** self.n_qubits self.state = np.zeros(dim, dtype=complex) self.state[0] = 1.0 + 0j self.history.clear() # --------- Gate application (matrix-free, beginner friendly) --------- def apply_single(self, gate_name: str, target: int, theta: Optional[float] = None) -> None: if not (0 <= target < self.n_qubits): raise ValueError("target out of range") def _parse_angle(label: str) -> float: # supports: "π", "pi", "π/2", "pi/2" s = label.strip().lower().replace(" ", "") if s in ("π", "pi"): return float(np.pi) if s in ("π/2", "pi/2"): return float(np.pi / 2) raise ValueError(f"Unsupported angle in {gate_name}. Use π or π/2.") g = gate_name.strip() # --- normalize common UI labels to internal canonical names --- # dagger variants if g in ("T†", "Tdg"): g = "Tdg" if g in ("S†", "Sdg"): g = "Sdg" # identity variants if g in ("I†", "Idg"): g = "I" # sqrt variants if g in ("√X", "SX"): g = "SQRTX" if g in ("√Z", "SZ"): g = "SQRTZ" # --- build gate matrix --- gate = None # fixed-angle rotation labels from UI if g.startswith(("Rx(", "RX(")) and g.endswith(")"): ang = _parse_angle(g[g.find("(")+1 : -1]) gate = rx(float(ang)) g = f"RX({g[g.find('(')+1:-1]})" # preserve readable name in history elif g.startswith(("Ry(", "RY(")) and g.endswith(")"): ang = _parse_angle(g[g.find("(")+1 : -1]) gate = ry(float(ang)) g = f"RY({g[g.find('(')+1:-1]})" elif g.startswith(("Rz(", "RZ(")) and g.endswith(")"): ang = _parse_angle(g[g.find("(")+1 : -1]) gate = rz(float(ang)) g = f"RZ({g[g.find('(')+1:-1]})" # original parametric API still supported elif g == "RX": if theta is None: raise ValueError("RX requires theta") gate = rx(float(theta)) elif g == "RY": if theta is None: raise ValueError("RY requires theta") gate = ry(float(theta)) elif g == "RZ": if theta is None: raise ValueError("RZ requires theta") gate = rz(float(theta)) # all normal single-qubit gates (incl. new ones) via map elif g in SINGLE_QUBIT_GATES: gate = SINGLE_QUBIT_GATES[g] else: raise ValueError(f"Unknown gate: {gate_name}") # Apply using pairwise amplitude updates (no full 2^n x 2^n matrix) msb_index = self.n_qubits - 1 - target # wire index -> bit position from right new_state = self.state.copy() dim = len(self.state) for basis in range(dim): if _bit(basis, msb_index) == 0: partner = _flip_bit(basis, msb_index) a0 = self.state[basis] a1 = self.state[partner] new_state[basis] = gate[0, 0] * a0 + gate[0, 1] * a1 new_state[partner] = gate[1, 0] * a0 + gate[1, 1] * a1 self.state = _normalize_state(new_state) op: Op = {"type": "single", "gate": g, "target": target} if theta is not None and g in ("RX", "RY", "RZ"): op["theta"] = float(theta) self.history.append(op) def apply_cnot(self, control: int, target: int) -> None: if control == target: raise ValueError("control and target must be different") if not (0 <= control < self.n_qubits) or not (0 <= target < self.n_qubits): raise ValueError("control/target out of range") c_bit = self.n_qubits - 1 - control t_bit = self.n_qubits - 1 - target new_state = self.state.copy() dim = len(self.state) visited = set() for basis in range(dim): if basis in visited: continue if _bit(basis, c_bit) == 1: flipped = _flip_bit(basis, t_bit) # swap amplitudes basis <-> flipped visited.add(basis); visited.add(flipped) new_state[basis], new_state[flipped] = self.state[flipped], self.state[basis] self.state = _normalize_state(new_state) self.history.append({"type": "cnot", "control": control, "target": target}) # --------- Measurement --------- def probabilities(self) -> np.ndarray: probs = np.abs(self.state) ** 2 total = float(np.sum(probs)) if total == 0: return probs return probs / total def sample(self, shots: int = 1024) -> Dict[str, int]: probs = self.probabilities() dim = len(probs) outcomes = np.random.choice(np.arange(dim), size=int(shots), p=probs) counts: Dict[str, int] = {} for idx in outcomes: b = format(int(idx), f"0{self.n_qubits}b") counts[b] = counts.get(b, 0) + 1 return dict(sorted(counts.items())) def measure_collapse(self) -> str: probs = self.probabilities() dim = len(probs) idx = int(np.random.choice(np.arange(dim), p=probs)) collapsed = np.zeros(dim, dtype=complex) collapsed[idx] = 1.0 + 0j self.state = collapsed bitstring = format(idx, f"0{self.n_qubits}b") self.history.append({"type": "measure", "result": bitstring}) return bitstring # --------- Convenience helpers --------- def ket_notation(self, max_terms: int = 16, tol: float = 1e-9) -> str: # human readable statevector terms = [] for i, amp in enumerate(self.state): if abs(amp) > tol: b = format(i, f"0{self.n_qubits}b") terms.append((amp, b)) # sort by magnitude desc terms.sort(key=lambda x: abs(x[0]), reverse=True) terms = terms[:max_terms] if not terms: return "0" parts = [] for amp, b in terms: a = complex(amp) parts.append(f"({a.real:+.4f}{a.imag:+.4f}j)|{b}⟩") return " + ".join(parts).lstrip("+").strip()