Upload qads/quantum/kernels.py
Browse files- qads/quantum/kernels.py +92 -0
qads/quantum/kernels.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quantum Kernel Attention for similarity measurement."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import pennylane as qml
|
| 7 |
+
from pennylane import numpy as pnp
|
| 8 |
+
HAS_PENNYLANE = True
|
| 9 |
+
except ImportError:
|
| 10 |
+
HAS_PENNYLANE = False
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class QuantumKernelAttention:
|
| 14 |
+
"""
|
| 15 |
+
Quantum kernel-based attention mechanism.
|
| 16 |
+
Computes K(x, x') = |<φ(x)|φ(x')>|² using quantum circuits.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: Any):
|
| 20 |
+
self.config = config
|
| 21 |
+
self.n_qubits = config.n_qubits
|
| 22 |
+
self.n_layers = config.n_layers
|
| 23 |
+
self.device = None
|
| 24 |
+
|
| 25 |
+
if HAS_PENNYLANE:
|
| 26 |
+
self.device = qml.device("default.qubit", wires=self.n_qubits, shots=1000)
|
| 27 |
+
self.params = pnp.random.uniform(0, 2*np.pi, (self.n_layers, self.n_qubits, 3))
|
| 28 |
+
|
| 29 |
+
def compute(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
|
| 30 |
+
"""Compute quantum kernel attention scores."""
|
| 31 |
+
if HAS_PENNYLANE and self.device is not None:
|
| 32 |
+
try:
|
| 33 |
+
return self._quantum_kernel(query, keys)
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
return self._classical_attention(query, keys)
|
| 37 |
+
|
| 38 |
+
def _quantum_kernel(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
|
| 39 |
+
"""Compute quantum kernel K(x, x') = |<φ(x)|φ(x')>|²."""
|
| 40 |
+
@qml.qnode(self.device)
|
| 41 |
+
def kernel_circuit(x1, x2, params):
|
| 42 |
+
# Encode first state
|
| 43 |
+
for i in range(min(self.n_qubits, len(x1))):
|
| 44 |
+
qml.RY(np.arcsin(np.clip(x1[i], -0.999, 0.999)), wires=i)
|
| 45 |
+
qml.RZ(x1[i] * np.pi, wires=i)
|
| 46 |
+
|
| 47 |
+
# Variational layers
|
| 48 |
+
for layer in range(self.n_layers):
|
| 49 |
+
for i in range(self.n_qubits):
|
| 50 |
+
qml.RX(params[layer, i, 0], wires=i)
|
| 51 |
+
qml.RY(params[layer, i, 1], wires=i)
|
| 52 |
+
qml.RZ(params[layer, i, 2], wires=i)
|
| 53 |
+
for i in range(self.n_qubits - 1):
|
| 54 |
+
qml.CNOT(wires=[i, i+1])
|
| 55 |
+
|
| 56 |
+
# Inverse encoding of second state (swap test)
|
| 57 |
+
for i in range(min(self.n_qubits, len(x2))):
|
| 58 |
+
qml.RZ(-x2[i] * np.pi, wires=i)
|
| 59 |
+
qml.RY(-np.arcsin(np.clip(x2[i], -0.999, 0.999)), wires=i)
|
| 60 |
+
|
| 61 |
+
# SWAP test measurement
|
| 62 |
+
return qml.expval(qml.PauliZ(0))
|
| 63 |
+
|
| 64 |
+
scores = []
|
| 65 |
+
q_padded = np.zeros(self.n_qubits)
|
| 66 |
+
q_padded[:min(len(query), self.n_qubits)] = query[:self.n_qubits]
|
| 67 |
+
q_padded = q_padded / (np.max(np.abs(q_padded)) + 1e-10) if np.max(np.abs(q_padded)) > 0 else q_padded
|
| 68 |
+
|
| 69 |
+
for key in keys:
|
| 70 |
+
k_padded = np.zeros(self.n_qubits)
|
| 71 |
+
k_padded[:min(len(key), self.n_qubits)] = key[:self.n_qubits]
|
| 72 |
+
k_padded = k_padded / (np.max(np.abs(k_padded)) + 1e-10) if np.max(np.abs(k_padded)) > 0 else k_padded
|
| 73 |
+
|
| 74 |
+
overlap = kernel_circuit(q_padded, k_padded, self.params)
|
| 75 |
+
# Probability is (1 + overlap)/2
|
| 76 |
+
scores.append((1.0 + float(overlap)) / 2.0)
|
| 77 |
+
|
| 78 |
+
scores = np.array(scores)
|
| 79 |
+
# Normalize
|
| 80 |
+
if scores.sum() > 0:
|
| 81 |
+
scores = scores / scores.sum()
|
| 82 |
+
return scores
|
| 83 |
+
|
| 84 |
+
def _classical_attention(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""Standard scaled dot-product attention fallback."""
|
| 86 |
+
q = query / (np.linalg.norm(query) + 1e-10)
|
| 87 |
+
k = keys / (np.linalg.norm(keys, axis=1, keepdims=True) + 1e-10)
|
| 88 |
+
scores = np.dot(k, q)
|
| 89 |
+
scores = scores / np.sqrt(len(query))
|
| 90 |
+
exp_scores = np.exp(scores - np.max(scores))
|
| 91 |
+
attention = exp_scores / (exp_scores.sum() + 1e-10)
|
| 92 |
+
return attention
|