File size: 17,151 Bytes
4eff328 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 | # Copyright (c) 2026 Salvatore Pennacchio <jtatopenn@libero.it>
# Distributed under the Business Source License 1.1 (BSL 1.1)
# See LICENSE.md in the project root for full license terms.
import subprocess
import sys
import os
import time
import psutil
import platform
import warnings
from typing import Optional, List, Dict, Any
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
try:
import cupy as cp
HAS_CUPY = True
except ImportError:
HAS_CUPY = False
try:
import jax
import jax.numpy as jnp
HAS_JAX = True
jax.config.update("jax_enable_x64", True)
except ImportError:
HAS_JAX = False
jnp = None
class QuantumHardwareRegistry:
def __init__(self):
self.processor = platform.processor()
self.ram_total = psutil.virtual_memory().total / (1024**3)
self.ram_avail = psutil.virtual_memory().available / (1024**3)
self.has_cupy = HAS_CUPY
self.has_jax = HAS_JAX
self.has_gpu = self._detect_gpu()
self.max_dense_qubits = self._get_qubit_limit()
def _detect_gpu(self) -> bool:
try:
subprocess.check_output(['nvidia-smi'], stderr=subprocess.DEVNULL)
return True
except Exception:
return False
def _get_qubit_limit(self) -> int:
if self.ram_total >= 50: return 28
elif self.ram_total >= 12: return 24
return 20
def print_diagnostics(self):
print(f"MAX_DENSE={self.max_dense_qubits}q | JAX={self.has_jax} | GPU={self.has_gpu}")
HARDWARE_REGISTRY = QuantumHardwareRegistry()
plt.style.use('dark_background')
matplotlib.rcParams.update({
'figure.facecolor': '#010409',
'axes.facecolor': '#0d1117',
'axes.edgecolor': '#21262d',
'grid.color': '#21262d',
'font.family': 'monospace',
'font.size': 9,
})
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Internal helpers
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _fresh_rng() -> np.random.Generator:
"""
Create a hardware-entropy-seeded RNG.
Combines os.urandom (CSPRNG) with a high-resolution nanosecond counter
so two calls within the same microsecond still differ.
"""
entropy_bytes = os.urandom(8)
entropy_int = int.from_bytes(entropy_bytes, byteorder='big')
ns_counter = time.perf_counter_ns() & 0xFFFF_FFFF_FFFF_FFFF
seed = (entropy_int ^ ns_counter) & 0xFFFF_FFFF_FFFF_FFFF
return np.random.default_rng(seed)
def _qubit_index_pairs(dim: int, q: int):
"""
Return (idx_0, idx_1) β two integer arrays of length dim/2 β where
idx_0[i] has bit q == 0 and idx_1[i] = idx_0[i] | (1 << q).
This is the correct and vectorised way to build qubit-pair indices.
The original code used `xp.where()` which returns a *tuple*, then
did `idx_1 = idx_0 | step` on that tuple β producing wrong indices
for all models and making phaseflip look deterministic.
"""
step = 1 << q
all_i = np.arange(dim, dtype=np.intp)
idx_0 = all_i[(all_i & step) == 0] # shape: (dim//2,)
idx_1 = idx_0 | step # shape: (dim//2,)
return idx_0, idx_1
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# NoiseModel
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class NoiseModel:
"""
Stochastic single-qubit Kraus channels applied directly to a statevector.
All channels are mathematically correct Kraus maps:
- trace is preserved (normalisation enforced at the end)
- phaseflip applies Z with probability p per qubit (non-deterministic)
- amplitude_damping applies the correct K0/K1 Kraus operators
- combined is a true worst-case NISQ mixture of all three Pauli errors
plus amplitude damping
Supported models
----------------
'ideal' identity β no modification
'depolarizing' {β(1-p)I, β(p/3)X, β(p/3)Y, β(p/3)Z}
'bitflip' {β(1-p)I, βpΒ·X}
'phaseflip' {β(1-p)I, βpΒ·Z} β was broken, now fixed
'amplitude_damping'{K0=diag(1,β(1-Ξ³)), K1=[[0,βΞ³],[0,0]]}
'combined' depolarizing(p/2) + amplitude_damping(p/3), renormalised
"""
MODELS = ['ideal', 'depolarizing', 'bitflip', 'phaseflip',
'amplitude_damping', 'combined']
@staticmethod
def apply_to_sv(
sv: np.ndarray,
n: int,
model: str,
p: float,
rng: Optional[np.random.Generator] = None,
qubits: Optional[List[int]] = None,
jax_key: Optional[Any] = None,
) -> np.ndarray:
"""
Apply a stochastic Kraus channel to statevector *sv* in-place
(numpy path) or via functional updates (JAX path).
Parameters
----------
sv : complex statevector of length 2**n
n : number of qubits
model : one of NoiseModel.MODELS
p : error probability (or damping rate Ξ³ for amplitude_damping)
rng : optional pre-seeded numpy Generator; created fresh if None
qubits : subset of qubits to apply the channel to; defaults to all
jax_key : optional JAX PRNGKey; created fresh if None and sv is a JAX array
Returns
-------
Normalised statevector (same array type as input).
"""
if model == 'ideal' or p <= 0.0:
return sv
is_jax = HAS_JAX and isinstance(sv, jnp.ndarray)
dim = len(sv)
# ββ RNG initialisation ββββββββββββββββββββββββββββββββββββββββ
if is_jax:
if jax_key is None:
seed_bytes = os.urandom(4)
jax_seed = int.from_bytes(seed_bytes, byteorder='big')
jax_seed ^= time.perf_counter_ns() & 0xFFFF_FFFF
key = jax.random.PRNGKey(jax_seed)
else:
key = jax_key
else:
if rng is None:
rng = _fresh_rng()
target_qubits = qubits if qubits is not None else list(range(n))
sv_out = sv # JAX: functional; NumPy: will be modified in-place copy
if not is_jax:
sv_out = sv.copy() # never mutate the caller's array
for q in target_qubits:
# ββ correct index pair construction βββββββββββββββββββββββ
idx_0, idx_1 = _qubit_index_pairs(dim, q)
half = len(idx_0) # == dim // 2
# ββ draw random numbers βββββββββββββββββββββββββββββββββββ
if is_jax:
key, subkey = jax.random.split(key)
r = jax.random.uniform(subkey, shape=(half,), minval=0.0, maxval=1.0)
else:
r = rng.random(half) # uniform [0, 1)
# ββ channel application βββββββββββββββββββββββββββββββββββ
if model == 'depolarizing':
# Three equiprobable Pauli errors, each with rate p/3
p3 = p / 3.0
if is_jax:
key, subkey2 = jax.random.split(key)
ch = jax.random.uniform(subkey2, shape=(half,), minval=0.0, maxval=1.0)
v0, v1 = sv_out[idx_0], sv_out[idx_1]
fire = r < p
x_gate = fire & (ch < p3)
y_gate = fire & (ch >= p3) & (ch < 2.0 * p3)
z_gate = fire & (ch >= 2.0 * p3)
new_v0 = jnp.where(x_gate, v1,
jnp.where(y_gate, -1j * v1, v0))
new_v1 = jnp.where(x_gate, v0,
jnp.where(y_gate, 1j * v0,
jnp.where(z_gate, -v1, v1)))
sv_out = sv_out.at[idx_0].set(new_v0)
sv_out = sv_out.at[idx_1].set(new_v1)
else:
ch = rng.random(half)
v0, v1 = sv_out[idx_0].copy(), sv_out[idx_1].copy()
fire = r < p
x_gate = fire & (ch < p3)
y_gate = fire & (ch >= p3) & (ch < 2.0 * p3)
z_gate = fire & (ch >= 2.0 * p3)
sv_out[idx_0] = np.where(x_gate, v1,
np.where(y_gate, -1j * v1, v0))
sv_out[idx_1] = np.where(x_gate, v0,
np.where(y_gate, 1j * v0,
np.where(z_gate, -v1, v1)))
elif model == 'bitflip':
# X gate applied with probability p
fire = r < p
if is_jax:
v0, v1 = sv_out[idx_0], sv_out[idx_1]
sv_out = sv_out.at[idx_0].set(jnp.where(fire, v1, v0))
sv_out = sv_out.at[idx_1].set(jnp.where(fire, v0, v1))
else:
v0, v1 = sv_out[idx_0].copy(), sv_out[idx_1].copy()
sv_out[idx_0] = np.where(fire, v1, v0)
sv_out[idx_1] = np.where(fire, v0, v1)
elif model == 'phaseflip':
# Z gate applied with probability p:
# Z|0β© = |0β© β no change to idx_0 amplitudes
# Z|1β© = -|1β© β negate idx_1 amplitudes when fired
fire = r < p
if is_jax:
v1 = sv_out[idx_1]
sv_out = sv_out.at[idx_1].set(jnp.where(fire, -v1, v1))
else:
v1 = sv_out[idx_1].copy()
sv_out[idx_1] = np.where(fire, -v1, v1)
elif model == 'amplitude_damping':
# K0 = [[1, 0], [0, β(1-Ξ³)]] β no decay
# K1 = [[0, βΞ³], [0, 0]] β decay |1β© β |0β©
# Applied stochastically: with probability Ξ³ the qubit
# decays (K1 path), otherwise K0 is applied.
gamma = float(np.clip(p, 0.0, 1.0))
decay = r < gamma
if is_jax:
v0, v1 = sv_out[idx_0], sv_out[idx_1]
# decay path: v0 += v1 * βΞ³, v1 = 0
# no-decay path: v0 unchanged, v1 *= β(1-Ξ³)
sq_gamma = jnp.sqrt(gamma)
sq_1m_gamma = jnp.sqrt(1.0 - gamma)
new_v0 = jnp.where(decay, v0 + v1 * sq_gamma, v0)
new_v1 = jnp.where(decay, 0.0 + 0j, v1 * sq_1m_gamma)
sv_out = sv_out.at[idx_0].set(new_v0)
sv_out = sv_out.at[idx_1].set(new_v1)
else:
v0, v1 = sv_out[idx_0].copy(), sv_out[idx_1].copy()
sq_gamma = np.sqrt(gamma)
sq_1m_gamma = np.sqrt(1.0 - gamma)
sv_out[idx_0] = np.where(decay, v0 + v1 * sq_gamma, v0)
sv_out[idx_1] = np.where(decay, 0.0 + 0j, v1 * sq_1m_gamma)
elif model == 'combined':
# Worst-case NISQ: depolarizing(p/2) + amplitude_damping(p/3)
# applied sequentially on the same qubit.
p_dep = p * 0.5
p_damp = p * 0.333333
p3 = p_dep / 3.0
# β depolarizing sub-channel β
if is_jax:
key, sk1, sk2 = jax.random.split(key, 3)
r_dep = jax.random.uniform(sk1, shape=(half,), minval=0.0, maxval=1.0)
ch = jax.random.uniform(sk2, shape=(half,), minval=0.0, maxval=1.0)
v0, v1 = sv_out[idx_0], sv_out[idx_1]
fire = r_dep < p_dep
x_gate = fire & (ch < p3)
y_gate = fire & (ch >= p3) & (ch < 2.0 * p3)
z_gate = fire & (ch >= 2.0 * p3)
new_v0 = jnp.where(x_gate, v1,
jnp.where(y_gate, -1j * v1, v0))
new_v1 = jnp.where(x_gate, v0,
jnp.where(y_gate, 1j * v0,
jnp.where(z_gate, -v1, v1)))
sv_out = sv_out.at[idx_0].set(new_v0)
sv_out = sv_out.at[idx_1].set(new_v1)
# β amplitude_damping sub-channel β
key, sk3 = jax.random.split(key)
r_damp = jax.random.uniform(sk3, shape=(half,), minval=0.0, maxval=1.0)
decay = r_damp < p_damp
v0, v1 = sv_out[idx_0], sv_out[idx_1]
sq_g = jnp.sqrt(p_damp)
sq_1mg = jnp.sqrt(1.0 - p_damp)
sv_out = sv_out.at[idx_0].set(jnp.where(decay, v0 + v1 * sq_g, v0))
sv_out = sv_out.at[idx_1].set(jnp.where(decay, 0.0 + 0j, v1 * sq_1mg))
else:
r_dep = rng.random(half)
ch = rng.random(half)
v0, v1 = sv_out[idx_0].copy(), sv_out[idx_1].copy()
fire = r_dep < p_dep
x_gate = fire & (ch < p3)
y_gate = fire & (ch >= p3) & (ch < 2.0 * p3)
z_gate = fire & (ch >= 2.0 * p3)
sv_out[idx_0] = np.where(x_gate, v1,
np.where(y_gate, -1j * v1, v0))
sv_out[idx_1] = np.where(x_gate, v0,
np.where(y_gate, 1j * v0,
np.where(z_gate, -v1, v1)))
r_damp = rng.random(half)
decay = r_damp < p_damp
v0, v1 = sv_out[idx_0].copy(), sv_out[idx_1].copy()
sq_g = np.sqrt(p_damp)
sq_1mg = np.sqrt(1.0 - p_damp)
sv_out[idx_0] = np.where(decay, v0 + v1 * sq_g, v0)
sv_out[idx_1] = np.where(decay, 0.0 + 0j, v1 * sq_1mg)
# ββ normalise βββββββββββββββββββββββββββββββββββββββββββββββββ
if is_jax:
norm = jnp.linalg.norm(sv_out)
return sv_out / (norm + 1e-15)
else:
norm = np.linalg.norm(sv_out)
return sv_out / (norm + 1e-15)
@staticmethod
def kraus_description(model: str) -> Dict:
desc = {
'ideal': {
'kraus': 1,
'formula': 'Kβ = I',
'physical': 'No noise',
},
'depolarizing': {
'kraus': 4,
'formula': 'Kβ=β(1-p)I Kβ=β(p/3)X Kβ=β(p/3)Y Kβ=β(p/3)Z',
'physical': 'Isotropic Pauli error β equiprobable X, Y, Z',
},
'bitflip': {
'kraus': 2,
'formula': 'Kβ=β(1-p)I Kβ=βpΒ·X',
'physical': 'Bit flip Ο_x with probability p',
},
'phaseflip': {
'kraus': 2,
'formula': 'Kβ=β(1-p)I Kβ=βpΒ·Z',
'physical': 'Pure dephasing Ο_z with probability p',
},
'amplitude_damping': {
'kraus': 2,
'formula': 'Kβ=diag(1,β(1-Ξ³)) Kβ=[[0,βΞ³],[0,0]]',
'physical': 'Tβ energy relaxation |1β©β|0β© with rate Ξ³',
},
'combined': {
'kraus': 6,
'formula': 'Depolarizing(p/2) β AmplitudeDamping(p/3)',
'physical': 'Worst-case NISQ: dephasing + relaxation',
},
}
return desc.get(model, desc['ideal'])
|