""" QA Model Training Script — HACKATHON SPEED EDITION (FIXED) =========================================================== Target: ~15–25 min on RTX 4050 6GB VRAM, NO crashes. Fixes vs original train_qa_fast.py: [BUG FIX] doc_stride=192 with max_length=256 was BROKEN. Stride must be significantly less than max_length. Rule of thumb: stride = max_length // 4 (64 here). Original "192" left only 64-token overlap → answers at window edges were silently mapped to CLS (treated as unanswerable). [BUG FIX] batch_size=32 would OOM on 6GB VRAM with distilroberta. Safe value is 16 with gradient_accumulation_steps=2, giving effective batch size 32 with no OOM risk. [BUG FIX] CUAD contracts are very long — 3000 examples × ~120 windows = ~360k tokenized features. Way too slow for a hackathon. Fix: cap source examples at 500 AND cap tokenized features at 12,000. Gives ~375 steps → ~15-25 min on RTX 4050. [SAFETY] Added PYTORCH_CUDA_ALLOC_CONF fragmentation guard. [SAFETY] fp16 + gradient checkpointing combo — halves VRAM footprint. [SPEED] max_train_samples=500, MAX_FEATURES=12000 — done in <30 min. [SPEED] tinyroberta-squad2 is kept — it's 4x smaller than roberta-base. [QUALITY] doc_stride fixed, so answer extraction is now correct. Expected wall-clock (RTX 4050 6GB, fp16): 12000 features (~500 samples) → ~15–25 min ← default 20000 features (~800 samples) → ~35–45 min 30000 features (~1200 samples) → ~55–70 min (use --max_features 30000) Usage: python train_qa_fast.py # safe defaults (~20 min) python train_qa_fast.py --max_features 20000 # better model (~40 min) python train_qa_fast.py --max_features 30000 # best quality (~65 min) python train_qa_fast.py --device cpu # CPU fallback (slow) """ import argparse import gc import json import logging import os import platform import sys from typing import Dict, List logger = logging.getLogger(__name__) # ────────────────────────────────────────────────────────────────────────────── # VRAM Safety — set BEFORE torch is imported anywhere # ────────────────────────────────────────────────────────────────────────────── # Prevents CUDA OOM from memory fragmentation (critical on 6GB cards) os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # Parallel tokenization (safe on Linux/Mac, disabled on Windows below) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # set per-platform later # ────────────────────────────────────────────────────────────────────────────── # Data loading # ────────────────────────────────────────────────────────────────────────────── def load_cuad_data(filepath: str) -> Dict: """Load a CUAD JSON file (SQuAD 2.0 format).""" with open(filepath, "r", encoding="utf-8") as f: return json.load(f) def cuad_to_squad_examples(data: Dict) -> List[Dict]: """Convert CUAD data to flat SQuAD-style examples.""" examples = [] for article in data["data"]: title = article.get("title", "") for paragraph in article["paragraphs"]: context = paragraph["context"] for qa in paragraph["qas"]: examples.append({ "id": qa["id"], "title": title, "question": qa["question"], "context": context, "answers": { "text": [a["text"] for a in qa.get("answers", [])], "answer_start": [a["answer_start"] for a in qa.get("answers", [])], }, "is_impossible": qa.get("is_impossible", False), }) return examples # ────────────────────────────────────────────────────────────────────────────── # Tokenization # ────────────────────────────────────────────────────────────────────────────── def prepare_train_features(examples, tokenizer, max_length=256, doc_stride=64): """ Sliding-window tokenization for extractive QA. IMPORTANT — doc_stride constraint: doc_stride MUST be substantially less than max_length. The overlap region (max_length - doc_stride) is what lets the model "see" answers that fall near a window boundary. Safe range: doc_stride = max_length // 4 to max_length // 3 (64–85 tokens of overlap for max_length=256 is fine) The original script used stride=192 on max_length=256 — only 64-token overlap — which caused most boundary answers to silently fall back to the CLS token (marked unanswerable). This function uses the value passed in from train(); the default here is 64. """ pad_on_right = tokenizer.padding_side == "right" tokenized = tokenizer( examples["question"] if pad_on_right else examples["context"], examples["context"] if pad_on_right else examples["question"], truncation="only_second" if pad_on_right else "only_first", max_length=max_length, stride=doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) sample_mapping = tokenized.pop("overflow_to_sample_mapping") offset_mapping = tokenized.pop("offset_mapping") tokenized["start_positions"] = [] tokenized["end_positions"] = [] for i, offsets in enumerate(offset_mapping): input_ids = tokenized["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) sequence_ids = tokenized.sequence_ids(i) sample_index = sample_mapping[i] answers = examples["answers"][sample_index] # Unanswerable → point to CLS if examples["is_impossible"][sample_index] or len(answers["answer_start"]) == 0: tokenized["start_positions"].append(cls_index) tokenized["end_positions"].append(cls_index) continue start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # Find the context token span in this window context_seq_id = 1 if pad_on_right else 0 token_start_index = 0 while sequence_ids[token_start_index] != context_seq_id: token_start_index += 1 token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != context_seq_id: token_end_index -= 1 # Answer not fully within this window → CLS fallback if not ( offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char ): tokenized["start_positions"].append(cls_index) tokenized["end_positions"].append(cls_index) continue # Walk forward to answer start token while ( token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char ): token_start_index += 1 tokenized["start_positions"].append(token_start_index - 1) # Walk backward to answer end token while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized["end_positions"].append(token_end_index + 1) return tokenized # ────────────────────────────────────────────────────────────────────────────── # Training # ────────────────────────────────────────────────────────────────────────────── def train( train_path: str = "data/train.json", test_path: str = "data/test.json", base_model: str = "deepset/tinyroberta-squad2", output_dir: str = "ckpt_obligation_fast", epochs: int = 1, # FIXED: 16 + grad_accum=2 gives effective batch 32, zero OOM risk on 6GB batch_size: int = 16, learning_rate: float = 3e-5, max_length: int = 256, # FIXED: was 192 (broken). 64 = max_length//4, gives 192-token overlap. doc_stride: int = 64, # FIXED: was 3000 → ~360k features (hours). 500 examples is plenty to start. max_train_samples: int = 500, # FIXED: hard cap on tokenized features — the real time killer on CUAD. # 12000 features @ effective batch 32 = ~375 steps = ~15-25 min on RTX 4050. max_features: int = 12_000, device: str = "auto", ): try: import torch from datasets import Dataset from transformers import ( AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments, Trainer, default_data_collator, ) except ImportError as e: print(f"Missing dependency: {e}") print("Install with: pip install torch transformers datasets") sys.exit(1) # ── Validate stride ────────────────────────────────────────────────── if doc_stride >= max_length: raise ValueError( f"doc_stride ({doc_stride}) must be less than max_length ({max_length}). " f"Recommended: doc_stride = max_length // 4 = {max_length // 4}" ) if doc_stride > max_length // 2: logger.warning( f"doc_stride={doc_stride} is more than half of max_length={max_length}. " f"This gives only {max_length - doc_stride} tokens of overlap — answers " f"near window edges may be missed. Consider doc_stride <= {max_length // 4}." ) # ── Device resolution ──────────────────────────────────────────────── if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" use_gpu = (device == "cuda") if use_gpu: vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 logger.info(f"GPU: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)") if vram_gb < 5.5: logger.warning( f"Only {vram_gb:.1f}GB VRAM detected. " "Reducing batch_size to 8 and enabling gradient checkpointing." ) batch_size = 8 else: logger.warning("No GPU detected — training on CPU will be very slow.") logger.info(f"Training device: {device}") # ── Sample limit ───────────────────────────────────────────────────── force_all = (max_train_samples is not None and max_train_samples < 0) if force_all: max_train_samples = None logger.warning("Using FULL dataset — this will take many hours!") # ── Load data ──────────────────────────────────────────────────────── logger.info(f"Loading training data from {train_path}") train_data = load_cuad_data(train_path) train_examples = cuad_to_squad_examples(train_data) logger.info(f" → {len(train_examples)} total examples in dataset") del train_data gc.collect() # ── Build HF Dataset ───────────────────────────────────────────────── def to_columns(examples): cols = {"id": [], "question": [], "context": [], "answers": [], "is_impossible": []} for ex in examples: for k in cols: cols[k].append(ex[k]) return cols train_dataset = Dataset.from_dict(to_columns(train_examples)) del train_examples gc.collect() # Apply sample cap with shuffle so we get diverse contract types if max_train_samples and max_train_samples < len(train_dataset): train_dataset = train_dataset.shuffle(seed=42).select(range(max_train_samples)) logger.info(f" Using {len(train_dataset)} training samples (shuffled)") else: logger.info(f" Using all {len(train_dataset)} training samples") # ── Load tokenizer & model ─────────────────────────────────────────── logger.info(f"Loading model: {base_model}") tokenizer = AutoTokenizer.from_pretrained(base_model) model = AutoModelForQuestionAnswering.from_pretrained(base_model) # Gradient checkpointing: trades ~20% speed for ~40% VRAM reduction. # Essential for staying safe on 6GB cards. if use_gpu: model.gradient_checkpointing_enable() logger.info(" Gradient checkpointing enabled (saves ~40% VRAM)") # ── Tokenize with disk cache ───────────────────────────────────────── is_windows = platform.system() == "Windows" num_samples = len(train_dataset) os.environ["TOKENIZERS_PARALLELISM"] = "false" if is_windows else "true" cache_dir = os.path.join(output_dir, "tokenized_cache") safe_model = base_model.replace("/", "_") cache_path = os.path.join( cache_dir, f"{safe_model}_L{max_length}_S{doc_stride}_N{num_samples}" ) if os.path.exists(cache_path): from datasets import load_from_disk logger.info(f"Loading cached tokenized data from {cache_path}") tokenized_train = load_from_disk(cache_path) logger.info(f" → {len(tokenized_train)} features loaded from cache") else: logger.info( f"Tokenizing {num_samples} examples " f"(max_length={max_length}, stride={doc_stride}, " f"overlap={max_length - doc_stride} tokens)…" ) num_proc = 0 if is_windows else min(4, os.cpu_count() or 1) tokenized_train = train_dataset.map( lambda ex: prepare_train_features(ex, tokenizer, max_length, doc_stride), batched=True, batch_size=64, remove_columns=train_dataset.column_names, num_proc=num_proc, desc="Tokenizing", ) os.makedirs(cache_dir, exist_ok=True) tokenized_train.save_to_disk(cache_path) logger.info(f" → Tokenized data cached to {cache_path}") del train_dataset gc.collect() if use_gpu: torch.cuda.empty_cache() logger.info(f" → {len(tokenized_train)} tokenized features before cap") # ── Cap tokenized features ─────────────────────────────────────────── # CUAD contracts are very long — each example fans out to ~120 windows. # Even 500 examples → ~60k features without this cap. # 12k features @ effective batch 32 = ~375 steps = ~15-25 min on RTX 4050. if max_features > 0 and len(tokenized_train) > max_features: tokenized_train = tokenized_train.shuffle(seed=42).select(range(max_features)) logger.info(f" → Capped to {max_features} features for hackathon speed") else: logger.info(f" → Using all {len(tokenized_train)} features") # ── Optimizer ──────────────────────────────────────────────────────── try: optim = ( "adamw_torch_fused" if (use_gpu and torch.__version__ >= "2.0") else "adamw_torch" ) except Exception: optim = "adamw_torch" # ── Training arguments ─────────────────────────────────────────────── # Effective batch size = per_device_train_batch_size * gradient_accumulation_steps # = 16 * 2 = 32 (same as original intent, but safe on 6GB) training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=epochs, per_device_train_batch_size=batch_size, # 16 — safe on 6GB gradient_accumulation_steps=2, # effective batch = 32 learning_rate=learning_rate, weight_decay=0.01, warmup_ratio=0.06, logging_steps=25, save_strategy="epoch", save_total_limit=1, fp16=use_gpu, # half precision on GPU fp16_full_eval=False, report_to="none", use_cpu=(not use_gpu), group_by_length=True, # less padding waste dataloader_num_workers=0 if is_windows else 2, dataloader_pin_memory=use_gpu, optim=optim, prediction_loss_only=True, disable_tqdm=False, # Prevents OOM from large gradient buffers max_grad_norm=1.0, ) # ── Trainer ────────────────────────────────────────────────────────── trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_train, tokenizer=tokenizer, data_collator=default_data_collator, ) # ── Pre-flight summary ─────────────────────────────────────────────── effective_batch = batch_size * 2 # * grad_accum_steps steps_per_epoch = len(tokenized_train) // effective_batch logger.info("=" * 60) logger.info("Starting training") logger.info(f" Model : {base_model}") logger.info(f" Features : {len(tokenized_train)}") logger.info(f" Epochs : {epochs}") logger.info(f" Batch (device) : {batch_size}") logger.info(f" Grad accum steps : 2 → effective batch = {effective_batch}") logger.info(f" Steps/epoch : ~{steps_per_epoch}") logger.info(f" max_length : {max_length}") logger.info(f" doc_stride : {doc_stride} (overlap = {max_length - doc_stride} tokens)") logger.info(f" Device : {device}") logger.info(f" Optimizer : {optim}") logger.info("=" * 60) trainer.train() # ── Save ───────────────────────────────────────────────────────────── logger.info(f"Saving model to {output_dir}") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) logger.info("✓ Training complete!") return output_dir # ────────────────────────────────────────────────────────────────────────────── # CLI # ────────────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( description="Fine-tune QA model on CUAD — hackathon speed edition (fixed)" ) parser.add_argument("--train_path", default="data/train.json") parser.add_argument("--test_path", default="data/test.json") parser.add_argument( "--base_model", default="deepset/tinyroberta-squad2", help="HuggingFace model. tinyroberta-squad2 is fast and supports unanswerable questions." ) parser.add_argument("--output_dir", default="ckpt_obligation_fast") parser.add_argument("--epochs", type=int, default=1) parser.add_argument( "--batch_size", type=int, default=16, help="Per-device batch size. 16 is safe on 6GB VRAM. Gradient accum=2 gives effective batch 32." ) parser.add_argument("--learning_rate", type=float, default=3e-5) parser.add_argument("--max_length", type=int, default=256) parser.add_argument( "--doc_stride", type=int, default=64, help=( "Stride between sliding windows. MUST be < max_length. " "Overlap = max_length - doc_stride. " "Recommended range: max_length//4 to max_length//3. " "Default 64 gives 192-token overlap on max_length=256." ) ) parser.add_argument( "--max_train_samples", type=int, default=500, help=( "Source QA examples to load before tokenization. " "500 is the default — combined with --max_features this keeps " "training well under 1 hour. Use -1 for full dataset (hours)." ) ) parser.add_argument( "--max_features", type=int, default=12_000, help=( "Hard cap on tokenized features (windows). " "This is the real time control — CUAD examples fan out to ~120 windows each. " "12000 ≈ 15-25min | 20000 ≈ 35-45min | 30000 ≈ 55-70min on RTX 4050. " "Use 0 to disable the cap." ) ) parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"]) args = parser.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) train(**vars(args)) if __name__ == "__main__": main()