Premchan369 commited on
Commit
fd385e8
·
verified ·
1 Parent(s): 9d5a297

Upload qads/quantum/kernels.py

Browse files
Files changed (1) hide show
  1. qads/quantum/kernels.py +92 -0
qads/quantum/kernels.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quantum Kernel Attention for similarity measurement."""
2
+ import numpy as np
3
+ from typing import Any
4
+
5
+ try:
6
+ import pennylane as qml
7
+ from pennylane import numpy as pnp
8
+ HAS_PENNYLANE = True
9
+ except ImportError:
10
+ HAS_PENNYLANE = False
11
+
12
+
13
+ class QuantumKernelAttention:
14
+ """
15
+ Quantum kernel-based attention mechanism.
16
+ Computes K(x, x') = |<φ(x)|φ(x')>|² using quantum circuits.
17
+ """
18
+
19
+ def __init__(self, config: Any):
20
+ self.config = config
21
+ self.n_qubits = config.n_qubits
22
+ self.n_layers = config.n_layers
23
+ self.device = None
24
+
25
+ if HAS_PENNYLANE:
26
+ self.device = qml.device("default.qubit", wires=self.n_qubits, shots=1000)
27
+ self.params = pnp.random.uniform(0, 2*np.pi, (self.n_layers, self.n_qubits, 3))
28
+
29
+ def compute(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
30
+ """Compute quantum kernel attention scores."""
31
+ if HAS_PENNYLANE and self.device is not None:
32
+ try:
33
+ return self._quantum_kernel(query, keys)
34
+ except Exception:
35
+ pass
36
+ return self._classical_attention(query, keys)
37
+
38
+ def _quantum_kernel(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
39
+ """Compute quantum kernel K(x, x') = |<φ(x)|φ(x')>|²."""
40
+ @qml.qnode(self.device)
41
+ def kernel_circuit(x1, x2, params):
42
+ # Encode first state
43
+ for i in range(min(self.n_qubits, len(x1))):
44
+ qml.RY(np.arcsin(np.clip(x1[i], -0.999, 0.999)), wires=i)
45
+ qml.RZ(x1[i] * np.pi, wires=i)
46
+
47
+ # Variational layers
48
+ for layer in range(self.n_layers):
49
+ for i in range(self.n_qubits):
50
+ qml.RX(params[layer, i, 0], wires=i)
51
+ qml.RY(params[layer, i, 1], wires=i)
52
+ qml.RZ(params[layer, i, 2], wires=i)
53
+ for i in range(self.n_qubits - 1):
54
+ qml.CNOT(wires=[i, i+1])
55
+
56
+ # Inverse encoding of second state (swap test)
57
+ for i in range(min(self.n_qubits, len(x2))):
58
+ qml.RZ(-x2[i] * np.pi, wires=i)
59
+ qml.RY(-np.arcsin(np.clip(x2[i], -0.999, 0.999)), wires=i)
60
+
61
+ # SWAP test measurement
62
+ return qml.expval(qml.PauliZ(0))
63
+
64
+ scores = []
65
+ q_padded = np.zeros(self.n_qubits)
66
+ q_padded[:min(len(query), self.n_qubits)] = query[:self.n_qubits]
67
+ q_padded = q_padded / (np.max(np.abs(q_padded)) + 1e-10) if np.max(np.abs(q_padded)) > 0 else q_padded
68
+
69
+ for key in keys:
70
+ k_padded = np.zeros(self.n_qubits)
71
+ k_padded[:min(len(key), self.n_qubits)] = key[:self.n_qubits]
72
+ k_padded = k_padded / (np.max(np.abs(k_padded)) + 1e-10) if np.max(np.abs(k_padded)) > 0 else k_padded
73
+
74
+ overlap = kernel_circuit(q_padded, k_padded, self.params)
75
+ # Probability is (1 + overlap)/2
76
+ scores.append((1.0 + float(overlap)) / 2.0)
77
+
78
+ scores = np.array(scores)
79
+ # Normalize
80
+ if scores.sum() > 0:
81
+ scores = scores / scores.sum()
82
+ return scores
83
+
84
+ def _classical_attention(self, query: np.ndarray, keys: np.ndarray) -> np.ndarray:
85
+ """Standard scaled dot-product attention fallback."""
86
+ q = query / (np.linalg.norm(query) + 1e-10)
87
+ k = keys / (np.linalg.norm(keys, axis=1, keepdims=True) + 1e-10)
88
+ scores = np.dot(k, q)
89
+ scores = scores / np.sqrt(len(query))
90
+ exp_scores = np.exp(scores - np.max(scores))
91
+ attention = exp_scores / (exp_scores.sum() + 1e-10)
92
+ return attention