""" Train H4 attention model for extractive question-answering. Uses the autoresearch pattern: modify -> run (2 min) -> measure -> keep/discard. Metric: F1 score on validation QA pairs (not bpb). The model learns to generate answer text given [context | question |] as input. This is extractive QA — the answer is a span from the context. Architecture: same H4LanguageModel from Phase 5/6, trained on QA-formatted text. """ import os import math import time import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import sys import re from collections import Counter sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from h4_language_model import H4LanguageModel from rag.prepare_qa import generate_sample_qa, prepare_training_data, format_qa_for_training, download_squad_dev from rag.tokenizer import BPETokenizer # --------------------------------------------------------------------------- # Hyperparameters (AGENT MAY MODIFY THESE) # --------------------------------------------------------------------------- TIME_BUDGET = 600 # 10 minutes for QA fine-tuning # Model (must match pre-trained checkpoint) D_MODEL = 64 N_HEADS = 8 N_LAYERS = 2 D_VALUE = 16 D_FFN = 256 DROPOUT = 0.0 USE_BITLINEAR = True # Pre-trained checkpoint (set to None to train from scratch) PRETRAINED_CKPT = os.path.join(os.path.dirname(__file__), '..', '..', 'checkpoints', 'lm_pretrained.pt') # Optimizer (lower LR for fine-tuning from pre-trained) LR = 3e-3 WEIGHT_DECAY = 0.01 WARMUP_STEPS = 30 GRAD_CLIP = 1.0 # Training BATCH_SIZE = 2 MAX_SEQ_LEN = 512 EVAL_INTERVAL = 500 # --------------------------------------------------------------------------- # QA-specific metrics # --------------------------------------------------------------------------- def normalize_answer(s: str) -> str: """Lower case, strip whitespace and punctuation.""" s = s.lower().strip() s = re.sub(r'[^\w\s]', '', s) s = re.sub(r'\s+', ' ', s) return s def compute_f1(prediction: str, gold: str) -> float: """Token-level F1 between prediction and gold answer.""" pred_tokens = normalize_answer(prediction).split() gold_tokens = normalize_answer(gold).split() if not gold_tokens: return 1.0 if not pred_tokens else 0.0 if not pred_tokens: return 0.0 common = Counter(pred_tokens) & Counter(gold_tokens) n_common = sum(common.values()) if n_common == 0: return 0.0 precision = n_common / len(pred_tokens) recall = n_common / len(gold_tokens) f1 = 2 * precision * recall / (precision + recall) return f1 def compute_exact_match(prediction: str, gold: str) -> float: """Exact match after normalization.""" return 1.0 if normalize_answer(prediction) == normalize_answer(gold) else 0.0 # --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- def main(): t_start = time.time() torch.manual_seed(42) np.random.seed(42) # Load QA data — use SQuAD if available, fall back to sample squad_qa = download_squad_dev() if len(squad_qa) > 100: print(f"Using SQuAD 2.0: {len(squad_qa)} QA pairs") all_qa = squad_qa else: print("SQuAD not available, using sample QA pairs") all_qa = generate_sample_qa() # Build BPE tokenizer from training data tokenizer = BPETokenizer(max_vocab=4096) all_texts = [qa['context'] + ' ' + qa['question'] + ' ' + qa['answer'] for qa in all_qa[:2000]] tokenizer.build_vocab(all_texts) vocab_size = tokenizer.vocab_size # Prepare training sequences using BPE # Format: [context SEP question SEP answer] # With BPE, seq_len=512 covers ~2500 chars — virtually all SQuAD contexts all_seqs = [] for qa in all_qa: input_ids, answer_ids = tokenizer.encode_qa(qa['context'], qa['question'], qa['answer']) full_ids = input_ids + answer_ids if len(full_ids) <= MAX_SEQ_LEN and len(full_ids) > 5: all_seqs.append((full_ids, len(input_ids), qa)) print(f"BPE sequences that fit seq_len={MAX_SEQ_LEN}: {len(all_seqs)} " f"(from {len(all_qa)} total QA pairs)") # Split into train/val rng = np.random.RandomState(42) indices = list(range(len(all_seqs))) rng.shuffle(indices) n_val = min(30, max(20, int(len(all_seqs) * 0.05))) val_indices = set(indices[:n_val]) train_seqs = [] val_pairs = [] for i, (full_ids, input_len, qa) in enumerate(all_seqs): if i in val_indices: val_pairs.append((full_ids[:input_len], qa['answer'], qa)) else: train_seqs.append(torch.tensor(full_ids, dtype=torch.long)) rng.shuffle(train_seqs) avg_len = sum(len(s) for s in train_seqs) / len(train_seqs) if train_seqs else 0 print(f"Train: {len(train_seqs)} seqs (avg {avg_len:.0f} BPE tokens), Val: {len(val_pairs)} pairs") # Create model model = H4LanguageModel( vocab_size=vocab_size, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, d_value=D_VALUE, d_ffn=D_FFN, max_seq_len=MAX_SEQ_LEN, dropout=DROPOUT, use_bitlinear=USE_BITLINEAR, ) params = model.count_params() print(f"Model: {params['trainable']:,} trainable params") # Note: pre-trained checkpoint skipped because vocab size changed with BPE. # The model trains from scratch but BPE makes this much more efficient. optimizer = torch.optim.AdamW( model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95)) def lr_schedule(step): if step < WARMUP_STEPS: return step / max(WARMUP_STEPS, 1) progress = (step - WARMUP_STEPS) / max(1, 5000 - WARMUP_STEPS) return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) # Training loop step = 0 total_training_time = 0.0 best_f1 = 0.0 model.train() print(f"\nTraining for {TIME_BUDGET}s, metric=F1") print(f"{'step':>6} {'loss':>8} {'val_f1':>8} {'val_em':>8} {'lr':>10}") print("-" * 48) while True: t0 = time.time() # Sample a training sequence seq = train_seqs[step % len(train_seqs)] x = seq[:-1].unsqueeze(0) # (1, T-1) y = seq[1:].unsqueeze(0) # (1, T-1) # Pad/truncate to consistent length if x.shape[1] > MAX_SEQ_LEN: x = x[:, :MAX_SEQ_LEN] y = y[:, :MAX_SEQ_LEN] logits = model(x, use_tree=False) loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) optimizer.zero_grad() loss.backward() if GRAD_CLIP > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step() scheduler.step() dt = time.time() - t0 if step > 2: total_training_time += dt # Evaluate if step % EVAL_INTERVAL == 0: model.eval() f1_scores = [] em_scores = [] with torch.no_grad(): for inp_ids, gold_answer, qa in val_pairs: # Generate answer inp_tensor = torch.tensor([inp_ids], dtype=torch.long) max_answer_len = min(len(tokenizer.encode(gold_answer)) + 10, 50) generated = model.generate( inp_tensor, max_new_tokens=max_answer_len, temperature=0.5, top_k_sample=10, ) gen_ids = generated[0, len(inp_ids):] pred_answer = tokenizer.decode(gen_ids.tolist()) f1 = compute_f1(pred_answer, gold_answer) em = compute_exact_match(pred_answer, gold_answer) f1_scores.append(f1) em_scores.append(em) avg_f1 = sum(f1_scores) / len(f1_scores) avg_em = sum(em_scores) / len(em_scores) if avg_f1 > best_f1: best_f1 = avg_f1 current_lr = scheduler.get_last_lr()[0] print(f"{step:6d} {loss.item():8.4f} {avg_f1:8.3f} {avg_em:8.3f} {current_lr:10.6f}") model.train() step += 1 if step > 2 and total_training_time >= TIME_BUDGET: break # Final evaluation with sample outputs model.eval() print("\n" + "=" * 60) print("SAMPLE QA RESULTS:") final_f1_scores = [] final_em_scores = [] with torch.no_grad(): for inp_ids, gold_answer, qa in val_pairs: inp_tensor = torch.tensor([inp_ids], dtype=torch.long) max_answer_len = min(len(gold_answer) + 20, 100) generated = model.generate( inp_tensor, max_new_tokens=max_answer_len, temperature=0.3, top_k_sample=5, ) gen_ids = generated[0, len(inp_ids):] pred_answer = tokenizer.decode(gen_ids.tolist()) f1 = compute_f1(pred_answer, gold_answer) em = compute_exact_match(pred_answer, gold_answer) final_f1_scores.append(f1) final_em_scores.append(em) print(f" Q: {qa['question']}") print(f" Gold: {gold_answer}") print(f" Pred: {pred_answer[:80]}") print(f" F1={f1:.3f} EM={em:.1f}") print() avg_f1 = sum(final_f1_scores) / len(final_f1_scores) if final_f1_scores else 0 avg_em = sum(final_em_scores) / len(final_em_scores) if final_em_scores else 0 print("=" * 60) print("\n---") print(f"val_f1: {avg_f1:.4f}") print(f"val_em: {avg_em:.4f}") print(f"best_f1: {best_f1:.4f}") print(f"training_seconds: {total_training_time:.1f}") print(f"total_seconds: {time.time() - t_start:.1f}") print(f"num_steps: {step}") print(f"num_params: {params['trainable']}") print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}") if __name__ == '__main__': main()