FireEcho / quantum /measurement.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
"""
FireEcho Quantum Gold - Measurement Operations
Implements quantum measurement with wavefunction collapse,
probability computation, and statistical sampling.
Measurement Theory:
- Measurement in computational basis collapses |ψ⟩ to basis state |i⟩
with probability |⟨i|ψ⟩|² = |αᵢ|²
- Post-measurement state is renormalized
- Measurement is irreversible (wavefunction collapse)
"""
import torch
import math
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from .simulator import StateVector
def get_probabilities(state: StateVector) -> torch.Tensor:
"""
Get measurement probabilities for all basis states.
Args:
state: Quantum state vector
Returns:
Tensor of probabilities, where prob[i] = |⟨i|ψ⟩|²
"""
return (state.amplitudes.abs() ** 2).real
def measure(state: StateVector, qubit: int, collapse: bool = True) -> Tuple[int, StateVector]:
"""
Measure a single qubit.
Args:
state: Quantum state vector
qubit: Qubit index to measure
collapse: Whether to collapse the state (default: True)
Returns:
Tuple of (measurement result 0 or 1, post-measurement state)
Example:
state = StateVector.uniform_superposition(2) # |+⟩|+⟩
result, new_state = measure(state, 0)
# result is 0 or 1 with 50% probability each
# new_state is collapsed to |0⟩|+⟩ or |1⟩|+⟩
"""
num_qubits = state.num_qubits
size = 2 ** num_qubits
qubit_mask = 1 << qubit
# Compute probability of measuring |0⟩ on this qubit
probs = get_probabilities(state)
prob_0 = 0.0
for i in range(size):
if not (i & qubit_mask): # qubit is 0
prob_0 += probs[i].item()
prob_1 = 1.0 - prob_0
# Sample measurement result
result = 0 if torch.rand(1).item() < prob_0 else 1
if not collapse:
return result, state
# Collapse state
new_state = state.copy()
if result == 0:
# Zero out amplitudes where qubit = 1
for i in range(size):
if i & qubit_mask:
new_state.amplitudes[i] = 0
# Renormalize
norm = math.sqrt(prob_0)
else:
# Zero out amplitudes where qubit = 0
for i in range(size):
if not (i & qubit_mask):
new_state.amplitudes[i] = 0
norm = math.sqrt(prob_1)
new_state.amplitudes = new_state.amplitudes / norm
return result, new_state
def measure_all(state: StateVector, collapse: bool = True) -> Tuple[str, StateVector]:
"""
Measure all qubits.
Args:
state: Quantum state vector
collapse: Whether to collapse the state
Returns:
Tuple of (bitstring result, post-measurement state)
Example:
state = bell_state() # (|00⟩ + |11⟩)/√2
result, new_state = measure_all(state)
# result is "00" or "11" with 50% probability each
"""
probs = get_probabilities(state)
# Sample from probability distribution
idx = torch.multinomial(probs, 1).item()
# Convert to bitstring (reversed for qubit ordering)
bitstring = format(idx, f'0{state.num_qubits}b')[::-1]
if not collapse:
return bitstring, state
# Collapse to measured basis state
new_state = StateVector.from_label(bitstring, device=str(state.amplitudes.device))
return bitstring, new_state
def sample(state: StateVector, shots: int = 1024,
seed: Optional[int] = None) -> Dict[str, int]:
"""
Sample measurement outcomes without collapsing state.
Args:
state: Quantum state vector
shots: Number of samples
seed: Random seed for reproducibility
Returns:
Dictionary of {bitstring: count}
Example:
state = ghz_state(3) # (|000⟩ + |111⟩)/√2
counts = sample(state, shots=1000)
# {'000': ~500, '111': ~500}
"""
if seed is not None:
torch.manual_seed(seed)
probs = get_probabilities(state)
# Sample indices
indices = torch.multinomial(probs, shots, replacement=True)
# Count occurrences
counts = {}
for idx in indices.tolist():
bitstring = format(idx, f'0{state.num_qubits}b')[::-1]
counts[bitstring] = counts.get(bitstring, 0) + 1
return counts
def expectation_value(state: StateVector, observable: str,
qubits: Optional[List[int]] = None) -> float:
"""
Compute expectation value of a Pauli observable.
Args:
state: Quantum state vector
observable: Pauli string like "ZZI", "XXX", "ZIZ"
qubits: Qubit indices (default: 0, 1, 2, ...)
Returns:
Expectation value ⟨ψ|O|ψ⟩
Example:
state = bell_state()
zz = expectation_value(state, "ZZ") # Should be +1
xx = expectation_value(state, "XX") # Should be +1
"""
if qubits is None:
qubits = list(range(len(observable)))
if len(observable) != len(qubits):
raise ValueError(
f"Observable length {len(observable)} doesn't match qubits {len(qubits)}"
)
# Build Pauli matrices
I = torch.tensor([[1, 0], [0, 1]], dtype=torch.complex64, device=state.amplitudes.device)
X = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64, device=state.amplitudes.device)
Y = torch.tensor([[0, -1j], [1j, 0]], dtype=torch.complex64, device=state.amplitudes.device)
Z = torch.tensor([[1, 0], [0, -1]], dtype=torch.complex64, device=state.amplitudes.device)
pauli_map = {'I': I, 'X': X, 'Y': Y, 'Z': Z}
# Build full observable via tensor product
full_obs = None
for i, pauli_char in enumerate(observable):
pauli = pauli_map[pauli_char.upper()]
if full_obs is None:
full_obs = pauli
else:
full_obs = torch.kron(pauli, full_obs)
# ⟨ψ|O|ψ⟩
o_psi = torch.mv(full_obs, state.amplitudes)
expectation = torch.sum(state.amplitudes.conj() * o_psi)
return expectation.real.item()
def partial_trace(state: StateVector, keep_qubits: List[int]) -> torch.Tensor:
"""
Compute reduced density matrix by tracing out qubits.
Args:
state: Quantum state vector
keep_qubits: List of qubit indices to keep
Returns:
Reduced density matrix for kept qubits
Example:
state = bell_state() # Entangled 2-qubit state
rho = partial_trace(state, [0]) # Single qubit density matrix
# rho should be maximally mixed: [[0.5, 0], [0, 0.5]]
"""
num_qubits = state.num_qubits
keep_set = set(keep_qubits)
trace_qubits = [i for i in range(num_qubits) if i not in keep_set]
num_keep = len(keep_qubits)
dim_keep = 2 ** num_keep
# Initialize reduced density matrix
rho = torch.zeros((dim_keep, dim_keep), dtype=torch.complex64,
device=state.amplitudes.device)
# Build mapping from kept qubit indices to reduced matrix indices
for i in range(2 ** num_qubits):
for j in range(2 ** num_qubits):
# Check if traced-out qubits match
match = True
for q in trace_qubits:
if ((i >> q) & 1) != ((j >> q) & 1):
match = False
break
if match:
# Map to reduced indices
i_reduced = 0
j_reduced = 0
for k, q in enumerate(keep_qubits):
i_reduced |= ((i >> q) & 1) << k
j_reduced |= ((j >> q) & 1) << k
rho[i_reduced, j_reduced] += (
state.amplitudes[i].conj() * state.amplitudes[j]
)
return rho
def entropy(state: StateVector, base: int = 2) -> float:
"""
Compute von Neumann entropy of a pure state (always 0).
For pure states, entropy is 0. Use entanglement_entropy for
bipartite entanglement.
Args:
state: Quantum state vector
base: Logarithm base (default: 2 for bits)
Returns:
Entropy value (0 for pure states)
"""
# Pure state has zero entropy
return 0.0
def entanglement_entropy(state: StateVector, partition: List[int],
base: int = 2) -> float:
"""
Compute entanglement entropy for a bipartition.
S(A) = -Tr(ρ_A log ρ_A) where ρ_A is reduced density matrix.
Args:
state: Quantum state vector
partition: List of qubits in subsystem A
base: Logarithm base
Returns:
Entanglement entropy
Example:
state = bell_state()
S = entanglement_entropy(state, [0]) # Should be 1.0 (maximally entangled)
"""
rho = partial_trace(state, partition)
# Compute eigenvalues
eigenvalues = torch.linalg.eigvalsh(rho).real
# Filter small eigenvalues (numerical noise)
eigenvalues = eigenvalues[eigenvalues > 1e-10]
# S = -Σ λ log(λ)
if base == 2:
entropy = -torch.sum(eigenvalues * torch.log2(eigenvalues))
elif base == math.e:
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues))
else:
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) / math.log(base)
return entropy.item()
@dataclass
class MeasurementResult:
"""Container for measurement results."""
outcomes: Dict[str, int]
total_shots: int
state_before: Optional[StateVector] = None
@property
def most_likely(self) -> str:
"""Get most frequently measured bitstring."""
return max(self.outcomes, key=self.outcomes.get)
@property
def probabilities(self) -> Dict[str, float]:
"""Get experimental probabilities."""
return {k: v / self.total_shots for k, v in self.outcomes.items()}
def __repr__(self):
return f"MeasurementResult(shots={self.total_shots}, outcomes={len(self.outcomes)})"
def __str__(self):
lines = [f"Measurement Results ({self.total_shots} shots):"]
sorted_outcomes = sorted(self.outcomes.items(), key=lambda x: -x[1])
for bitstring, count in sorted_outcomes[:10]:
prob = count / self.total_shots
bar = "█" * int(prob * 40)
lines.append(f" |{bitstring}⟩: {count:5d} ({prob:.3f}) {bar}")
if len(sorted_outcomes) > 10:
lines.append(f" ... ({len(sorted_outcomes) - 10} more outcomes)")
return "\n".join(lines)