File size: 20,213 Bytes
4eff328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import numpy as np
from typing import List, Tuple, Optional
from .registry import HAS_JAX
from .gates import GATES, PARAMETRIC_GATES, GATE_IDS
from .compiler import QuantumTranspiler

if HAS_JAX:
    import jax
    import jax.numpy as jnp
    jax.config.update("jax_enable_x64", True)
    from .compiler import _compile_and_run_circuit_jit


# ─────────────────────────────────────────────────────────────────────────────
# Internal helpers 
# ─────────────────────────────────────────────────────────────────────────────
#  Copyright (c) 2026 Salvatore Pennacchio <jtatopenn@libero.it>
#  Distributed under the Business Source License 1.1 (BSL 1.1)
#  See LICENSE.md in the project root for full license terms.

def _qubit_stride_pairs(n: int, qubit: int):
    """
    Return (stride, outer_step, inner_step) for the MSB-first statevector
    convention used throughout this simulator.

    In MSB-first ordering qubit 0 is the *most* significant bit, so:
        physical_bit_position = n - 1 - qubit
        stride = 1 << physical_bit_position
    """
    phys   = n - 1 - qubit
    stride = 1 << phys
    return stride


def _cx_numpy(sv: np.ndarray, n: int, ctrl: int, tgt: int) -> np.ndarray:
    """
    Vectorised CX on a NumPy statevector.
    No Python loops β€” uses strided index arithmetic.
    """
    dim      = len(sv)
    c_stride = 1 << (n - 1 - ctrl)
    t_stride = 1 << (n - 1 - tgt)
    all_i    = np.arange(dim, dtype=np.intp)
    # Select indices where ctrl bit == 1 and tgt bit == 0
    mask     = ((all_i & c_stride) != 0) & ((all_i & t_stride) == 0)
    idx_0    = all_i[mask]
    idx_1    = idx_0 | t_stride
    sv       = sv.copy()
    sv[idx_0], sv[idx_1] = sv[idx_1].copy(), sv[idx_0].copy()
    return sv


def _cz_numpy(sv: np.ndarray, n: int, ctrl: int, tgt: int) -> np.ndarray:
    """Vectorised CZ on a NumPy statevector."""
    dim      = len(sv)
    c_stride = 1 << (n - 1 - ctrl)
    t_stride = 1 << (n - 1 - tgt)
    all_i    = np.arange(dim, dtype=np.intp)
    mask     = ((all_i & c_stride) != 0) & ((all_i & t_stride) != 0)
    sv       = sv.copy()
    sv[mask] *= -1
    return sv


# ─────────────────────────────────────────────────────────────────────────────
# DenseSVSimulator
# ─────────────────────────────────────────────────────────────────────────────

