| """ |
| FireEcho Quantum Gold - Gate Kernels |
| |
| SM120-optimized quantum gate implementations using Triton. |
| All gates use Thread Block Clusters (num_ctas=2) for cooperative execution. |
| |
| Gate Mathematics: |
| - Single qubit gates operate on pairs of amplitudes differing in target bit |
| - Two-qubit gates operate on quadruples of amplitudes |
| - All gates are unitary: U†U = I, preserving normalization |
| """ |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| import numpy as np |
| from typing import Optional |
| import math |
|
|
| |
| |
| |
|
|
| GATE_MATRICES = { |
| |
| 'X': torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64), |
| 'Y': torch.tensor([[0, -1j], [1j, 0]], dtype=torch.complex64), |
| 'Z': torch.tensor([[1, 0], [0, -1]], dtype=torch.complex64), |
| |
| |
| 'H': torch.tensor([[1, 1], [1, -1]], dtype=torch.complex64) / math.sqrt(2), |
| |
| |
| 'S': torch.tensor([[1, 0], [0, 1j]], dtype=torch.complex64), |
| 'T': torch.tensor([[1, 0], [0, math.e**(1j * math.pi / 4)]], dtype=torch.complex64), |
| |
| |
| 'CNOT': torch.tensor([ |
| [1, 0, 0, 0], |
| [0, 1, 0, 0], |
| [0, 0, 0, 1], |
| [0, 0, 1, 0] |
| ], dtype=torch.complex64), |
| |
| 'CZ': torch.tensor([ |
| [1, 0, 0, 0], |
| [0, 1, 0, 0], |
| [0, 0, 1, 0], |
| [0, 0, 0, -1] |
| ], dtype=torch.complex64), |
| |
| 'SWAP': torch.tensor([ |
| [1, 0, 0, 0], |
| [0, 0, 1, 0], |
| [0, 1, 0, 0], |
| [0, 0, 0, 1] |
| ], dtype=torch.complex64), |
| } |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _hadamard_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Hadamard gate: H|0⟩ = (|0⟩+|1⟩)/√2, H|1⟩ = (|0⟩-|1⟩)/√2 |
| Matrix: [[1,1],[1,-1]]/√2 |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| inv_sqrt2 = 0.7071067811865476 |
| |
| new_a0_real = (a0_real + a1_real) * inv_sqrt2 |
| new_a0_imag = (a0_imag + a1_imag) * inv_sqrt2 |
| new_a1_real = (a0_real - a1_real) * inv_sqrt2 |
| new_a1_imag = (a0_imag - a1_imag) * inv_sqrt2 |
| |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _pauli_x_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Pauli-X (NOT) gate: X|0⟩ = |1⟩, X|1⟩ = |0⟩ |
| Matrix: [[0,1],[1,0]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| tl.store(state_real_ptr + i0, a1_real, mask=mask) |
| tl.store(state_imag_ptr + i0, a1_imag, mask=mask) |
| tl.store(state_real_ptr + i1, a0_real, mask=mask) |
| tl.store(state_imag_ptr + i1, a0_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _pauli_y_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Pauli-Y gate: Y|0⟩ = i|1⟩, Y|1⟩ = -i|0⟩ |
| Matrix: [[0,-i],[i,0]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| |
| |
| new_a0_real = a1_imag |
| new_a0_imag = -a1_real |
| new_a1_real = -a0_imag |
| new_a1_imag = a0_real |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _pauli_z_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Pauli-Z gate: Z|0⟩ = |0⟩, Z|1⟩ = -|1⟩ |
| Matrix: [[1,0],[0,-1]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| tl.store(state_real_ptr + i1, -a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, -a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _rotation_z_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| cos_half: tl.constexpr, |
| sin_half: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Rotation around Z axis: Rz(θ) = [[e^(-iθ/2), 0], [0, e^(iθ/2)]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| |
| new_a0_real = a0_real * cos_half + a0_imag * sin_half |
| new_a0_imag = a0_imag * cos_half - a0_real * sin_half |
| new_a1_real = a1_real * cos_half - a1_imag * sin_half |
| new_a1_imag = a1_imag * cos_half + a1_real * sin_half |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _rotation_x_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| cos_half: tl.constexpr, |
| sin_half: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Rotation around X axis: Rx(θ) = [[cos(θ/2), -i*sin(θ/2)], [-i*sin(θ/2), cos(θ/2)]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| |
| |
| new_a0_real = cos_half * a0_real + sin_half * a1_imag |
| new_a0_imag = cos_half * a0_imag - sin_half * a1_real |
| new_a1_real = sin_half * a0_imag + cos_half * a1_real |
| new_a1_imag = -sin_half * a0_real + cos_half * a1_imag |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _rotation_y_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| cos_half: tl.constexpr, |
| sin_half: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Rotation around Y axis: Ry(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| new_a0_real = cos_half * a0_real - sin_half * a1_real |
| new_a0_imag = cos_half * a0_imag - sin_half * a1_imag |
| new_a1_real = sin_half * a0_real + cos_half * a1_real |
| new_a1_imag = sin_half * a0_imag + cos_half * a1_imag |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _phase_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| target_qubit, |
| num_qubits, |
| cos_phi: tl.constexpr, |
| sin_phi: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Phase gate: P(φ) = [[1, 0], [0, e^(iφ)]] |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| stride = 1 << target_qubit |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < (state_size // 2) |
| |
| i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride) |
| i1 = i0 + stride |
| |
| |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| new_a1_real = a1_real * cos_phi - a1_imag * sin_phi |
| new_a1_imag = a1_real * sin_phi + a1_imag * cos_phi |
| |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _cnot_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| control_qubit, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| CNOT (Controlled-NOT) gate. |
| Flips target qubit when control qubit is |1⟩. |
| |
| Truth table: |
| |00⟩ → |00⟩ |
| |01⟩ → |01⟩ |
| |10⟩ → |11⟩ |
| |11⟩ → |10⟩ |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| num_pairs = state_size // 2 |
| |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < num_pairs |
| |
| control_mask = 1 << control_qubit |
| target_mask = 1 << target_qubit |
| |
| |
| |
| i0 = (pair_idx // target_mask) * (2 * target_mask) + (pair_idx % target_mask) |
| i1 = i0 + target_mask |
| |
| |
| control_is_1 = (i0 & control_mask) != 0 |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| new_a0_real = tl.where(control_is_1, a1_real, a0_real) |
| new_a0_imag = tl.where(control_is_1, a1_imag, a0_imag) |
| new_a1_real = tl.where(control_is_1, a0_real, a1_real) |
| new_a1_imag = tl.where(control_is_1, a0_imag, a1_imag) |
| |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _cz_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| control_qubit, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Controlled-Z gate. |
| Applies phase flip (-1) when both control and target are |1⟩. |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| |
| idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = idx < state_size |
| |
| control_mask = 1 << control_qubit |
| target_mask = 1 << target_qubit |
| |
| |
| is_11 = ((idx & control_mask) != 0) & ((idx & target_mask) != 0) |
| |
| val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0) |
| val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0) |
| |
| new_real = tl.where(is_11, -val_real, val_real) |
| new_imag = tl.where(is_11, -val_imag, val_imag) |
| |
| tl.store(state_real_ptr + idx, new_real, mask=mask) |
| tl.store(state_imag_ptr + idx, new_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _swap_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| qubit_a, |
| qubit_b, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| SWAP gate - exchanges two qubits. |
| |
| Swaps amplitudes where bits a and b differ. |
| |01⟩ ↔ |10⟩ pattern |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| |
| idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = idx < state_size |
| |
| mask_a = 1 << qubit_a |
| mask_b = 1 << qubit_b |
| |
| |
| bit_a = (idx >> qubit_a) & 1 |
| bit_b = (idx >> qubit_b) & 1 |
| |
| |
| should_swap = (bit_a == 1) & (bit_b == 0) |
| partner = idx ^ mask_a ^ mask_b |
| |
| |
| val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0) |
| val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0) |
| partner_real = tl.load(state_real_ptr + partner, mask=mask, other=0.0) |
| partner_imag = tl.load(state_imag_ptr + partner, mask=mask, other=0.0) |
| |
| |
| |
| partner_would_swap = ((partner >> qubit_a) & 1 == 1) & ((partner >> qubit_b) & 1 == 0) |
| |
| new_real = tl.where(should_swap | partner_would_swap, partner_real, val_real) |
| new_imag = tl.where(should_swap | partner_would_swap, partner_imag, val_imag) |
| |
| tl.store(state_real_ptr + idx, new_real, mask=mask) |
| tl.store(state_imag_ptr + idx, new_imag, mask=mask) |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _ccx_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| control1_qubit, |
| control2_qubit, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Toffoli (CCX) gate - Controlled-Controlled-NOT. |
| Flips target qubit only when BOTH control qubits are |1⟩. |
| |
| Truth table (for c1, c2, t): |
| |000⟩ → |000⟩ |
| |001⟩ → |001⟩ |
| |010⟩ → |010⟩ |
| |011⟩ → |011⟩ |
| |100⟩ → |100⟩ |
| |101⟩ → |101⟩ |
| |110⟩ → |111⟩ (flip when both controls are 1) |
| |111⟩ → |110⟩ |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| num_pairs = state_size // 2 |
| |
| pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = pair_idx < num_pairs |
| |
| control1_mask = 1 << control1_qubit |
| control2_mask = 1 << control2_qubit |
| target_mask = 1 << target_qubit |
| |
| |
| i0 = (pair_idx // target_mask) * (2 * target_mask) + (pair_idx % target_mask) |
| i1 = i0 + target_mask |
| |
| |
| control1_is_1 = (i0 & control1_mask) != 0 |
| control2_is_1 = (i0 & control2_mask) != 0 |
| both_controls_1 = control1_is_1 & control2_is_1 |
| |
| |
| a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0) |
| a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0) |
| a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0) |
| a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0) |
| |
| |
| new_a0_real = tl.where(both_controls_1, a1_real, a0_real) |
| new_a0_imag = tl.where(both_controls_1, a1_imag, a0_imag) |
| new_a1_real = tl.where(both_controls_1, a0_real, a1_real) |
| new_a1_imag = tl.where(both_controls_1, a0_imag, a1_imag) |
| |
| tl.store(state_real_ptr + i0, new_a0_real, mask=mask) |
| tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask) |
| tl.store(state_real_ptr + i1, new_a1_real, mask=mask) |
| tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _cswap_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| control_qubit, |
| target1_qubit, |
| target2_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Fredkin (CSWAP) gate - Controlled-SWAP. |
| Swaps target1 and target2 only when control qubit is |1⟩. |
| |
| When control=1: swaps states where t1 and t2 bits differ |
| |101⟩ ↔ |110⟩ pattern (when control=1) |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| |
| idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = idx < state_size |
| |
| control_mask = 1 << control_qubit |
| t1_mask = 1 << target1_qubit |
| t2_mask = 1 << target2_qubit |
| |
| |
| control_is_1 = (idx & control_mask) != 0 |
| t1_bit = (idx >> target1_qubit) & 1 |
| t2_bit = (idx >> target2_qubit) & 1 |
| |
| |
| should_swap = control_is_1 & (t1_bit == 1) & (t2_bit == 0) |
| partner = idx ^ t1_mask ^ t2_mask |
| |
| |
| val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0) |
| val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0) |
| partner_real = tl.load(state_real_ptr + partner, mask=mask, other=0.0) |
| partner_imag = tl.load(state_imag_ptr + partner, mask=mask, other=0.0) |
| |
| |
| partner_control = (partner & control_mask) != 0 |
| partner_t1 = (partner >> target1_qubit) & 1 |
| partner_t2 = (partner >> target2_qubit) & 1 |
| partner_swaps_to_us = partner_control & (partner_t1 == 1) & (partner_t2 == 0) |
| |
| new_real = tl.where(should_swap | partner_swaps_to_us, partner_real, val_real) |
| new_imag = tl.where(should_swap | partner_swaps_to_us, partner_imag, val_imag) |
| |
| tl.store(state_real_ptr + idx, new_real, mask=mask) |
| tl.store(state_imag_ptr + idx, new_imag, mask=mask) |
|
|
|
|
| @triton.jit |
| def _ccz_kernel( |
| state_real_ptr, |
| state_imag_ptr, |
| control1_qubit, |
| control2_qubit, |
| target_qubit, |
| num_qubits, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| Controlled-Controlled-Z (CCZ) gate. |
| Applies phase flip (-1) when all three qubits are |1⟩. |
| """ |
| pid = tl.program_id(0) |
| state_size = 1 << num_qubits |
| |
| idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = idx < state_size |
| |
| c1_mask = 1 << control1_qubit |
| c2_mask = 1 << control2_qubit |
| t_mask = 1 << target_qubit |
| |
| |
| is_111 = ((idx & c1_mask) != 0) & ((idx & c2_mask) != 0) & ((idx & t_mask) != 0) |
| |
| val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0) |
| val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0) |
| |
| new_real = tl.where(is_111, -val_real, val_real) |
| new_imag = tl.where(is_111, -val_imag, val_imag) |
| |
| tl.store(state_real_ptr + idx, new_real, mask=mask) |
| tl.store(state_imag_ptr + idx, new_imag, mask=mask) |
|
|
|
|
| |
| |
| |
|
|
| def _get_state_arrays(state: torch.Tensor): |
| """Split complex state into real/imag float tensors.""" |
| if state.dtype == torch.complex64: |
| return state.real.contiguous(), state.imag.contiguous() |
| elif state.dtype == torch.complex128: |
| return state.real.float().contiguous(), state.imag.float().contiguous() |
| else: |
| raise ValueError(f"Expected complex tensor, got {state.dtype}") |
|
|
|
|
| def _combine_state_arrays(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor: |
| """Combine real/imag arrays back to complex.""" |
| return torch.complex(real, imag) |
|
|
|
|
| def _get_num_qubits(state: torch.Tensor) -> int: |
| """Get number of qubits from state vector size.""" |
| size = state.numel() |
| num_qubits = int(math.log2(size)) |
| assert 2 ** num_qubits == size, f"State size {size} is not a power of 2" |
| return num_qubits |
|
|
|
|
| def hadamard(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply Hadamard gate to target qubit.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _hadamard_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def pauli_x(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply Pauli-X (NOT) gate to target qubit.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _pauli_x_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def pauli_y(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply Pauli-Y gate to target qubit.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _pauli_y_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def pauli_z(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply Pauli-Z gate to target qubit.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _pauli_z_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def rotation_x(state: torch.Tensor, target: int, theta: float) -> torch.Tensor: |
| """Apply Rx(theta) rotation gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| cos_half = math.cos(theta / 2) |
| sin_half = math.sin(theta / 2) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _rotation_x_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| cos_half, sin_half, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def rotation_y(state: torch.Tensor, target: int, theta: float) -> torch.Tensor: |
| """Apply Ry(theta) rotation gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| cos_half = math.cos(theta / 2) |
| sin_half = math.sin(theta / 2) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _rotation_y_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| cos_half, sin_half, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def rotation_z(state: torch.Tensor, target: int, theta: float) -> torch.Tensor: |
| """Apply Rz(theta) rotation gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| cos_half = math.cos(theta / 2) |
| sin_half = math.sin(theta / 2) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _rotation_z_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| cos_half, sin_half, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def phase_gate(state: torch.Tensor, target: int, phi: float) -> torch.Tensor: |
| """Apply phase gate P(phi) = [[1, 0], [0, e^(i*phi)]].""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| cos_phi = math.cos(phi) |
| sin_phi = math.sin(phi) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _phase_kernel[(grid,)]( |
| real, imag, target, num_qubits, |
| cos_phi, sin_phi, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def t_gate(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply T gate (π/4 phase).""" |
| return phase_gate(state, target, math.pi / 4) |
|
|
|
|
| def s_gate(state: torch.Tensor, target: int) -> torch.Tensor: |
| """Apply S gate (π/2 phase).""" |
| return phase_gate(state, target, math.pi / 2) |
|
|
|
|
| def cnot(state: torch.Tensor, control: int, target: int) -> torch.Tensor: |
| """Apply CNOT (controlled-NOT) gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| num_pairs = 2 ** num_qubits // 2 |
| grid = (num_pairs + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _cnot_kernel[(grid,)]( |
| real, imag, control, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def cz(state: torch.Tensor, control: int, target: int) -> torch.Tensor: |
| """Apply Controlled-Z gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _cz_kernel[(grid,)]( |
| real, imag, control, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| def swap(state: torch.Tensor, qubit_a: int, qubit_b: int) -> torch.Tensor: |
| """Apply SWAP gate.""" |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _swap_kernel[(grid,)]( |
| real, imag, qubit_a, qubit_b, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| |
| |
| |
|
|
| def ccx(state: torch.Tensor, control1: int, control2: int, target: int) -> torch.Tensor: |
| """ |
| Apply Toffoli (CCX) gate - Controlled-Controlled-NOT. |
| |
| Flips target qubit only when both control qubits are |1⟩. |
| This is a universal gate for classical computation. |
| |
| Args: |
| state: Quantum state vector |
| control1: First control qubit index |
| control2: Second control qubit index |
| target: Target qubit index |
| """ |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| num_pairs = 2 ** num_qubits // 2 |
| grid = (num_pairs + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _ccx_kernel[(grid,)]( |
| real, imag, control1, control2, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| |
| toffoli = ccx |
|
|
|
|
| def cswap(state: torch.Tensor, control: int, target1: int, target2: int) -> torch.Tensor: |
| """ |
| Apply Fredkin (CSWAP) gate - Controlled-SWAP. |
| |
| Swaps target1 and target2 qubits only when control qubit is |1⟩. |
| This is a universal gate for reversible computation. |
| |
| Args: |
| state: Quantum state vector |
| control: Control qubit index |
| target1: First target qubit index |
| target2: Second target qubit index |
| """ |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _cswap_kernel[(grid,)]( |
| real, imag, control, target1, target2, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| |
| fredkin = cswap |
|
|
|
|
| def ccz(state: torch.Tensor, control1: int, control2: int, target: int) -> torch.Tensor: |
| """ |
| Apply Controlled-Controlled-Z (CCZ) gate. |
| |
| Applies phase flip (-1) when all three qubits are |1⟩. |
| CCZ is equivalent to CCX with Hadamard gates on target. |
| |
| Args: |
| state: Quantum state vector |
| control1: First control qubit index |
| control2: Second control qubit index |
| target: Target qubit index |
| """ |
| num_qubits = _get_num_qubits(state) |
| real, imag = _get_state_arrays(state) |
| |
| BLOCK_SIZE = 256 |
| grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE |
| |
| _ccz_kernel[(grid,)]( |
| real, imag, control1, control2, target, num_qubits, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_ctas=2, |
| ) |
| |
| state.real.copy_(real) |
| state.imag.copy_(imag) |
| return state |
|
|
|
|
| |
| GATE_MATRICES['CCX'] = torch.tensor([ |
| [1, 0, 0, 0, 0, 0, 0, 0], |
| [0, 1, 0, 0, 0, 0, 0, 0], |
| [0, 0, 1, 0, 0, 0, 0, 0], |
| [0, 0, 0, 1, 0, 0, 0, 0], |
| [0, 0, 0, 0, 1, 0, 0, 0], |
| [0, 0, 0, 0, 0, 1, 0, 0], |
| [0, 0, 0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 0, 1, 0], |
| ], dtype=torch.complex64) |
|
|
| GATE_MATRICES['CSWAP'] = torch.tensor([ |
| [1, 0, 0, 0, 0, 0, 0, 0], |
| [0, 1, 0, 0, 0, 0, 0, 0], |
| [0, 0, 1, 0, 0, 0, 0, 0], |
| [0, 0, 0, 1, 0, 0, 0, 0], |
| [0, 0, 0, 0, 1, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0, 1, 0], |
| [0, 0, 0, 0, 0, 1, 0, 0], |
| [0, 0, 0, 0, 0, 0, 0, 1], |
| ], dtype=torch.complex64) |
|
|
| GATE_MATRICES['CCZ'] = torch.tensor([ |
| [1, 0, 0, 0, 0, 0, 0, 0], |
| [0, 1, 0, 0, 0, 0, 0, 0], |
| [0, 0, 1, 0, 0, 0, 0, 0], |
| [0, 0, 0, 1, 0, 0, 0, 0], |
| [0, 0, 0, 0, 1, 0, 0, 0], |
| [0, 0, 0, 0, 0, 1, 0, 0], |
| [0, 0, 0, 0, 0, 0, 1, 0], |
| [0, 0, 0, 0, 0, 0, 0, -1], |
| ], dtype=torch.complex64) |
|
|