FireEcho / quantum /gates.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
"""
FireEcho Quantum Gold - Gate Kernels
SM120-optimized quantum gate implementations using Triton.
All gates use Thread Block Clusters (num_ctas=2) for cooperative execution.
Gate Mathematics:
- Single qubit gates operate on pairs of amplitudes differing in target bit
- Two-qubit gates operate on quadruples of amplitudes
- All gates are unitary: U†U = I, preserving normalization
"""
import torch
import triton
import triton.language as tl
import numpy as np
from typing import Optional
import math
# =============================================================================
# Gate Matrices (for reference and validation)
# =============================================================================
GATE_MATRICES = {
# Pauli gates
'X': torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64),
'Y': torch.tensor([[0, -1j], [1j, 0]], dtype=torch.complex64),
'Z': torch.tensor([[1, 0], [0, -1]], dtype=torch.complex64),
# Hadamard
'H': torch.tensor([[1, 1], [1, -1]], dtype=torch.complex64) / math.sqrt(2),
# Phase gates
'S': torch.tensor([[1, 0], [0, 1j]], dtype=torch.complex64),
'T': torch.tensor([[1, 0], [0, math.e**(1j * math.pi / 4)]], dtype=torch.complex64),
# Two-qubit gates
'CNOT': torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]
], dtype=torch.complex64),
'CZ': torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, -1]
], dtype=torch.complex64),
'SWAP': torch.tensor([
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]
], dtype=torch.complex64),
}
# =============================================================================
# Triton Kernels for Single-Qubit Gates
# =============================================================================
@triton.jit
def _hadamard_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Hadamard gate: H|0⟩ = (|0⟩+|1⟩)/√2, H|1⟩ = (|0⟩-|1⟩)/√2
Matrix: [[1,1],[1,-1]]/√2
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
# Each thread handles one pair of amplitudes
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
# Compute indices i0 and i1 that differ only in target_qubit bit
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Load amplitudes
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# H gate: 1/√2 * [[1, 1], [1, -1]]
inv_sqrt2 = 0.7071067811865476
new_a0_real = (a0_real + a1_real) * inv_sqrt2
new_a0_imag = (a0_imag + a1_imag) * inv_sqrt2
new_a1_real = (a0_real - a1_real) * inv_sqrt2
new_a1_imag = (a0_imag - a1_imag) * inv_sqrt2
# Store results
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _pauli_x_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Pauli-X (NOT) gate: X|0⟩ = |1⟩, X|1⟩ = |0⟩
Matrix: [[0,1],[1,0]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Load
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Swap amplitudes
tl.store(state_real_ptr + i0, a1_real, mask=mask)
tl.store(state_imag_ptr + i0, a1_imag, mask=mask)
tl.store(state_real_ptr + i1, a0_real, mask=mask)
tl.store(state_imag_ptr + i1, a0_imag, mask=mask)
@triton.jit
def _pauli_y_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Pauli-Y gate: Y|0⟩ = i|1⟩, Y|1⟩ = -i|0⟩
Matrix: [[0,-i],[i,0]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Load
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Y: [[0, -i], [i, 0]]
# new_a0 = -i * a1 = (a1_imag, -a1_real)
# new_a1 = i * a0 = (-a0_imag, a0_real)
new_a0_real = a1_imag
new_a0_imag = -a1_real
new_a1_real = -a0_imag
new_a1_imag = a0_real
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _pauli_z_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Pauli-Z gate: Z|0⟩ = |0⟩, Z|1⟩ = -|1⟩
Matrix: [[1,0],[0,-1]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Only negate |1⟩ amplitude
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
tl.store(state_real_ptr + i1, -a1_real, mask=mask)
tl.store(state_imag_ptr + i1, -a1_imag, mask=mask)
@triton.jit
def _rotation_z_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
cos_half: tl.constexpr,
sin_half: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Rotation around Z axis: Rz(θ) = [[e^(-iθ/2), 0], [0, e^(iθ/2)]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Load amplitudes
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# e^(-iθ/2) = cos(θ/2) - i*sin(θ/2)
# e^(iθ/2) = cos(θ/2) + i*sin(θ/2)
new_a0_real = a0_real * cos_half + a0_imag * sin_half
new_a0_imag = a0_imag * cos_half - a0_real * sin_half
new_a1_real = a1_real * cos_half - a1_imag * sin_half
new_a1_imag = a1_imag * cos_half + a1_real * sin_half
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _rotation_x_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
cos_half: tl.constexpr,
sin_half: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Rotation around X axis: Rx(θ) = [[cos(θ/2), -i*sin(θ/2)], [-i*sin(θ/2), cos(θ/2)]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Rx = [[cos, -i*sin], [-i*sin, cos]]
# new_a0 = cos*a0 - i*sin*a1 = cos*a0 + sin*a1_imag - i*sin*a1_real
# new_a1 = -i*sin*a0 + cos*a1
new_a0_real = cos_half * a0_real + sin_half * a1_imag
new_a0_imag = cos_half * a0_imag - sin_half * a1_real
new_a1_real = sin_half * a0_imag + cos_half * a1_real
new_a1_imag = -sin_half * a0_real + cos_half * a1_imag
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _rotation_y_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
cos_half: tl.constexpr,
sin_half: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Rotation around Y axis: Ry(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Ry = [[cos, -sin], [sin, cos]]
new_a0_real = cos_half * a0_real - sin_half * a1_real
new_a0_imag = cos_half * a0_imag - sin_half * a1_imag
new_a1_real = sin_half * a0_real + cos_half * a1_real
new_a1_imag = sin_half * a0_imag + cos_half * a1_imag
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _phase_kernel(
state_real_ptr,
state_imag_ptr,
target_qubit,
num_qubits,
cos_phi: tl.constexpr,
sin_phi: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Phase gate: P(φ) = [[1, 0], [0, e^(iφ)]]
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
stride = 1 << target_qubit
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < (state_size // 2)
i0 = (pair_idx // stride) * (2 * stride) + (pair_idx % stride)
i1 = i0 + stride
# Only modify |1⟩ amplitude
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# e^(iφ) = cos(φ) + i*sin(φ)
new_a1_real = a1_real * cos_phi - a1_imag * sin_phi
new_a1_imag = a1_real * sin_phi + a1_imag * cos_phi
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
# =============================================================================
# Triton Kernels for Two-Qubit Gates
# =============================================================================
@triton.jit
def _cnot_kernel(
state_real_ptr,
state_imag_ptr,
control_qubit,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
CNOT (Controlled-NOT) gate.
Flips target qubit when control qubit is |1⟩.
Truth table:
|00⟩ → |00⟩
|01⟩ → |01⟩
|10⟩ → |11⟩
|11⟩ → |10⟩
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
num_pairs = state_size // 2
# Each thread handles one swap pair
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < num_pairs
control_mask = 1 << control_qubit
target_mask = 1 << target_qubit
# Generate indices for pairs where target bit differs
# i0 has target bit = 0, i1 has target bit = 1
i0 = (pair_idx // target_mask) * (2 * target_mask) + (pair_idx % target_mask)
i1 = i0 + target_mask
# Only swap when control bit is 1 in both (it's the same for the pair since only target differs)
control_is_1 = (i0 & control_mask) != 0
# Load both amplitudes
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Swap if control is 1
new_a0_real = tl.where(control_is_1, a1_real, a0_real)
new_a0_imag = tl.where(control_is_1, a1_imag, a0_imag)
new_a1_real = tl.where(control_is_1, a0_real, a1_real)
new_a1_imag = tl.where(control_is_1, a0_imag, a1_imag)
# Store results
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _cz_kernel(
state_real_ptr,
state_imag_ptr,
control_qubit,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Controlled-Z gate.
Applies phase flip (-1) when both control and target are |1⟩.
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < state_size
control_mask = 1 << control_qubit
target_mask = 1 << target_qubit
# Negate amplitude where both control and target bits are 1
is_11 = ((idx & control_mask) != 0) & ((idx & target_mask) != 0)
val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0)
val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0)
new_real = tl.where(is_11, -val_real, val_real)
new_imag = tl.where(is_11, -val_imag, val_imag)
tl.store(state_real_ptr + idx, new_real, mask=mask)
tl.store(state_imag_ptr + idx, new_imag, mask=mask)
@triton.jit
def _swap_kernel(
state_real_ptr,
state_imag_ptr,
qubit_a,
qubit_b,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
SWAP gate - exchanges two qubits.
Swaps amplitudes where bits a and b differ.
|01⟩ ↔ |10⟩ pattern
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < state_size
mask_a = 1 << qubit_a
mask_b = 1 << qubit_b
# Check bits at positions a and b
bit_a = (idx >> qubit_a) & 1
bit_b = (idx >> qubit_b) & 1
# Only process indices where bit_a=1 and bit_b=0 to avoid double-swapping
should_swap = (bit_a == 1) & (bit_b == 0)
partner = idx ^ mask_a ^ mask_b # Flip both bits
# Load current and partner values
val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0)
val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0)
partner_real = tl.load(state_real_ptr + partner, mask=mask, other=0.0)
partner_imag = tl.load(state_imag_ptr + partner, mask=mask, other=0.0)
# For indices where should_swap, store partner's value
# For indices where not should_swap AND partner would swap to us, store partner's value
partner_would_swap = ((partner >> qubit_a) & 1 == 1) & ((partner >> qubit_b) & 1 == 0)
new_real = tl.where(should_swap | partner_would_swap, partner_real, val_real)
new_imag = tl.where(should_swap | partner_would_swap, partner_imag, val_imag)
tl.store(state_real_ptr + idx, new_real, mask=mask)
tl.store(state_imag_ptr + idx, new_imag, mask=mask)
# =============================================================================
# Triton Kernels for Three-Qubit Gates
# =============================================================================
@triton.jit
def _ccx_kernel(
state_real_ptr,
state_imag_ptr,
control1_qubit,
control2_qubit,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Toffoli (CCX) gate - Controlled-Controlled-NOT.
Flips target qubit only when BOTH control qubits are |1⟩.
Truth table (for c1, c2, t):
|000⟩ → |000⟩
|001⟩ → |001⟩
|010⟩ → |010⟩
|011⟩ → |011⟩
|100⟩ → |100⟩
|101⟩ → |101⟩
|110⟩ → |111⟩ (flip when both controls are 1)
|111⟩ → |110⟩
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
num_pairs = state_size // 2
pair_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pair_idx < num_pairs
control1_mask = 1 << control1_qubit
control2_mask = 1 << control2_qubit
target_mask = 1 << target_qubit
# Generate indices for pairs where target bit differs
i0 = (pair_idx // target_mask) * (2 * target_mask) + (pair_idx % target_mask)
i1 = i0 + target_mask
# Only swap when BOTH control bits are 1
control1_is_1 = (i0 & control1_mask) != 0
control2_is_1 = (i0 & control2_mask) != 0
both_controls_1 = control1_is_1 & control2_is_1
# Load amplitudes
a0_real = tl.load(state_real_ptr + i0, mask=mask, other=0.0)
a0_imag = tl.load(state_imag_ptr + i0, mask=mask, other=0.0)
a1_real = tl.load(state_real_ptr + i1, mask=mask, other=0.0)
a1_imag = tl.load(state_imag_ptr + i1, mask=mask, other=0.0)
# Swap if both controls are 1
new_a0_real = tl.where(both_controls_1, a1_real, a0_real)
new_a0_imag = tl.where(both_controls_1, a1_imag, a0_imag)
new_a1_real = tl.where(both_controls_1, a0_real, a1_real)
new_a1_imag = tl.where(both_controls_1, a0_imag, a1_imag)
tl.store(state_real_ptr + i0, new_a0_real, mask=mask)
tl.store(state_imag_ptr + i0, new_a0_imag, mask=mask)
tl.store(state_real_ptr + i1, new_a1_real, mask=mask)
tl.store(state_imag_ptr + i1, new_a1_imag, mask=mask)
@triton.jit
def _cswap_kernel(
state_real_ptr,
state_imag_ptr,
control_qubit,
target1_qubit,
target2_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Fredkin (CSWAP) gate - Controlled-SWAP.
Swaps target1 and target2 only when control qubit is |1⟩.
When control=1: swaps states where t1 and t2 bits differ
|101⟩ ↔ |110⟩ pattern (when control=1)
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < state_size
control_mask = 1 << control_qubit
t1_mask = 1 << target1_qubit
t2_mask = 1 << target2_qubit
# Check bits
control_is_1 = (idx & control_mask) != 0
t1_bit = (idx >> target1_qubit) & 1
t2_bit = (idx >> target2_qubit) & 1
# Only process when control=1 AND t1=1 AND t2=0 (to avoid double-swapping)
should_swap = control_is_1 & (t1_bit == 1) & (t2_bit == 0)
partner = idx ^ t1_mask ^ t2_mask # Flip both target bits
# Load values
val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0)
val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0)
partner_real = tl.load(state_real_ptr + partner, mask=mask, other=0.0)
partner_imag = tl.load(state_imag_ptr + partner, mask=mask, other=0.0)
# Partner would swap to us when: control=1 AND partner_t1=1 AND partner_t2=0
partner_control = (partner & control_mask) != 0
partner_t1 = (partner >> target1_qubit) & 1
partner_t2 = (partner >> target2_qubit) & 1
partner_swaps_to_us = partner_control & (partner_t1 == 1) & (partner_t2 == 0)
new_real = tl.where(should_swap | partner_swaps_to_us, partner_real, val_real)
new_imag = tl.where(should_swap | partner_swaps_to_us, partner_imag, val_imag)
tl.store(state_real_ptr + idx, new_real, mask=mask)
tl.store(state_imag_ptr + idx, new_imag, mask=mask)
@triton.jit
def _ccz_kernel(
state_real_ptr,
state_imag_ptr,
control1_qubit,
control2_qubit,
target_qubit,
num_qubits,
BLOCK_SIZE: tl.constexpr,
):
"""
Controlled-Controlled-Z (CCZ) gate.
Applies phase flip (-1) when all three qubits are |1⟩.
"""
pid = tl.program_id(0)
state_size = 1 << num_qubits
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < state_size
c1_mask = 1 << control1_qubit
c2_mask = 1 << control2_qubit
t_mask = 1 << target_qubit
# Check if all three bits are 1
is_111 = ((idx & c1_mask) != 0) & ((idx & c2_mask) != 0) & ((idx & t_mask) != 0)
val_real = tl.load(state_real_ptr + idx, mask=mask, other=0.0)
val_imag = tl.load(state_imag_ptr + idx, mask=mask, other=0.0)
new_real = tl.where(is_111, -val_real, val_real)
new_imag = tl.where(is_111, -val_imag, val_imag)
tl.store(state_real_ptr + idx, new_real, mask=mask)
tl.store(state_imag_ptr + idx, new_imag, mask=mask)
# =============================================================================
# Python Gate Functions
# =============================================================================
def _get_state_arrays(state: torch.Tensor):
"""Split complex state into real/imag float tensors."""
if state.dtype == torch.complex64:
return state.real.contiguous(), state.imag.contiguous()
elif state.dtype == torch.complex128:
return state.real.float().contiguous(), state.imag.float().contiguous()
else:
raise ValueError(f"Expected complex tensor, got {state.dtype}")
def _combine_state_arrays(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
"""Combine real/imag arrays back to complex."""
return torch.complex(real, imag)
def _get_num_qubits(state: torch.Tensor) -> int:
"""Get number of qubits from state vector size."""
size = state.numel()
num_qubits = int(math.log2(size))
assert 2 ** num_qubits == size, f"State size {size} is not a power of 2"
return num_qubits
def hadamard(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply Hadamard gate to target qubit."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_hadamard_kernel[(grid,)](
real, imag, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def pauli_x(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply Pauli-X (NOT) gate to target qubit."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_pauli_x_kernel[(grid,)](
real, imag, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def pauli_y(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply Pauli-Y gate to target qubit."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_pauli_y_kernel[(grid,)](
real, imag, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def pauli_z(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply Pauli-Z gate to target qubit."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_pauli_z_kernel[(grid,)](
real, imag, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def rotation_x(state: torch.Tensor, target: int, theta: float) -> torch.Tensor:
"""Apply Rx(theta) rotation gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
cos_half = math.cos(theta / 2)
sin_half = math.sin(theta / 2)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_rotation_x_kernel[(grid,)](
real, imag, target, num_qubits,
cos_half, sin_half,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def rotation_y(state: torch.Tensor, target: int, theta: float) -> torch.Tensor:
"""Apply Ry(theta) rotation gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
cos_half = math.cos(theta / 2)
sin_half = math.sin(theta / 2)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_rotation_y_kernel[(grid,)](
real, imag, target, num_qubits,
cos_half, sin_half,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def rotation_z(state: torch.Tensor, target: int, theta: float) -> torch.Tensor:
"""Apply Rz(theta) rotation gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
cos_half = math.cos(theta / 2)
sin_half = math.sin(theta / 2)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_rotation_z_kernel[(grid,)](
real, imag, target, num_qubits,
cos_half, sin_half,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def phase_gate(state: torch.Tensor, target: int, phi: float) -> torch.Tensor:
"""Apply phase gate P(phi) = [[1, 0], [0, e^(i*phi)]]."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
cos_phi = math.cos(phi)
sin_phi = math.sin(phi)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits // 2) + BLOCK_SIZE - 1) // BLOCK_SIZE
_phase_kernel[(grid,)](
real, imag, target, num_qubits,
cos_phi, sin_phi,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def t_gate(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply T gate (π/4 phase)."""
return phase_gate(state, target, math.pi / 4)
def s_gate(state: torch.Tensor, target: int) -> torch.Tensor:
"""Apply S gate (π/2 phase)."""
return phase_gate(state, target, math.pi / 2)
def cnot(state: torch.Tensor, control: int, target: int) -> torch.Tensor:
"""Apply CNOT (controlled-NOT) gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
num_pairs = 2 ** num_qubits // 2
grid = (num_pairs + BLOCK_SIZE - 1) // BLOCK_SIZE
_cnot_kernel[(grid,)](
real, imag, control, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def cz(state: torch.Tensor, control: int, target: int) -> torch.Tensor:
"""Apply Controlled-Z gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE
_cz_kernel[(grid,)](
real, imag, control, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
def swap(state: torch.Tensor, qubit_a: int, qubit_b: int) -> torch.Tensor:
"""Apply SWAP gate."""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE
_swap_kernel[(grid,)](
real, imag, qubit_a, qubit_b, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
# =============================================================================
# Three-Qubit Gate Functions
# =============================================================================
def ccx(state: torch.Tensor, control1: int, control2: int, target: int) -> torch.Tensor:
"""
Apply Toffoli (CCX) gate - Controlled-Controlled-NOT.
Flips target qubit only when both control qubits are |1⟩.
This is a universal gate for classical computation.
Args:
state: Quantum state vector
control1: First control qubit index
control2: Second control qubit index
target: Target qubit index
"""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
num_pairs = 2 ** num_qubits // 2
grid = (num_pairs + BLOCK_SIZE - 1) // BLOCK_SIZE
_ccx_kernel[(grid,)](
real, imag, control1, control2, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
# Alias for Toffoli gate
toffoli = ccx
def cswap(state: torch.Tensor, control: int, target1: int, target2: int) -> torch.Tensor:
"""
Apply Fredkin (CSWAP) gate - Controlled-SWAP.
Swaps target1 and target2 qubits only when control qubit is |1⟩.
This is a universal gate for reversible computation.
Args:
state: Quantum state vector
control: Control qubit index
target1: First target qubit index
target2: Second target qubit index
"""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE
_cswap_kernel[(grid,)](
real, imag, control, target1, target2, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
# Alias for Fredkin gate
fredkin = cswap
def ccz(state: torch.Tensor, control1: int, control2: int, target: int) -> torch.Tensor:
"""
Apply Controlled-Controlled-Z (CCZ) gate.
Applies phase flip (-1) when all three qubits are |1⟩.
CCZ is equivalent to CCX with Hadamard gates on target.
Args:
state: Quantum state vector
control1: First control qubit index
control2: Second control qubit index
target: Target qubit index
"""
num_qubits = _get_num_qubits(state)
real, imag = _get_state_arrays(state)
BLOCK_SIZE = 256
grid = ((2 ** num_qubits) + BLOCK_SIZE - 1) // BLOCK_SIZE
_ccz_kernel[(grid,)](
real, imag, control1, control2, target, num_qubits,
BLOCK_SIZE=BLOCK_SIZE,
num_ctas=2,
)
state.real.copy_(real)
state.imag.copy_(imag)
return state
# Update GATE_MATRICES with three-qubit gates
GATE_MATRICES['CCX'] = torch.tensor([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
], dtype=torch.complex64)
GATE_MATRICES['CSWAP'] = torch.tensor([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
], dtype=torch.complex64)
GATE_MATRICES['CCZ'] = torch.tensor([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, -1],
], dtype=torch.complex64)