class DenseSVSimulator:
    """
    Dense statevector quantum circuit simulator.

    Qubit ordering: MSB-first (qubit 0 is the most significant bit).
    Backends: NumPy (CPU), JAX XLA JIT (CPU/GPU/TPU).

    Parameters
    ----------
    n_qubits   : number of qubits
    use_gpu    : reserved for future CuPy/JAX GPU dispatch
    use_float32: use complex64 instead of complex128
    """

    def __init__(self, n_qubits: int,
                 use_gpu:     bool = False,
                 use_float32: bool = False):
        if n_qubits < 1 or n_qubits > 34:
            raise ValueError(f"n_qubits must be in [1, 34], got {n_qubits}")
        self.n         = n_qubits
        self.dim       = 1 << n_qubits            # 2 ** n_qubits
        self.use_float32 = use_float32
        self.dtype     = np.complex64 if use_float32 else np.complex128
        self.xp        = jnp if HAS_JAX else np
        self._reset_sv()

    # ── state initialisation ─────────────────────────────────────────

    def _reset_sv(self):
        """Allocate |0...0⟩ on the active backend."""
        if HAS_JAX:
            self.sv = jnp.zeros(self.dim, dtype=self.dtype).at[0].set(1.0)
        else:
            self.sv    = np.zeros(self.dim, dtype=self.dtype)
            self.sv[0] = 1.0

    def set_initial_state(self, state: Optional[np.ndarray] = None):
        """
        Reset the simulator.

        Parameters
        ----------
        state : optional complex array of length 2**n.
                If None, resets to |0...0⟩.
                The array is normalised automatically.
        """
        if state is None:
            self._reset_sv()
            return
        state = np.asarray(state, dtype=self.dtype)
        if state.shape != (self.dim,):
            raise ValueError(
                f"State vector length {len(state)} != 2**{self.n} = {self.dim}")
        norm = np.linalg.norm(state)
        if norm < 1e-12:
            raise ValueError("Cannot set a zero-norm state vector")
        state = state / norm
        if HAS_JAX:
            self.sv = jnp.array(state)
        else:
            self.sv = state.copy()

    # Alias used by the VQE engine
    def set_state(self, state: np.ndarray):
        self.set_initial_state(state)

    # ── normalisation ─────────────────────────────────────────────────

    def normalize(self):
        norm = float(self.xp.linalg.norm(self.sv))
        if norm > 1e-12:
            if HAS_JAX:
                self.sv = self.sv / norm
            else:
                self.sv /= norm

    # ── 1-qubit gate ──────────────────────────────────────────────────

    def apply_gate_1q(self, gate: np.ndarray, qubit: int):
        """
        Apply a 2Γ—2 unitary to *qubit* via tensor contraction.

        Uses reshape + moveaxis + matmul β€” fully vectorised,
        no Python loops, compatible with both NumPy and JAX.
        """
        if not 0 <= qubit < self.n:
            raise ValueError(f"Qubit index {qubit} out of range [0, {self.n})")
        gate = self.xp.array(gate, dtype=self.dtype)
        sv_nd      = self.sv.reshape([2] * self.n)
        sv_moved   = self.xp.moveaxis(sv_nd, qubit, -1)          # qubit axis β†’ last
        flat_shape = (self.dim >> 1, 2)
        # matmul: (dim/2, 2) @ (2, 2).T  β†’ (dim/2, 2)
        result     = self.xp.dot(sv_moved.reshape(flat_shape),
                                 gate.T)
        self.sv    = self.xp.moveaxis(
            result.reshape([2] * self.n), -1, qubit).ravel()

    # ── 2-qubit gate ──────────────────────────────────────────────────

    def apply_gate_2q(self, gate: np.ndarray, q1: int, q2: int):
        """
        Apply a 4Γ—4 unitary to qubits (q1, q2) via tensor contraction.
        """
        if q1 == q2:
            raise ValueError("Control and target qubits must differ")
        if not (0 <= q1 < self.n and 0 <= q2 < self.n):
            raise ValueError(f"Qubit indices ({q1},{q2}) out of range [0, {self.n})")
        gate = self.xp.array(gate, dtype=self.dtype)
        sv_nd      = self.sv.reshape([2] * self.n)
        sv_moved   = self.xp.moveaxis(sv_nd, (q1, q2), (-2, -1))
        flat_shape = (self.dim >> 2, 4)
        result     = self.xp.dot(sv_moved.reshape(flat_shape),
                                 gate.reshape(4, 4).T)
        self.sv    = self.xp.moveaxis(
            result.reshape([2] * self.n), (-2, -1), (q1, q2)).ravel()

    # ── specialised 2-qubit gates ─────────────────────────────────────

    def apply_cx(self, ctrl: int, tgt: int):
        """
        CX (CNOT) gate.

        JAX path: matrix contraction via apply_gate_2q.
        NumPy path: fully vectorised index swap β€” no Python loops.
        """
        if ctrl == tgt:
            raise ValueError("Control and target qubits must differ")
        if not (0 <= ctrl < self.n and 0 <= tgt < self.n):
            raise ValueError(f"Qubit indices ({ctrl},{tgt}) out of range [0, {self.n})")
        if HAS_JAX:
            cx_mat = jnp.array([
                [1, 0, 0, 0],
                [0, 1, 0, 0],
                [0, 0, 0, 1],
                [0, 0, 1, 0],
            ], dtype=self.dtype)
            self.apply_gate_2q(cx_mat, ctrl, tgt)
        else:
            self.sv = _cx_numpy(np.array(self.sv), self.n, ctrl, tgt)

    def apply_cz(self, ctrl: int, tgt: int):
        """
        CZ gate.

        JAX path: matrix contraction via apply_gate_2q.
        NumPy path: fully vectorised sign flip β€” no Python loops.
        """
        if ctrl == tgt:
            raise ValueError("Control and target qubits must differ")
        if not (0 <= ctrl < self.n and 0 <= tgt < self.n):
            raise ValueError(f"Qubit indices ({ctrl},{tgt}) out of range [0, {self.n})")
        if HAS_JAX:
            cz_mat = jnp.array([
                [1, 0, 0,  0],
                [0, 1, 0,  0],
                [0, 0, 1,  0],
                [0, 0, 0, -1],
            ], dtype=self.dtype)
            self.apply_gate_2q(cz_mat, ctrl, tgt)
        else:
            self.sv = _cz_numpy(np.array(self.sv), self.n, ctrl, tgt)

    def apply_rx(self, qubit: int, theta: float):
        """Apply a parameterized RX gate using the active backend (NumPy/JAX)."""
        cos, sin = self.xp.cos(theta / 2), self.xp.sin(theta / 2)
        mat = self.xp.array([[cos, -1j * sin], [-1j * sin, cos]], dtype=self.dtype)
        self.apply_gate_1q(mat, qubit)

    def apply_ry(self, qubit: int, theta: float):
        """Apply a parameterized RY gate using the active backend (NumPy/JAX)."""
        cos, sin = self.xp.cos(theta / 2), self.xp.sin(theta / 2)
        mat = self.xp.array([[cos, -sin], [sin, cos]], dtype=self.dtype)
        self.apply_gate_1q(mat, qubit)

    def apply_rz(self, qubit: int, theta: float):
        """Apply a parameterized RZ gate using the active backend (NumPy/JAX)."""
        exp_neg = self.xp.exp(-1j * theta / 2)
        exp_pos = self.xp.exp(1j * theta / 2)
        mat = self.xp.array([[exp_neg, 0.0], [0.0, exp_pos]], dtype=self.dtype)
        self.apply_gate_1q(mat, qubit)

    # ── measurement ───────────────────────────────────────────────────

    def measure(self, qubit_idx: int) -> int:
        """
        Projective measurement on *qubit_idx*.

        Returns 0 or 1 and collapses the statevector.
        Uses MSB-first physical bit index: phys = n - 1 - qubit_idx.

        BUG FIX (original): the original NumPy collapse wrote
            sv_reshaped[:, 1 if result == 0 else 0, :] = 0.0
        which zeroed the *wrong* basis state (0 when result=1, 1 when result=0)
        and never normalised the JAX path.
        """
        if not 0 <= qubit_idx < self.n:
            raise ValueError(
                f"Qubit {qubit_idx} out of range [0, {self.n})")

        phys   = self.n - 1 - qubit_idx
        stride = 1 << phys

        # ── compute marginal probabilities ──────────────────────────
        if HAS_JAX:
            probs   = jnp.abs(self.sv) ** 2
            sv_nd   = probs.reshape([2] * self.n)
            mv      = jnp.moveaxis(sv_nd, phys, 0)
            prob_0  = float(jnp.sum(mv[0]))
            prob_1  = float(jnp.sum(mv[1]))
        else:
            sv_res  = self.sv.reshape(-1, 2, stride)
            prob_0  = float(np.sum(np.abs(sv_res[:, 0, :]) ** 2))
            prob_1  = float(np.sum(np.abs(sv_res[:, 1, :]) ** 2))

        total = prob_0 + prob_1
        if total < 1e-12:
            raise RuntimeError("Statevector norm is zero β€” cannot measure")
        prob_0 /= total
        prob_1 /= total

        result = int(np.random.choice([0, 1], p=[prob_0, prob_1]))

        # ── collapse ────────────────────────────────────────────────
        # Zero out the amplitudes corresponding to the *opposite* outcome.
        zero_slot = 1 - result     # if result=0, zero slot 1; if result=1, zero slot 0

        if HAS_JAX:
            sv_nd  = self.sv.reshape([2] * self.n)
            mv     = jnp.moveaxis(sv_nd, phys, 0)
            mv     = mv.at[zero_slot].set(0.0 + 0j)
            self.sv = jnp.moveaxis(mv, 0, phys).ravel()
        else:
            sv_res = self.sv.reshape(-1, 2, stride)
            sv_res[:, zero_slot, :] = 0.0
            self.sv = sv_res.ravel()

        self.normalize()
        return result

    # ── circuit execution ─────────────────────────────────────────────

    def run_circuit(self, circuit: List[Tuple], transpile: bool = True):
        target = QuantumTranspiler.transpile(circuit) if transpile else circuit
        for cmd in target:
            name = cmd[0].lower()
            args = cmd[1:]

            if name in GATES:
                mat = self.xp.array(GATES[name], dtype=self.dtype)
                if mat.shape == (2, 2):
                    self.apply_gate_1q(mat, int(args[0]))
                else:
                    self.apply_gate_2q(mat, int(args[0]), int(args[1]))

            elif name in PARAMETRIC_GATES:
                if len(args) == 2:
                    mat = self.xp.array(PARAMETRIC_GATES[name](args[1]), dtype=self.dtype)
                    self.apply_gate_1q(mat, int(args[0]))
                elif len(args) == 3:
                    mat = self.xp.array(PARAMETRIC_GATES[name](args[2]), dtype=self.dtype)
                    self.apply_gate_2q(mat, int(args[0]), int(args[1]))
                elif len(args) == 4:
                    mat = self.xp.array(
                        PARAMETRIC_GATES[name](args[1], args[2], args[3]),
                        dtype=self.dtype)
                    self.apply_gate_1q(mat, int(args[0]))



    def run_circuit_jit_beast_mode(self, circuit: List):
       
        if not HAS_JAX:
            return self.run_circuit(circuit)

        target       = QuantumTranspiler.transpile(circuit)
        compiled_ops = []

        for cmd in target:
            name = cmd[0].lower() if isinstance(cmd[0], str) else str(cmd[0]).lower()
            if name not in GATE_IDS:
                continue

            g_id = float(GATE_IDS[name])
            args = cmd[1:]

            # ── gate argument parsing ──────────────────────────────
            # 1-qubit parametric: (name, qubit, param)
            if name in ('rx', 'ry', 'rz', 'p', 'u1', 'phase'):
                q1 = float(args[0])
                p  = float(args[1]) if len(args) > 1 else 0.0
                compiled_ops.append([g_id, q1, 0.0, p])

            # 2-qubit parametric: (name, ctrl, tgt, param)
            elif name in ('cp', 'crz', 'cphase'):
                ctrl = float(args[0])
                tgt  = float(args[1]) if len(args) > 1 else 0.0
                p    = float(args[2]) if len(args) > 2 else 0.0
                compiled_ops.append([g_id, ctrl, tgt, p])

            # 2-qubit non-parametric: (name, ctrl, tgt)
            elif name in ('cx', 'cz', 'swap', 'cy'):
                ctrl = float(args[0])
                tgt  = float(args[1]) if len(args) > 1 else 0.0
                compiled_ops.append([g_id, ctrl, tgt, 0.0])

            # 1-qubit non-parametric: (name, qubit)
            else:
                q1 = float(args[0]) if args else 0.0
                compiled_ops.append([g_id, q1, 0.0, 0.0])

        if compiled_ops:
            ops_jnp = jnp.array(compiled_ops, dtype=jnp.float64)
            self.sv  = _compile_and_run_circuit_jit(self.sv, ops_jnp)

    def run_circuit_with_chunking(self, circuit: List, chunk_size: int = 500):
        """
        Execute a circuit in chunks to avoid JIT recompilation on
        large variable-length circuits.

        Each chunk is a separate _compile_and_run_circuit_jit call
        with a fixed-size ops array, allowing XLA to cache each size.
        """
        target = QuantumTranspiler.transpile(circuit)
        for i in range(0, len(target), chunk_size):
            self.run_circuit_jit_beast_mode(target[i: i + chunk_size])

    def run_parametric_batch_jit(self,
                                  base_circuit:    List,
                                  parameter_batch: np.ndarray) -> "jnp.ndarray":
      
        if not HAS_JAX:
            raise RuntimeError("run_parametric_batch_jit requires JAX")

        target       = QuantumTranspiler.transpile(base_circuit)
        compiled_ops = []

        for cmd in target:
            name = cmd[0].lower() if isinstance(cmd[0], str) else str(cmd[0]).lower()
            if name not in GATE_IDS:
                continue
            g_id = float(GATE_IDS[name])
            args = cmd[1:]
            if name in ('rx', 'ry', 'rz', 'p', 'u1', 'phase'):
                compiled_ops.append([g_id, float(args[0]), 0.0, -1.0])   # -1.0 = param slot
            elif name in ('cp', 'crz', 'cphase'):
                compiled_ops.append([g_id, float(args[0]), float(args[1]) if len(args) > 1 else 0.0, -1.0])
            elif name in ('cx', 'cz', 'swap', 'cy'):
                compiled_ops.append([g_id, float(args[0]), float(args[1]) if len(args) > 1 else 0.0, 0.0])
            else:
                compiled_ops.append([g_id, float(args[0]) if args else 0.0, 0.0, 0.0])

        template    = jnp.array(compiled_ops, dtype=jnp.float64)
        init_sv     = jnp.zeros(self.dim, dtype=jnp.complex128).at[0].set(1.0)

        def simulate_single_instance(single_params: "jnp.ndarray") -> "jnp.ndarray":
            """Run one parameter vector through the circuit."""

            def patch_and_apply(carry: "jnp.ndarray",
                                op:    "jnp.ndarray"):
                """
                carry: jnp.int32 scalar β€” current parametric gate index.
                op:    [g_id, q1, q2, p_sentinel]
                """
                idx       = carry
                is_param  = op[3] == -1.0
                final_p   = jnp.where(is_param, single_params[idx], op[3])
                next_idx  = jnp.where(is_param, idx + jnp.int32(1), idx)
                patched   = jnp.array([op[0], op[1], op[2], final_p],
                                       dtype=jnp.float64)
                return next_idx, patched

            _, patched_ops = jax.lax.scan(
                patch_and_apply,
                jnp.int32(0),       # BUG FIX: was (0,) tuple β€” must be a scalar
                template,
            )
            return _compile_and_run_circuit_jit(init_sv, patched_ops)

        return jax.jit(jax.vmap(simulate_single_instance, in_axes=(0,)))(
            jnp.asarray(parameter_batch, dtype=jnp.float64)
        )

    # ── observables ───────────────────────────────────────────────────

    def get_probabilities(self) -> np.ndarray:
        """Return measurement probability distribution as a NumPy float64 array."""
        probs = np.array(self.xp.abs(self.sv) ** 2, dtype=np.float64)
        # guard against floating-point leakage outside [0, 1]
        probs = np.clip(probs, 0.0, 1.0)
        total = probs.sum()
        if total > 1e-12:
            probs /= total
        return probs

    def get_statevector(self) -> np.ndarray:
        """Return the current statevector as a NumPy complex array."""
        return np.array(self.sv, dtype=self.dtype)

    def memory_mb(self) -> float:
        """Statevector memory footprint in megabytes."""
        bytes_per_element = 8 if self.use_float32 else 16   # complex64=8, complex128=16
        return self.dim * bytes_per_element / 1_000_000