squ11z1 commited on
Commit
662c9ff
·
verified ·
1 Parent(s): 219ff75

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +109 -0
  2. __init__.py +18 -0
  3. quantum_head.py +335 -0
  4. train.py +246 -0
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Q-GPT: Quantum-Enhanced GPT
2
+
3
+ A quantum neural network head that adds confidence estimation to GPT models.
4
+
5
+ ## Features
6
+
7
+ - 🔮 **Variational Quantum Circuit** - Uses PennyLane for true quantum computing simulation
8
+ - 📊 **Confidence Estimation** - Estimates how confident the model is in its response
9
+ - 🚫 **Refusal Detection** - Identifies when the model should refuse to answer
10
+ - ⚡ **Classical Fallback** - Works without PennyLane using classical approximation
11
+
12
+ ## Installation
13
+
14
+ ```bash
15
+ pip install pennylane torch transformers
16
+ ```
17
+
18
+ ## Usage
19
+
20
+ ### Quick Start
21
+
22
+ ```python
23
+ from quantum_head import load_qgpt
24
+
25
+ # Load Q-GPT
26
+ model, tokenizer = load_qgpt(
27
+ "squ11z1/gpt-oss-9b-reasoning",
28
+ torch_dtype="auto",
29
+ device="auto",
30
+ )
31
+
32
+ # Generate with confidence
33
+ inputs = tokenizer("What is 2 + 2?", return_tensors="pt").to(model.device)
34
+ outputs = model.generate_with_confidence(inputs.input_ids, max_new_tokens=50)
35
+
36
+ print(f"Response: {tokenizer.decode(outputs['sequences'][0])}")
37
+ print(f"Confidence: {outputs['confidence_label']}") # e.g., "high"
38
+ print(f"Should refuse: {outputs['should_refuse']}")
39
+ ```
40
+
41
+ ### Just the Quantum Head
42
+
43
+ ```python
44
+ from quantum_head import QuantumHead
45
+ import torch
46
+
47
+ # Create quantum head
48
+ head = QuantumHead(hidden_size=2880) # Match your model's hidden size
49
+
50
+ # Forward pass with hidden states
51
+ hidden_states = torch.randn(1, 2880) # From your model
52
+ output = head(hidden_states)
53
+
54
+ print(f"Confidence: {output['confidence'].item():.2f}")
55
+ print(f"Uncertainty: {output['uncertainty'].item():.2f}")
56
+ ```
57
+
58
+ ### Training
59
+
60
+ ```bash
61
+ # Create synthetic training data
62
+ python train.py --model squ11z1/gpt-oss-9b-reasoning --create-data --data train.jsonl
63
+
64
+ # Train quantum head
65
+ python train.py --model squ11z1/gpt-oss-9b-reasoning --data train.jsonl --epochs 3
66
+ ```
67
+
68
+ ## Architecture
69
+
70
+ ```
71
+ Hidden States → [Classical Compression] → [Quantum Circuit] → [Post-Processing] → Confidence
72
+ ↓ ↓ ↓ ↓
73
+ [B, H] [B, n_qubits] [B, n_qubits] [B, 2]
74
+
75
+ confidence + uncertainty
76
+ ```
77
+
78
+ ### Quantum Circuit
79
+
80
+ ```
81
+ |0⟩ ─ RY(x₀) ─ RZ(x₀) ─ Rot(θ) ─ ●─────── Rot(θ) ─ ... ─ ⟨Z⟩
82
+
83
+ |0⟩ ─ RY(x₁) ─ RZ(x₁) ─ Rot(θ) ─ ⊕ ─ ●─── Rot(θ) ─ ... ─ ⟨Z⟩
84
+
85
+ |0⟩ ─ RY(x₂) ─ RZ(x₂) ─ Rot(θ) ───── ⊕ ─ ●─ Rot(θ) ─ ... ─ ⟨Z⟩
86
+
87
+ |0⟩ ─ RY(x₃) ─ RZ(x₃) ─ Rot(θ) ───────── ⊕ ─ Rot(θ) ─ ... ─ ⟨Z⟩
88
+ ```
89
+
90
+ ## Files
91
+
92
+ - `quantum_head.py` - Main implementation (QuantumHead, QGPT, load_qgpt)
93
+ - `train.py` - Training script for quantum head
94
+ - `quantum_head.pt` - Pre-trained weights (after training)
95
+
96
+ ## Citation
97
+
98
+ ```bibtex
99
+ @misc{qgpt2026,
100
+ title={Q-GPT: Quantum-Enhanced Confidence Estimation for Language Models},
101
+ author={squ11z1},
102
+ year={2026},
103
+ url={https://huggingface.co/squ11z1/Q-GPT}
104
+ }
105
+ ```
106
+
107
+ ## License
108
+
109
+ Apache 2.0
__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Q-GPT: Quantum-Enhanced GPT with Confidence Estimation"""
2
+
3
+ from .quantum_head import (
4
+ QuantumHead,
5
+ QuantumCircuit,
6
+ QGPT,
7
+ load_qgpt,
8
+ )
9
+
10
+ __version__ = "1.0.0"
11
+ __author__ = "squ11z1"
12
+
13
+ __all__ = [
14
+ "QuantumHead",
15
+ "QuantumCircuit",
16
+ "QGPT",
17
+ "load_qgpt",
18
+ ]
quantum_head.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Q-GPT: Quantum-Enhanced GPT with Confidence Estimation
3
+ A quantum neural network head that estimates response confidence.
4
+
5
+ Author: squ11z1
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+
12
+ try:
13
+ import pennylane as qml
14
+ PENNYLANE_AVAILABLE = True
15
+ except ImportError:
16
+ PENNYLANE_AVAILABLE = False
17
+ print("Warning: PennyLane not installed. Using classical fallback.")
18
+
19
+
20
+ class QuantumCircuit:
21
+ """Variational Quantum Circuit for confidence estimation."""
22
+
23
+ def __init__(self, n_qubits: int = 4, n_layers: int = 3):
24
+ self.n_qubits = n_qubits
25
+ self.n_layers = n_layers
26
+
27
+ if PENNYLANE_AVAILABLE:
28
+ self.dev = qml.device("default.qubit", wires=n_qubits)
29
+ self.circuit = qml.QNode(self._quantum_circuit, self.dev, interface="torch")
30
+
31
+ def _quantum_circuit(self, inputs, weights):
32
+ """
33
+ Variational quantum circuit.
34
+
35
+ Args:
36
+ inputs: Input features [n_qubits]
37
+ weights: Trainable parameters [n_layers, n_qubits, 3]
38
+ """
39
+ # Encode classical data into quantum states
40
+ for i in range(self.n_qubits):
41
+ qml.RY(inputs[i], wires=i)
42
+ qml.RZ(inputs[i], wires=i)
43
+
44
+ # Variational layers
45
+ for layer in range(self.n_layers):
46
+ # Rotation gates
47
+ for i in range(self.n_qubits):
48
+ qml.Rot(weights[layer, i, 0],
49
+ weights[layer, i, 1],
50
+ weights[layer, i, 2], wires=i)
51
+
52
+ # Entanglement (CNOT ladder)
53
+ for i in range(self.n_qubits - 1):
54
+ qml.CNOT(wires=[i, i + 1])
55
+
56
+ # Circular entanglement
57
+ if self.n_qubits > 2:
58
+ qml.CNOT(wires=[self.n_qubits - 1, 0])
59
+
60
+ # Measure expectation values
61
+ return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
62
+
63
+ def forward(self, inputs, weights):
64
+ """Execute quantum circuit."""
65
+ if PENNYLANE_AVAILABLE:
66
+ return self.circuit(inputs, weights)
67
+ else:
68
+ # Classical fallback: simple tanh transformation
69
+ return torch.tanh(inputs @ weights.mean(dim=(0, 2)))
70
+
71
+
72
+ class QuantumHead(nn.Module):
73
+ """
74
+ Quantum-enhanced confidence estimation head for GPT.
75
+
76
+ Takes hidden states from the last layer and outputs:
77
+ - confidence: Estimated confidence in the response [0, 1]
78
+ - uncertainty: Quantum-derived uncertainty measure
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2880, # GPT-OSS hidden size
84
+ n_qubits: int = 4,
85
+ n_layers: int = 3,
86
+ intermediate_size: int = 64,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.hidden_size = hidden_size
91
+ self.n_qubits = n_qubits
92
+ self.n_layers = n_layers
93
+
94
+ # Classical preprocessing: compress hidden states
95
+ self.pre_quantum = nn.Sequential(
96
+ nn.Linear(hidden_size, intermediate_size),
97
+ nn.LayerNorm(intermediate_size),
98
+ nn.GELU(),
99
+ nn.Linear(intermediate_size, n_qubits),
100
+ nn.Tanh(), # Normalize to [-1, 1] for quantum encoding
101
+ )
102
+
103
+ # Quantum circuit
104
+ self.quantum = QuantumCircuit(n_qubits, n_layers)
105
+
106
+ # Quantum weights (trainable)
107
+ self.quantum_weights = nn.Parameter(
108
+ torch.randn(n_layers, n_qubits, 3) * 0.1
109
+ )
110
+
111
+ # Post-quantum processing
112
+ self.post_quantum = nn.Sequential(
113
+ nn.Linear(n_qubits, intermediate_size),
114
+ nn.GELU(),
115
+ nn.Linear(intermediate_size, 2), # [confidence, uncertainty]
116
+ )
117
+
118
+ # Output heads
119
+ self.confidence_activation = nn.Sigmoid()
120
+ self.uncertainty_activation = nn.Softplus()
121
+
122
+ def forward(self, hidden_states: torch.Tensor) -> dict:
123
+ """
124
+ Compute confidence and uncertainty from hidden states.
125
+
126
+ Args:
127
+ hidden_states: Last layer hidden states [batch, seq_len, hidden_size]
128
+ or pooled representation [batch, hidden_size]
129
+
130
+ Returns:
131
+ dict with 'confidence' and 'uncertainty' tensors
132
+ """
133
+ # Pool if sequence dimension exists
134
+ if hidden_states.dim() == 3:
135
+ # Use last token representation
136
+ hidden_states = hidden_states[:, -1, :]
137
+
138
+ batch_size = hidden_states.size(0)
139
+
140
+ # Preprocess
141
+ quantum_input = self.pre_quantum(hidden_states) # [batch, n_qubits]
142
+
143
+ # Process through quantum circuit (per sample)
144
+ quantum_outputs = []
145
+ for i in range(batch_size):
146
+ qout = self.quantum.forward(
147
+ quantum_input[i],
148
+ self.quantum_weights
149
+ )
150
+ if isinstance(qout, list):
151
+ qout = torch.stack(qout)
152
+ quantum_outputs.append(qout)
153
+
154
+ quantum_output = torch.stack(quantum_outputs) # [batch, n_qubits]
155
+
156
+ # Post-process
157
+ output = self.post_quantum(quantum_output)
158
+
159
+ confidence = self.confidence_activation(output[:, 0])
160
+ uncertainty = self.uncertainty_activation(output[:, 1])
161
+
162
+ return {
163
+ "confidence": confidence,
164
+ "uncertainty": uncertainty,
165
+ "should_refuse": confidence < 0.3, # Low confidence = should refuse
166
+ }
167
+
168
+ def get_interpretable_confidence(self, confidence: torch.Tensor) -> str:
169
+ """Convert confidence score to human-readable label."""
170
+ conf = confidence.item() if confidence.dim() == 0 else confidence.mean().item()
171
+
172
+ if conf >= 0.9:
173
+ return "very high"
174
+ elif conf >= 0.7:
175
+ return "high"
176
+ elif conf >= 0.5:
177
+ return "moderate"
178
+ elif conf >= 0.3:
179
+ return "low"
180
+ else:
181
+ return "very low (consider refusing)"
182
+
183
+
184
+ class QGPT(nn.Module):
185
+ """
186
+ Q-GPT: GPT with Quantum Confidence Head
187
+
188
+ Wraps any HuggingFace GPT model and adds quantum confidence estimation.
189
+ """
190
+
191
+ def __init__(self, base_model, quantum_head: QuantumHead = None):
192
+ super().__init__()
193
+ self.base_model = base_model
194
+
195
+ # Get hidden size from model config
196
+ if hasattr(base_model.config, 'hidden_size'):
197
+ hidden_size = base_model.config.hidden_size
198
+ elif hasattr(base_model.config, 'd_model'):
199
+ hidden_size = base_model.config.d_model
200
+ else:
201
+ hidden_size = 2880 # GPT-OSS default
202
+
203
+ self.quantum_head = quantum_head or QuantumHead(hidden_size=hidden_size)
204
+
205
+ def forward(self, input_ids, attention_mask=None, **kwargs):
206
+ """Forward pass with confidence estimation."""
207
+ # Get base model outputs with hidden states
208
+ outputs = self.base_model(
209
+ input_ids=input_ids,
210
+ attention_mask=attention_mask,
211
+ output_hidden_states=True,
212
+ **kwargs
213
+ )
214
+
215
+ # Get last layer hidden states
216
+ hidden_states = outputs.hidden_states[-1]
217
+
218
+ # Compute quantum confidence
219
+ confidence_output = self.quantum_head(hidden_states)
220
+
221
+ # Add to outputs
222
+ outputs.confidence = confidence_output["confidence"]
223
+ outputs.uncertainty = confidence_output["uncertainty"]
224
+ outputs.should_refuse = confidence_output["should_refuse"]
225
+
226
+ return outputs
227
+
228
+ def generate_with_confidence(
229
+ self,
230
+ input_ids,
231
+ attention_mask=None,
232
+ max_new_tokens=256,
233
+ **kwargs
234
+ ):
235
+ """Generate text and return confidence score."""
236
+ # Generate
237
+ outputs = self.base_model.generate(
238
+ input_ids=input_ids,
239
+ attention_mask=attention_mask,
240
+ max_new_tokens=max_new_tokens,
241
+ output_hidden_states=True,
242
+ return_dict_in_generate=True,
243
+ **kwargs
244
+ )
245
+
246
+ # Get hidden states from last generation step
247
+ if hasattr(outputs, 'hidden_states') and outputs.hidden_states:
248
+ last_hidden = outputs.hidden_states[-1][-1] # Last layer, last step
249
+ else:
250
+ # Fallback: run forward pass on generated sequence
251
+ with torch.no_grad():
252
+ model_outputs = self.base_model(
253
+ outputs.sequences,
254
+ output_hidden_states=True
255
+ )
256
+ last_hidden = model_outputs.hidden_states[-1]
257
+
258
+ # Compute confidence
259
+ confidence_output = self.quantum_head(last_hidden)
260
+
261
+ return {
262
+ "sequences": outputs.sequences,
263
+ "confidence": confidence_output["confidence"],
264
+ "uncertainty": confidence_output["uncertainty"],
265
+ "should_refuse": confidence_output["should_refuse"],
266
+ "confidence_label": self.quantum_head.get_interpretable_confidence(
267
+ confidence_output["confidence"]
268
+ ),
269
+ }
270
+
271
+
272
+ def load_qgpt(
273
+ model_name: str = "squ11z1/gpt-oss-9b-reasoning",
274
+ quantum_head_path: str = None,
275
+ device: str = "auto",
276
+ torch_dtype = None,
277
+ **kwargs
278
+ ):
279
+ """
280
+ Load Q-GPT model with quantum head.
281
+
282
+ Args:
283
+ model_name: HuggingFace model name or path
284
+ quantum_head_path: Path to trained quantum head weights
285
+ device: Device to load model on
286
+ torch_dtype: Model dtype (e.g., torch.bfloat16)
287
+
288
+ Returns:
289
+ QGPT model and tokenizer
290
+ """
291
+ from transformers import AutoModelForCausalLM, AutoTokenizer
292
+
293
+ if torch_dtype is None:
294
+ torch_dtype = torch.bfloat16
295
+
296
+ # Load base model
297
+ base_model = AutoModelForCausalLM.from_pretrained(
298
+ model_name,
299
+ torch_dtype=torch_dtype,
300
+ device_map=device,
301
+ trust_remote_code=True,
302
+ **kwargs
303
+ )
304
+
305
+ tokenizer = AutoTokenizer.from_pretrained(
306
+ model_name,
307
+ trust_remote_code=True,
308
+ **kwargs
309
+ )
310
+
311
+ # Create Q-GPT
312
+ model = QGPT(base_model)
313
+
314
+ # Load quantum head weights if provided
315
+ if quantum_head_path:
316
+ state_dict = torch.load(quantum_head_path, map_location="cpu")
317
+ model.quantum_head.load_state_dict(state_dict)
318
+ print(f"Loaded quantum head from {quantum_head_path}")
319
+
320
+ return model, tokenizer
321
+
322
+
323
+ if __name__ == "__main__":
324
+ # Quick test
325
+ print("Testing QuantumHead...")
326
+
327
+ head = QuantumHead(hidden_size=2880)
328
+ dummy_input = torch.randn(2, 2880) # Batch of 2
329
+
330
+ output = head(dummy_input)
331
+ print(f"Confidence: {output['confidence']}")
332
+ print(f"Uncertainty: {output['uncertainty']}")
333
+ print(f"Should refuse: {output['should_refuse']}")
334
+
335
+ print("\n✓ QuantumHead test passed!")
train.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Q-GPT Training Script
3
+ Train the quantum head on GPT outputs.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from tqdm import tqdm
10
+ import json
11
+ import os
12
+
13
+ from quantum_head import QuantumHead, load_qgpt
14
+
15
+
16
+ class ConfidenceDataset(Dataset):
17
+ """Dataset for training quantum confidence head."""
18
+
19
+ def __init__(self, data_path: str, tokenizer, max_length: int = 512):
20
+ self.tokenizer = tokenizer
21
+ self.max_length = max_length
22
+ self.data = []
23
+
24
+ # Load data
25
+ with open(data_path, 'r') as f:
26
+ for line in f:
27
+ item = json.loads(line)
28
+ self.data.append(item)
29
+
30
+ def __len__(self):
31
+ return len(self.data)
32
+
33
+ def __getitem__(self, idx):
34
+ item = self.data[idx]
35
+
36
+ # Tokenize
37
+ encoding = self.tokenizer(
38
+ item["text"],
39
+ truncation=True,
40
+ max_length=self.max_length,
41
+ padding="max_length",
42
+ return_tensors="pt"
43
+ )
44
+
45
+ return {
46
+ "input_ids": encoding["input_ids"].squeeze(),
47
+ "attention_mask": encoding["attention_mask"].squeeze(),
48
+ "confidence_label": torch.tensor(item.get("confidence", 0.5)),
49
+ "is_correct": torch.tensor(float(item.get("is_correct", True))),
50
+ }
51
+
52
+
53
+ def train_quantum_head(
54
+ model_name: str = "squ11z1/gpt-oss-9b-reasoning",
55
+ train_data_path: str = None,
56
+ output_dir: str = "./q_gpt_trained",
57
+ epochs: int = 3,
58
+ batch_size: int = 4,
59
+ learning_rate: float = 1e-4,
60
+ device: str = "cuda",
61
+ ):
62
+ """
63
+ Train the quantum head on confidence estimation.
64
+
65
+ Args:
66
+ model_name: Base model name
67
+ train_data_path: Path to training data (jsonl with text, confidence, is_correct)
68
+ output_dir: Where to save trained weights
69
+ epochs: Number of training epochs
70
+ batch_size: Batch size
71
+ learning_rate: Learning rate for quantum head
72
+ device: Device to train on
73
+ """
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer
75
+
76
+ os.makedirs(output_dir, exist_ok=True)
77
+
78
+ print(f"Loading model: {model_name}")
79
+
80
+ # Load base model (frozen)
81
+ base_model = AutoModelForCausalLM.from_pretrained(
82
+ model_name,
83
+ torch_dtype=torch.bfloat16,
84
+ device_map="auto",
85
+ trust_remote_code=True,
86
+ )
87
+ base_model.eval()
88
+ for param in base_model.parameters():
89
+ param.requires_grad = False
90
+
91
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
92
+ if tokenizer.pad_token is None:
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+
95
+ # Create quantum head
96
+ hidden_size = base_model.config.hidden_size
97
+ quantum_head = QuantumHead(hidden_size=hidden_size).to(device)
98
+
99
+ # Optimizer (only quantum head parameters)
100
+ optimizer = torch.optim.AdamW(quantum_head.parameters(), lr=learning_rate)
101
+
102
+ # Loss functions
103
+ confidence_loss_fn = nn.BCELoss()
104
+ correctness_loss_fn = nn.BCELoss()
105
+
106
+ # Training loop
107
+ if train_data_path and os.path.exists(train_data_path):
108
+ dataset = ConfidenceDataset(train_data_path, tokenizer)
109
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
110
+
111
+ for epoch in range(epochs):
112
+ quantum_head.train()
113
+ total_loss = 0
114
+
115
+ for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
116
+ input_ids = batch["input_ids"].to(device)
117
+ attention_mask = batch["attention_mask"].to(device)
118
+ confidence_labels = batch["confidence_label"].to(device)
119
+ correctness_labels = batch["is_correct"].to(device)
120
+
121
+ # Get hidden states from base model
122
+ with torch.no_grad():
123
+ outputs = base_model(
124
+ input_ids=input_ids,
125
+ attention_mask=attention_mask,
126
+ output_hidden_states=True
127
+ )
128
+ hidden_states = outputs.hidden_states[-1]
129
+
130
+ # Forward through quantum head
131
+ qout = quantum_head(hidden_states.to(device))
132
+
133
+ # Compute loss
134
+ conf_loss = confidence_loss_fn(qout["confidence"], confidence_labels)
135
+
136
+ # High confidence should correlate with correctness
137
+ correct_loss = correctness_loss_fn(qout["confidence"], correctness_labels)
138
+
139
+ loss = 0.5 * conf_loss + 0.5 * correct_loss
140
+
141
+ # Backward
142
+ optimizer.zero_grad()
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ total_loss += loss.item()
147
+
148
+ avg_loss = total_loss / len(dataloader)
149
+ print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")
150
+ else:
151
+ print("No training data provided. Saving untrained quantum head.")
152
+
153
+ # Save
154
+ save_path = os.path.join(output_dir, "quantum_head.pt")
155
+ torch.save(quantum_head.state_dict(), save_path)
156
+ print(f"Saved quantum head to {save_path}")
157
+
158
+ return quantum_head
159
+
160
+
161
+ def create_synthetic_training_data(
162
+ model_name: str,
163
+ output_path: str,
164
+ num_samples: int = 1000,
165
+ ):
166
+ """Create synthetic training data from model predictions."""
167
+ from transformers import AutoModelForCausalLM, AutoTokenizer
168
+ import random
169
+
170
+ print("Creating synthetic training data...")
171
+
172
+ model = AutoModelForCausalLM.from_pretrained(
173
+ model_name,
174
+ torch_dtype=torch.bfloat16,
175
+ device_map="auto",
176
+ trust_remote_code=True,
177
+ )
178
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
179
+
180
+ # Sample prompts
181
+ prompts = [
182
+ "What is 2 + 2?",
183
+ "Explain quantum mechanics.",
184
+ "Who was the first president of USA?",
185
+ "Solve: x^2 - 4 = 0",
186
+ "What is the capital of France?",
187
+ "Explain machine learning.",
188
+ "What is consciousness?",
189
+ "Calculate 15% of 200.",
190
+ ]
191
+
192
+ data = []
193
+
194
+ for i in tqdm(range(num_samples)):
195
+ prompt = random.choice(prompts)
196
+
197
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
198
+
199
+ with torch.no_grad():
200
+ outputs = model.generate(
201
+ **inputs,
202
+ max_new_tokens=50,
203
+ do_sample=True,
204
+ temperature=0.7,
205
+ )
206
+
207
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
208
+
209
+ # Simple heuristic for confidence (based on prompt type)
210
+ is_factual = any(kw in prompt.lower() for kw in ["what is", "who", "calculate", "solve"])
211
+ confidence = random.uniform(0.7, 0.95) if is_factual else random.uniform(0.4, 0.7)
212
+
213
+ data.append({
214
+ "text": text,
215
+ "confidence": confidence,
216
+ "is_correct": confidence > 0.5,
217
+ })
218
+
219
+ with open(output_path, 'w') as f:
220
+ for item in data:
221
+ f.write(json.dumps(item) + '\n')
222
+
223
+ print(f"Created {len(data)} samples at {output_path}")
224
+
225
+
226
+ if __name__ == "__main__":
227
+ import argparse
228
+
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument("--model", default="squ11z1/gpt-oss-9b-reasoning")
231
+ parser.add_argument("--data", default=None)
232
+ parser.add_argument("--output", default="./q_gpt_trained")
233
+ parser.add_argument("--epochs", type=int, default=3)
234
+ parser.add_argument("--create-data", action="store_true")
235
+
236
+ args = parser.parse_args()
237
+
238
+ if args.create_data:
239
+ create_synthetic_training_data(args.model, args.data or "train_data.jsonl")
240
+ else:
241
+ train_quantum_head(
242
+ model_name=args.model,
243
+ train_data_path=args.data,
244
+ output_dir=args.output,
245
+ epochs=args.epochs,
246
+ )