| """ |
| FireEcho Quantum Gold - Measurement Operations |
| |
| Implements quantum measurement with wavefunction collapse, |
| probability computation, and statistical sampling. |
| |
| Measurement Theory: |
| - Measurement in computational basis collapses |ψ⟩ to basis state |i⟩ |
| with probability |⟨i|ψ⟩|² = |αᵢ|² |
| - Post-measurement state is renormalized |
| - Measurement is irreversible (wavefunction collapse) |
| """ |
|
|
| import torch |
| import math |
| from typing import Dict, List, Optional, Tuple, Union |
| from dataclasses import dataclass |
|
|
| from .simulator import StateVector |
|
|
|
|
| def get_probabilities(state: StateVector) -> torch.Tensor: |
| """ |
| Get measurement probabilities for all basis states. |
| |
| Args: |
| state: Quantum state vector |
| |
| Returns: |
| Tensor of probabilities, where prob[i] = |⟨i|ψ⟩|² |
| """ |
| return (state.amplitudes.abs() ** 2).real |
|
|
|
|
| def measure(state: StateVector, qubit: int, collapse: bool = True) -> Tuple[int, StateVector]: |
| """ |
| Measure a single qubit. |
| |
| Args: |
| state: Quantum state vector |
| qubit: Qubit index to measure |
| collapse: Whether to collapse the state (default: True) |
| |
| Returns: |
| Tuple of (measurement result 0 or 1, post-measurement state) |
| |
| Example: |
| state = StateVector.uniform_superposition(2) # |+⟩|+⟩ |
| result, new_state = measure(state, 0) |
| # result is 0 or 1 with 50% probability each |
| # new_state is collapsed to |0⟩|+⟩ or |1⟩|+⟩ |
| """ |
| num_qubits = state.num_qubits |
| size = 2 ** num_qubits |
| qubit_mask = 1 << qubit |
| |
| |
| probs = get_probabilities(state) |
| |
| prob_0 = 0.0 |
| for i in range(size): |
| if not (i & qubit_mask): |
| prob_0 += probs[i].item() |
| |
| prob_1 = 1.0 - prob_0 |
| |
| |
| result = 0 if torch.rand(1).item() < prob_0 else 1 |
| |
| if not collapse: |
| return result, state |
| |
| |
| new_state = state.copy() |
| |
| if result == 0: |
| |
| for i in range(size): |
| if i & qubit_mask: |
| new_state.amplitudes[i] = 0 |
| |
| norm = math.sqrt(prob_0) |
| else: |
| |
| for i in range(size): |
| if not (i & qubit_mask): |
| new_state.amplitudes[i] = 0 |
| norm = math.sqrt(prob_1) |
| |
| new_state.amplitudes = new_state.amplitudes / norm |
| |
| return result, new_state |
|
|
|
|
| def measure_all(state: StateVector, collapse: bool = True) -> Tuple[str, StateVector]: |
| """ |
| Measure all qubits. |
| |
| Args: |
| state: Quantum state vector |
| collapse: Whether to collapse the state |
| |
| Returns: |
| Tuple of (bitstring result, post-measurement state) |
| |
| Example: |
| state = bell_state() # (|00⟩ + |11⟩)/√2 |
| result, new_state = measure_all(state) |
| # result is "00" or "11" with 50% probability each |
| """ |
| probs = get_probabilities(state) |
| |
| |
| idx = torch.multinomial(probs, 1).item() |
| |
| |
| bitstring = format(idx, f'0{state.num_qubits}b')[::-1] |
| |
| if not collapse: |
| return bitstring, state |
| |
| |
| new_state = StateVector.from_label(bitstring, device=str(state.amplitudes.device)) |
| |
| return bitstring, new_state |
|
|
|
|
| def sample(state: StateVector, shots: int = 1024, |
| seed: Optional[int] = None) -> Dict[str, int]: |
| """ |
| Sample measurement outcomes without collapsing state. |
| |
| Args: |
| state: Quantum state vector |
| shots: Number of samples |
| seed: Random seed for reproducibility |
| |
| Returns: |
| Dictionary of {bitstring: count} |
| |
| Example: |
| state = ghz_state(3) # (|000⟩ + |111⟩)/√2 |
| counts = sample(state, shots=1000) |
| # {'000': ~500, '111': ~500} |
| """ |
| if seed is not None: |
| torch.manual_seed(seed) |
| |
| probs = get_probabilities(state) |
| |
| |
| indices = torch.multinomial(probs, shots, replacement=True) |
| |
| |
| counts = {} |
| for idx in indices.tolist(): |
| bitstring = format(idx, f'0{state.num_qubits}b')[::-1] |
| counts[bitstring] = counts.get(bitstring, 0) + 1 |
| |
| return counts |
|
|
|
|
| def expectation_value(state: StateVector, observable: str, |
| qubits: Optional[List[int]] = None) -> float: |
| """ |
| Compute expectation value of a Pauli observable. |
| |
| Args: |
| state: Quantum state vector |
| observable: Pauli string like "ZZI", "XXX", "ZIZ" |
| qubits: Qubit indices (default: 0, 1, 2, ...) |
| |
| Returns: |
| Expectation value ⟨ψ|O|ψ⟩ |
| |
| Example: |
| state = bell_state() |
| zz = expectation_value(state, "ZZ") # Should be +1 |
| xx = expectation_value(state, "XX") # Should be +1 |
| """ |
| if qubits is None: |
| qubits = list(range(len(observable))) |
| |
| if len(observable) != len(qubits): |
| raise ValueError( |
| f"Observable length {len(observable)} doesn't match qubits {len(qubits)}" |
| ) |
| |
| |
| I = torch.tensor([[1, 0], [0, 1]], dtype=torch.complex64, device=state.amplitudes.device) |
| X = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64, device=state.amplitudes.device) |
| Y = torch.tensor([[0, -1j], [1j, 0]], dtype=torch.complex64, device=state.amplitudes.device) |
| Z = torch.tensor([[1, 0], [0, -1]], dtype=torch.complex64, device=state.amplitudes.device) |
| |
| pauli_map = {'I': I, 'X': X, 'Y': Y, 'Z': Z} |
| |
| |
| full_obs = None |
| for i, pauli_char in enumerate(observable): |
| pauli = pauli_map[pauli_char.upper()] |
| |
| if full_obs is None: |
| full_obs = pauli |
| else: |
| full_obs = torch.kron(pauli, full_obs) |
| |
| |
| o_psi = torch.mv(full_obs, state.amplitudes) |
| expectation = torch.sum(state.amplitudes.conj() * o_psi) |
| |
| return expectation.real.item() |
|
|
|
|
| def partial_trace(state: StateVector, keep_qubits: List[int]) -> torch.Tensor: |
| """ |
| Compute reduced density matrix by tracing out qubits. |
| |
| Args: |
| state: Quantum state vector |
| keep_qubits: List of qubit indices to keep |
| |
| Returns: |
| Reduced density matrix for kept qubits |
| |
| Example: |
| state = bell_state() # Entangled 2-qubit state |
| rho = partial_trace(state, [0]) # Single qubit density matrix |
| # rho should be maximally mixed: [[0.5, 0], [0, 0.5]] |
| """ |
| num_qubits = state.num_qubits |
| keep_set = set(keep_qubits) |
| trace_qubits = [i for i in range(num_qubits) if i not in keep_set] |
| |
| num_keep = len(keep_qubits) |
| dim_keep = 2 ** num_keep |
| |
| |
| rho = torch.zeros((dim_keep, dim_keep), dtype=torch.complex64, |
| device=state.amplitudes.device) |
| |
| |
| for i in range(2 ** num_qubits): |
| for j in range(2 ** num_qubits): |
| |
| match = True |
| for q in trace_qubits: |
| if ((i >> q) & 1) != ((j >> q) & 1): |
| match = False |
| break |
| |
| if match: |
| |
| i_reduced = 0 |
| j_reduced = 0 |
| for k, q in enumerate(keep_qubits): |
| i_reduced |= ((i >> q) & 1) << k |
| j_reduced |= ((j >> q) & 1) << k |
| |
| rho[i_reduced, j_reduced] += ( |
| state.amplitudes[i].conj() * state.amplitudes[j] |
| ) |
| |
| return rho |
|
|
|
|
| def entropy(state: StateVector, base: int = 2) -> float: |
| """ |
| Compute von Neumann entropy of a pure state (always 0). |
| |
| For pure states, entropy is 0. Use entanglement_entropy for |
| bipartite entanglement. |
| |
| Args: |
| state: Quantum state vector |
| base: Logarithm base (default: 2 for bits) |
| |
| Returns: |
| Entropy value (0 for pure states) |
| """ |
| |
| return 0.0 |
|
|
|
|
| def entanglement_entropy(state: StateVector, partition: List[int], |
| base: int = 2) -> float: |
| """ |
| Compute entanglement entropy for a bipartition. |
| |
| S(A) = -Tr(ρ_A log ρ_A) where ρ_A is reduced density matrix. |
| |
| Args: |
| state: Quantum state vector |
| partition: List of qubits in subsystem A |
| base: Logarithm base |
| |
| Returns: |
| Entanglement entropy |
| |
| Example: |
| state = bell_state() |
| S = entanglement_entropy(state, [0]) # Should be 1.0 (maximally entangled) |
| """ |
| rho = partial_trace(state, partition) |
| |
| |
| eigenvalues = torch.linalg.eigvalsh(rho).real |
| |
| |
| eigenvalues = eigenvalues[eigenvalues > 1e-10] |
| |
| |
| if base == 2: |
| entropy = -torch.sum(eigenvalues * torch.log2(eigenvalues)) |
| elif base == math.e: |
| entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
| else: |
| entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) / math.log(base) |
| |
| return entropy.item() |
|
|
|
|
| @dataclass |
| class MeasurementResult: |
| """Container for measurement results.""" |
| outcomes: Dict[str, int] |
| total_shots: int |
| state_before: Optional[StateVector] = None |
| |
| @property |
| def most_likely(self) -> str: |
| """Get most frequently measured bitstring.""" |
| return max(self.outcomes, key=self.outcomes.get) |
| |
| @property |
| def probabilities(self) -> Dict[str, float]: |
| """Get experimental probabilities.""" |
| return {k: v / self.total_shots for k, v in self.outcomes.items()} |
| |
| def __repr__(self): |
| return f"MeasurementResult(shots={self.total_shots}, outcomes={len(self.outcomes)})" |
| |
| def __str__(self): |
| lines = [f"Measurement Results ({self.total_shots} shots):"] |
| sorted_outcomes = sorted(self.outcomes.items(), key=lambda x: -x[1]) |
| for bitstring, count in sorted_outcomes[:10]: |
| prob = count / self.total_shots |
| bar = "█" * int(prob * 40) |
| lines.append(f" |{bitstring}⟩: {count:5d} ({prob:.3f}) {bar}") |
| if len(sorted_outcomes) > 10: |
| lines.append(f" ... ({len(sorted_outcomes) - 10} more outcomes)") |
| return "\n".join(lines) |
|
|