File size: 10,852 Bytes
662c9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""
Q-GPT: Quantum-Enhanced GPT with Confidence Estimation
A quantum neural network head that estimates response confidence.

Author: squ11z1
"""

import torch
import torch.nn as nn
import numpy as np

try:
    import pennylane as qml
    PENNYLANE_AVAILABLE = True
except ImportError:
    PENNYLANE_AVAILABLE = False
    print("Warning: PennyLane not installed. Using classical fallback.")


class QuantumCircuit:
    """Variational Quantum Circuit for confidence estimation."""
    
    def __init__(self, n_qubits: int = 4, n_layers: int = 3):
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        
        if PENNYLANE_AVAILABLE:
            self.dev = qml.device("default.qubit", wires=n_qubits)
            self.circuit = qml.QNode(self._quantum_circuit, self.dev, interface="torch")
        
    def _quantum_circuit(self, inputs, weights):
        """
        Variational quantum circuit.
        
        Args:
            inputs: Input features [n_qubits]
            weights: Trainable parameters [n_layers, n_qubits, 3]
        """
        # Encode classical data into quantum states
        for i in range(self.n_qubits):
            qml.RY(inputs[i], wires=i)
            qml.RZ(inputs[i], wires=i)
        
        # Variational layers
        for layer in range(self.n_layers):
            # Rotation gates
            for i in range(self.n_qubits):
                qml.Rot(weights[layer, i, 0], 
                       weights[layer, i, 1], 
                       weights[layer, i, 2], wires=i)
            
            # Entanglement (CNOT ladder)
            for i in range(self.n_qubits - 1):
                qml.CNOT(wires=[i, i + 1])
            
            # Circular entanglement
            if self.n_qubits > 2:
                qml.CNOT(wires=[self.n_qubits - 1, 0])
        
        # Measure expectation values
        return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
    
    def forward(self, inputs, weights):
        """Execute quantum circuit."""
        if PENNYLANE_AVAILABLE:
            return self.circuit(inputs, weights)
        else:
            # Classical fallback: simple tanh transformation
            return torch.tanh(inputs @ weights.mean(dim=(0, 2)))


class QuantumHead(nn.Module):
    """
    Quantum-enhanced confidence estimation head for GPT.
    
    Takes hidden states from the last layer and outputs:
    - confidence: Estimated confidence in the response [0, 1]
    - uncertainty: Quantum-derived uncertainty measure
    """
    
    def __init__(
        self,
        hidden_size: int = 2880,  # GPT-OSS hidden size
        n_qubits: int = 4,
        n_layers: int = 3,
        intermediate_size: int = 64,
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        
        # Classical preprocessing: compress hidden states
        self.pre_quantum = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            nn.LayerNorm(intermediate_size),
            nn.GELU(),
            nn.Linear(intermediate_size, n_qubits),
            nn.Tanh(),  # Normalize to [-1, 1] for quantum encoding
        )
        
        # Quantum circuit
        self.quantum = QuantumCircuit(n_qubits, n_layers)
        
        # Quantum weights (trainable)
        self.quantum_weights = nn.Parameter(
            torch.randn(n_layers, n_qubits, 3) * 0.1
        )
        
        # Post-quantum processing
        self.post_quantum = nn.Sequential(
            nn.Linear(n_qubits, intermediate_size),
            nn.GELU(),
            nn.Linear(intermediate_size, 2),  # [confidence, uncertainty]
        )
        
        # Output heads
        self.confidence_activation = nn.Sigmoid()
        self.uncertainty_activation = nn.Softplus()
        
    def forward(self, hidden_states: torch.Tensor) -> dict:
        """
        Compute confidence and uncertainty from hidden states.
        
        Args:
            hidden_states: Last layer hidden states [batch, seq_len, hidden_size]
                          or pooled representation [batch, hidden_size]
        
        Returns:
            dict with 'confidence' and 'uncertainty' tensors
        """
        # Pool if sequence dimension exists
        if hidden_states.dim() == 3:
            # Use last token representation
            hidden_states = hidden_states[:, -1, :]
        
        batch_size = hidden_states.size(0)
        
        # Preprocess
        quantum_input = self.pre_quantum(hidden_states)  # [batch, n_qubits]
        
        # Process through quantum circuit (per sample)
        quantum_outputs = []
        for i in range(batch_size):
            qout = self.quantum.forward(
                quantum_input[i], 
                self.quantum_weights
            )
            if isinstance(qout, list):
                qout = torch.stack(qout)
            quantum_outputs.append(qout)
        
        quantum_output = torch.stack(quantum_outputs)  # [batch, n_qubits]
        
        # Post-process
        output = self.post_quantum(quantum_output)
        
        confidence = self.confidence_activation(output[:, 0])
        uncertainty = self.uncertainty_activation(output[:, 1])
        
        return {
            "confidence": confidence,
            "uncertainty": uncertainty,
            "should_refuse": confidence < 0.3,  # Low confidence = should refuse
        }
    
    def get_interpretable_confidence(self, confidence: torch.Tensor) -> str:
        """Convert confidence score to human-readable label."""
        conf = confidence.item() if confidence.dim() == 0 else confidence.mean().item()
        
        if conf >= 0.9:
            return "very high"
        elif conf >= 0.7:
            return "high"
        elif conf >= 0.5:
            return "moderate"
        elif conf >= 0.3:
            return "low"
        else:
            return "very low (consider refusing)"


class QGPT(nn.Module):
    """
    Q-GPT: GPT with Quantum Confidence Head
    
    Wraps any HuggingFace GPT model and adds quantum confidence estimation.
    """
    
    def __init__(self, base_model, quantum_head: QuantumHead = None):
        super().__init__()
        self.base_model = base_model
        
        # Get hidden size from model config
        if hasattr(base_model.config, 'hidden_size'):
            hidden_size = base_model.config.hidden_size
        elif hasattr(base_model.config, 'd_model'):
            hidden_size = base_model.config.d_model
        else:
            hidden_size = 2880  # GPT-OSS default
        
        self.quantum_head = quantum_head or QuantumHead(hidden_size=hidden_size)
        
    def forward(self, input_ids, attention_mask=None, **kwargs):
        """Forward pass with confidence estimation."""
        # Get base model outputs with hidden states
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs
        )
        
        # Get last layer hidden states
        hidden_states = outputs.hidden_states[-1]
        
        # Compute quantum confidence
        confidence_output = self.quantum_head(hidden_states)
        
        # Add to outputs
        outputs.confidence = confidence_output["confidence"]
        outputs.uncertainty = confidence_output["uncertainty"]
        outputs.should_refuse = confidence_output["should_refuse"]
        
        return outputs
    
    def generate_with_confidence(
        self, 
        input_ids, 
        attention_mask=None,
        max_new_tokens=256,
        **kwargs
    ):
        """Generate text and return confidence score."""
        # Generate
        outputs = self.base_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            output_hidden_states=True,
            return_dict_in_generate=True,
            **kwargs
        )
        
        # Get hidden states from last generation step
        if hasattr(outputs, 'hidden_states') and outputs.hidden_states:
            last_hidden = outputs.hidden_states[-1][-1]  # Last layer, last step
        else:
            # Fallback: run forward pass on generated sequence
            with torch.no_grad():
                model_outputs = self.base_model(
                    outputs.sequences,
                    output_hidden_states=True
                )
                last_hidden = model_outputs.hidden_states[-1]
        
        # Compute confidence
        confidence_output = self.quantum_head(last_hidden)
        
        return {
            "sequences": outputs.sequences,
            "confidence": confidence_output["confidence"],
            "uncertainty": confidence_output["uncertainty"],
            "should_refuse": confidence_output["should_refuse"],
            "confidence_label": self.quantum_head.get_interpretable_confidence(
                confidence_output["confidence"]
            ),
        }


def load_qgpt(
    model_name: str = "squ11z1/gpt-oss-9b-reasoning",
    quantum_head_path: str = None,
    device: str = "auto",
    torch_dtype = None,
    **kwargs
):
    """
    Load Q-GPT model with quantum head.
    
    Args:
        model_name: HuggingFace model name or path
        quantum_head_path: Path to trained quantum head weights
        device: Device to load model on
        torch_dtype: Model dtype (e.g., torch.bfloat16)
    
    Returns:
        QGPT model and tokenizer
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    if torch_dtype is None:
        torch_dtype = torch.bfloat16
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype,
        device_map=device,
        trust_remote_code=True,
        **kwargs
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        **kwargs
    )
    
    # Create Q-GPT
    model = QGPT(base_model)
    
    # Load quantum head weights if provided
    if quantum_head_path:
        state_dict = torch.load(quantum_head_path, map_location="cpu")
        model.quantum_head.load_state_dict(state_dict)
        print(f"Loaded quantum head from {quantum_head_path}")
    
    return model, tokenizer


if __name__ == "__main__":
    # Quick test
    print("Testing QuantumHead...")
    
    head = QuantumHead(hidden_size=2880)
    dummy_input = torch.randn(2, 2880)  # Batch of 2
    
    output = head(dummy_input)
    print(f"Confidence: {output['confidence']}")
    print(f"Uncertainty: {output['uncertainty']}")
    print(f"Should refuse: {output['should_refuse']}")
    
    print("\n✓ QuantumHead test passed!")