Q-GPT / quantum_head.py
squ11z1's picture
Upload folder using huggingface_hub
662c9ff verified
"""
Q-GPT: Quantum-Enhanced GPT with Confidence Estimation
A quantum neural network head that estimates response confidence.
Author: squ11z1
"""
import torch
import torch.nn as nn
import numpy as np
try:
import pennylane as qml
PENNYLANE_AVAILABLE = True
except ImportError:
PENNYLANE_AVAILABLE = False
print("Warning: PennyLane not installed. Using classical fallback.")
class QuantumCircuit:
"""Variational Quantum Circuit for confidence estimation."""
def __init__(self, n_qubits: int = 4, n_layers: int = 3):
self.n_qubits = n_qubits
self.n_layers = n_layers
if PENNYLANE_AVAILABLE:
self.dev = qml.device("default.qubit", wires=n_qubits)
self.circuit = qml.QNode(self._quantum_circuit, self.dev, interface="torch")
def _quantum_circuit(self, inputs, weights):
"""
Variational quantum circuit.
Args:
inputs: Input features [n_qubits]
weights: Trainable parameters [n_layers, n_qubits, 3]
"""
# Encode classical data into quantum states
for i in range(self.n_qubits):
qml.RY(inputs[i], wires=i)
qml.RZ(inputs[i], wires=i)
# Variational layers
for layer in range(self.n_layers):
# Rotation gates
for i in range(self.n_qubits):
qml.Rot(weights[layer, i, 0],
weights[layer, i, 1],
weights[layer, i, 2], wires=i)
# Entanglement (CNOT ladder)
for i in range(self.n_qubits - 1):
qml.CNOT(wires=[i, i + 1])
# Circular entanglement
if self.n_qubits > 2:
qml.CNOT(wires=[self.n_qubits - 1, 0])
# Measure expectation values
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
def forward(self, inputs, weights):
"""Execute quantum circuit."""
if PENNYLANE_AVAILABLE:
return self.circuit(inputs, weights)
else:
# Classical fallback: simple tanh transformation
return torch.tanh(inputs @ weights.mean(dim=(0, 2)))
class QuantumHead(nn.Module):
"""
Quantum-enhanced confidence estimation head for GPT.
Takes hidden states from the last layer and outputs:
- confidence: Estimated confidence in the response [0, 1]
- uncertainty: Quantum-derived uncertainty measure
"""
def __init__(
self,
hidden_size: int = 2880, # GPT-OSS hidden size
n_qubits: int = 4,
n_layers: int = 3,
intermediate_size: int = 64,
):
super().__init__()
self.hidden_size = hidden_size
self.n_qubits = n_qubits
self.n_layers = n_layers
# Classical preprocessing: compress hidden states
self.pre_quantum = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.LayerNorm(intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, n_qubits),
nn.Tanh(), # Normalize to [-1, 1] for quantum encoding
)
# Quantum circuit
self.quantum = QuantumCircuit(n_qubits, n_layers)
# Quantum weights (trainable)
self.quantum_weights = nn.Parameter(
torch.randn(n_layers, n_qubits, 3) * 0.1
)
# Post-quantum processing
self.post_quantum = nn.Sequential(
nn.Linear(n_qubits, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, 2), # [confidence, uncertainty]
)
# Output heads
self.confidence_activation = nn.Sigmoid()
self.uncertainty_activation = nn.Softplus()
def forward(self, hidden_states: torch.Tensor) -> dict:
"""
Compute confidence and uncertainty from hidden states.
Args:
hidden_states: Last layer hidden states [batch, seq_len, hidden_size]
or pooled representation [batch, hidden_size]
Returns:
dict with 'confidence' and 'uncertainty' tensors
"""
# Pool if sequence dimension exists
if hidden_states.dim() == 3:
# Use last token representation
hidden_states = hidden_states[:, -1, :]
batch_size = hidden_states.size(0)
# Preprocess
quantum_input = self.pre_quantum(hidden_states) # [batch, n_qubits]
# Process through quantum circuit (per sample)
quantum_outputs = []
for i in range(batch_size):
qout = self.quantum.forward(
quantum_input[i],
self.quantum_weights
)
if isinstance(qout, list):
qout = torch.stack(qout)
quantum_outputs.append(qout)
quantum_output = torch.stack(quantum_outputs) # [batch, n_qubits]
# Post-process
output = self.post_quantum(quantum_output)
confidence = self.confidence_activation(output[:, 0])
uncertainty = self.uncertainty_activation(output[:, 1])
return {
"confidence": confidence,
"uncertainty": uncertainty,
"should_refuse": confidence < 0.3, # Low confidence = should refuse
}
def get_interpretable_confidence(self, confidence: torch.Tensor) -> str:
"""Convert confidence score to human-readable label."""
conf = confidence.item() if confidence.dim() == 0 else confidence.mean().item()
if conf >= 0.9:
return "very high"
elif conf >= 0.7:
return "high"
elif conf >= 0.5:
return "moderate"
elif conf >= 0.3:
return "low"
else:
return "very low (consider refusing)"
class QGPT(nn.Module):
"""
Q-GPT: GPT with Quantum Confidence Head
Wraps any HuggingFace GPT model and adds quantum confidence estimation.
"""
def __init__(self, base_model, quantum_head: QuantumHead = None):
super().__init__()
self.base_model = base_model
# Get hidden size from model config
if hasattr(base_model.config, 'hidden_size'):
hidden_size = base_model.config.hidden_size
elif hasattr(base_model.config, 'd_model'):
hidden_size = base_model.config.d_model
else:
hidden_size = 2880 # GPT-OSS default
self.quantum_head = quantum_head or QuantumHead(hidden_size=hidden_size)
def forward(self, input_ids, attention_mask=None, **kwargs):
"""Forward pass with confidence estimation."""
# Get base model outputs with hidden states
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
# Get last layer hidden states
hidden_states = outputs.hidden_states[-1]
# Compute quantum confidence
confidence_output = self.quantum_head(hidden_states)
# Add to outputs
outputs.confidence = confidence_output["confidence"]
outputs.uncertainty = confidence_output["uncertainty"]
outputs.should_refuse = confidence_output["should_refuse"]
return outputs
def generate_with_confidence(
self,
input_ids,
attention_mask=None,
max_new_tokens=256,
**kwargs
):
"""Generate text and return confidence score."""
# Generate
outputs = self.base_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs
)
# Get hidden states from last generation step
if hasattr(outputs, 'hidden_states') and outputs.hidden_states:
last_hidden = outputs.hidden_states[-1][-1] # Last layer, last step
else:
# Fallback: run forward pass on generated sequence
with torch.no_grad():
model_outputs = self.base_model(
outputs.sequences,
output_hidden_states=True
)
last_hidden = model_outputs.hidden_states[-1]
# Compute confidence
confidence_output = self.quantum_head(last_hidden)
return {
"sequences": outputs.sequences,
"confidence": confidence_output["confidence"],
"uncertainty": confidence_output["uncertainty"],
"should_refuse": confidence_output["should_refuse"],
"confidence_label": self.quantum_head.get_interpretable_confidence(
confidence_output["confidence"]
),
}
def load_qgpt(
model_name: str = "squ11z1/gpt-oss-9b-reasoning",
quantum_head_path: str = None,
device: str = "auto",
torch_dtype = None,
**kwargs
):
"""
Load Q-GPT model with quantum head.
Args:
model_name: HuggingFace model name or path
quantum_head_path: Path to trained quantum head weights
device: Device to load model on
torch_dtype: Model dtype (e.g., torch.bfloat16)
Returns:
QGPT model and tokenizer
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
if torch_dtype is None:
torch_dtype = torch.bfloat16
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
**kwargs
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
**kwargs
)
# Create Q-GPT
model = QGPT(base_model)
# Load quantum head weights if provided
if quantum_head_path:
state_dict = torch.load(quantum_head_path, map_location="cpu")
model.quantum_head.load_state_dict(state_dict)
print(f"Loaded quantum head from {quantum_head_path}")
return model, tokenizer
if __name__ == "__main__":
# Quick test
print("Testing QuantumHead...")
head = QuantumHead(hidden_size=2880)
dummy_input = torch.randn(2, 2880) # Batch of 2
output = head(dummy_input)
print(f"Confidence: {output['confidence']}")
print(f"Uncertainty: {output['uncertainty']}")
print(f"Should refuse: {output['should_refuse']}")
print("\n✓ QuantumHead test passed!")