|
|
""" |
|
|
Q-GPT Training Script |
|
|
Train the quantum head on GPT outputs. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import os |
|
|
|
|
|
from quantum_head import QuantumHead, load_qgpt |
|
|
|
|
|
|
|
|
class ConfidenceDataset(Dataset): |
|
|
"""Dataset for training quantum confidence head.""" |
|
|
|
|
|
def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = [] |
|
|
|
|
|
|
|
|
with open(data_path, 'r') as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
self.data.append(item) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
item["text"], |
|
|
truncation=True, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(), |
|
|
"confidence_label": torch.tensor(item.get("confidence", 0.5)), |
|
|
"is_correct": torch.tensor(float(item.get("is_correct", True))), |
|
|
} |
|
|
|
|
|
|
|
|
def train_quantum_head( |
|
|
model_name: str = "squ11z1/gpt-oss-9b-reasoning", |
|
|
train_data_path: str = None, |
|
|
output_dir: str = "./q_gpt_trained", |
|
|
epochs: int = 3, |
|
|
batch_size: int = 4, |
|
|
learning_rate: float = 1e-4, |
|
|
device: str = "cuda", |
|
|
): |
|
|
""" |
|
|
Train the quantum head on confidence estimation. |
|
|
|
|
|
Args: |
|
|
model_name: Base model name |
|
|
train_data_path: Path to training data (jsonl with text, confidence, is_correct) |
|
|
output_dir: Where to save trained weights |
|
|
epochs: Number of training epochs |
|
|
batch_size: Batch size |
|
|
learning_rate: Learning rate for quantum head |
|
|
device: Device to train on |
|
|
""" |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"Loading model: {model_name}") |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
base_model.eval() |
|
|
for param in base_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
hidden_size = base_model.config.hidden_size |
|
|
quantum_head = QuantumHead(hidden_size=hidden_size).to(device) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(quantum_head.parameters(), lr=learning_rate) |
|
|
|
|
|
|
|
|
confidence_loss_fn = nn.BCELoss() |
|
|
correctness_loss_fn = nn.BCELoss() |
|
|
|
|
|
|
|
|
if train_data_path and os.path.exists(train_data_path): |
|
|
dataset = ConfidenceDataset(train_data_path, tokenizer) |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
quantum_head.train() |
|
|
total_loss = 0 |
|
|
|
|
|
for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"): |
|
|
input_ids = batch["input_ids"].to(device) |
|
|
attention_mask = batch["attention_mask"].to(device) |
|
|
confidence_labels = batch["confidence_label"].to(device) |
|
|
correctness_labels = batch["is_correct"].to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = base_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True |
|
|
) |
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
qout = quantum_head(hidden_states.to(device)) |
|
|
|
|
|
|
|
|
conf_loss = confidence_loss_fn(qout["confidence"], confidence_labels) |
|
|
|
|
|
|
|
|
correct_loss = correctness_loss_fn(qout["confidence"], correctness_labels) |
|
|
|
|
|
loss = 0.5 * conf_loss + 0.5 * correct_loss |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}") |
|
|
else: |
|
|
print("No training data provided. Saving untrained quantum head.") |
|
|
|
|
|
|
|
|
save_path = os.path.join(output_dir, "quantum_head.pt") |
|
|
torch.save(quantum_head.state_dict(), save_path) |
|
|
print(f"Saved quantum head to {save_path}") |
|
|
|
|
|
return quantum_head |
|
|
|
|
|
|
|
|
def create_synthetic_training_data( |
|
|
model_name: str, |
|
|
output_path: str, |
|
|
num_samples: int = 1000, |
|
|
): |
|
|
"""Create synthetic training data from model predictions.""" |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import random |
|
|
|
|
|
print("Creating synthetic training data...") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"What is 2 + 2?", |
|
|
"Explain quantum mechanics.", |
|
|
"Who was the first president of USA?", |
|
|
"Solve: x^2 - 4 = 0", |
|
|
"What is the capital of France?", |
|
|
"Explain machine learning.", |
|
|
"What is consciousness?", |
|
|
"Calculate 15% of 200.", |
|
|
] |
|
|
|
|
|
data = [] |
|
|
|
|
|
for i in tqdm(range(num_samples)): |
|
|
prompt = random.choice(prompts) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=50, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
) |
|
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
is_factual = any(kw in prompt.lower() for kw in ["what is", "who", "calculate", "solve"]) |
|
|
confidence = random.uniform(0.7, 0.95) if is_factual else random.uniform(0.4, 0.7) |
|
|
|
|
|
data.append({ |
|
|
"text": text, |
|
|
"confidence": confidence, |
|
|
"is_correct": confidence > 0.5, |
|
|
}) |
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
for item in data: |
|
|
f.write(json.dumps(item) + '\n') |
|
|
|
|
|
print(f"Created {len(data)} samples at {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model", default="squ11z1/gpt-oss-9b-reasoning") |
|
|
parser.add_argument("--data", default=None) |
|
|
parser.add_argument("--output", default="./q_gpt_trained") |
|
|
parser.add_argument("--epochs", type=int, default=3) |
|
|
parser.add_argument("--create-data", action="store_true") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.create_data: |
|
|
create_synthetic_training_data(args.model, args.data or "train_data.jsonl") |
|
|
else: |
|
|
train_quantum_head( |
|
|
model_name=args.model, |
|
|
train_data_path=args.data, |
|
|
output_dir=args.output, |
|
|
epochs=args.epochs, |
|
|
) |
|
|
|