akshay1306's picture
Upload 14 files
5dbca28 verified
"""
QA Model Training Script
=========================
Fine-tunes a QA model (roberta-base or squad2 warm-start) on the CUAD dataset.
CUAD contexts are full contracts (avg 54K chars).
We use sliding window tokenization to create 384-token windows
with 128-token overlap β€” this is how SQuAD-style models handle long documents.
Memory Safety:
- Default batch_size=2 to avoid RAM exhaustion on laptops
- Default max_train_samples=500 (use --max_train_samples=-1 for full dataset)
- Aggressive garbage collection after data transformations
- Tokenization uses small batch sizes to limit peak memory
Usage:
python -m src.train_qa # defaults (safe)
python -m src.train_qa --epochs 3 --batch_size 4
python -m src.train_qa --base_model deepset/roberta-base-squad2
python -m src.train_qa --max_train_samples -1 # full dataset (needs 16GB+ RAM)
"""
import argparse
import functools
import gc
import json
import logging
import os
import sys
from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
def load_cuad_data(filepath: str) -> Dict:
"""Load a CUAD JSON file (SQuAD 2.0 format)."""
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return data
def cuad_to_squad_examples(data: Dict) -> List[Dict]:
"""Convert CUAD data to a flat list of SQuAD-style examples.
Each example has: id, question, context, answers, is_impossible.
"""
examples = []
for article in data["data"]:
title = article.get("title", "")
for paragraph in article["paragraphs"]:
context = paragraph["context"]
for qa in paragraph["qas"]:
example = {
"id": qa["id"],
"title": title,
"question": qa["question"],
"context": context,
"answers": {
"text": [a["text"] for a in qa.get("answers", [])],
"answer_start": [a["answer_start"] for a in qa.get("answers", [])],
},
"is_impossible": qa.get("is_impossible", False),
}
examples.append(example)
return examples
def prepare_train_features(examples, tokenizer, max_length=384, doc_stride=128):
"""Prepare training features with sliding window tokenization.
Handles long documents by creating overlapping windows.
Maps answer spans to the correct window positions.
"""
pad_on_right = tokenizer.padding_side == "right"
tokenized = tokenizer(
examples["question"] if pad_on_right else examples["context"],
examples["context"] if pad_on_right else examples["question"],
truncation="only_second" if pad_on_right else "only_first",
max_length=max_length,
stride=doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized.pop("overflow_to_sample_mapping")
offset_mapping = tokenized.pop("offset_mapping")
tokenized["start_positions"] = []
tokenized["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
input_ids = tokenized["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
sequence_ids = tokenized.sequence_ids(i)
sample_index = sample_mapping[i]
answers = examples["answers"][sample_index]
if examples["is_impossible"][sample_index] or len(answers["answer_start"]) == 0:
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
else:
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1
if not (
offsets[token_start_index][0] <= start_char
and offsets[token_end_index][1] >= end_char
):
tokenized["start_positions"].append(cls_index)
tokenized["end_positions"].append(cls_index)
else:
while (
token_start_index < len(offsets)
and offsets[token_start_index][0] <= start_char
):
token_start_index += 1
tokenized["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized["end_positions"].append(token_end_index + 1)
return tokenized
def _tokenize_wrapper(examples, tokenizer, max_length, doc_stride):
"""Top-level wrapper so it can be pickled for multiprocessing on Windows."""
return prepare_train_features(examples, tokenizer, max_length, doc_stride)
def train(
train_path: str = "data/train.json",
test_path: str = "data/test.json",
base_model: str = "deepset/roberta-base-squad2",
output_dir: str = "ckpt_obligation",
epochs: int = 2,
batch_size: int = 16,
learning_rate: float = 3e-5,
max_length: int = 384,
doc_stride: int = 128,
max_train_samples: int = None,
device: str = "auto",
):
"""Fine-tune a QA model on CUAD data.
Args:
train_path: Path to CUAD train.json
test_path: Path to CUAD test.json
base_model: HuggingFace model name to fine-tune
output_dir: Directory to save the fine-tuned model
epochs: Number of training epochs (2 is enough when warm-starting from squad2)
batch_size: Training batch size (16 fits in RTX 4050 6GB VRAM with fp16)
learning_rate: Learning rate
max_length: Maximum sequence length for tokenizer
doc_stride: Sliding window stride for long documents
max_train_samples: Limit training samples (default 500, use -1 for all)
device: 'auto' (detect GPU), 'cuda', or 'cpu'
"""
# ─── Imports (done here to allow module to be imported without torch) ─
try:
import torch
from datasets import Dataset
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
TrainingArguments,
Trainer,
default_data_collator,
)
except ImportError as e:
print(f"Missing dependency: {e}")
print("Install with: pip install torch transformers datasets")
sys.exit(1)
# ─── Resolve device ──────────────────────────────────────────────────
from all_model_code.model_1_code.utils import get_device, get_safe_train_samples
device = get_device(device)
logger.info(f"Training device: {device}")
# ─── Handle max_train_samples ─────────────────────────────────────────
# -1 means "force all regardless of RAM" (user knows what they're doing)
force_all = (max_train_samples is not None and max_train_samples < 0)
if force_all:
max_train_samples = None
# ─── Load Data ───────────────────────────────────────────────────────
logger.info(f"Loading training data from {train_path}")
train_data = load_cuad_data(train_path)
train_examples = cuad_to_squad_examples(train_data)
logger.info(f" β†’ {len(train_examples)} training examples")
# Free the raw JSON immediately β€” it's huge and no longer needed
del train_data
gc.collect()
# NOTE: test data is NOT loaded here to save memory.
# Evaluation should be done separately via src.evaluate.
# ─── Auto-detect safe sample count if no explicit limit ──────────────
if max_train_samples is None and not force_all:
max_train_samples = get_safe_train_samples(len(train_examples))
# ─── Create HuggingFace Datasets ─────────────────────────────────────
# Convert to column-oriented format for HF datasets
def examples_to_columns(examples):
columns = {
"id": [], "question": [], "context": [],
"answers": [], "is_impossible": [],
}
for ex in examples:
columns["id"].append(ex["id"])
columns["question"].append(ex["question"])
columns["context"].append(ex["context"])
columns["answers"].append(ex["answers"])
columns["is_impossible"].append(ex["is_impossible"])
return columns
train_dataset = Dataset.from_dict(examples_to_columns(train_examples))
# Free the examples list β€” dataset holds the data now
del train_examples
gc.collect()
if max_train_samples and max_train_samples < len(train_dataset):
train_dataset = train_dataset.select(range(max_train_samples))
logger.info(f" Using {len(train_dataset)} training samples")
# ─── Load Tokenizer & Model ──────────────────────────────────────────
logger.info(f"Loading model: {base_model}")
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForQuestionAnswering.from_pretrained(base_model)
# ─── Tokenize (with disk cache) ────────────────────────────────────
# CUAD contracts are huge (~54K chars each Γ— 22K examples).
# Tokenization takes ~50 min, so we cache to disk after the first run.
# Cache is keyed by sample count so changing --max_train_samples auto-invalidates.
os.environ["TOKENIZERS_PARALLELISM"] = "true" # Rust-level threading
num_samples = len(train_dataset)
cache_dir = os.path.join(output_dir, "tokenized_cache")
cache_path = os.path.join(cache_dir, f"tokenized_train_{num_samples}")
if os.path.exists(cache_path):
from datasets import load_from_disk
logger.info(f"Loading cached tokenized data from {cache_path}")
tokenized_train = load_from_disk(cache_path)
logger.info(f" β†’ {len(tokenized_train)} cached features loaded instantly!")
else:
logger.info(f"Tokenizing {num_samples} training examples (sliding window) β€” this only happens once per sample count...")
tokenized_train = train_dataset.map(
lambda ex: prepare_train_features(ex, tokenizer, max_length, doc_stride),
batched=True,
batch_size=100, # small batch to limit peak memory
remove_columns=train_dataset.column_names,
desc="Tokenizing",
)
# Save to disk so next run skips this entirely
os.makedirs(cache_dir, exist_ok=True)
tokenized_train.save_to_disk(cache_path)
logger.info(f" β†’ Tokenized data cached to {cache_path}")
# Free the un-tokenized dataset
del train_dataset
gc.collect()
logger.info(f" β†’ {len(tokenized_train)} tokenized features (from sliding windows)")
# ─── Training Arguments ──────────────────────────────────────────────
use_gpu = (device == "cuda")
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=50,
save_strategy="epoch",
save_total_limit=1, # keep only 1 checkpoint to save disk/memory
fp16=use_gpu, # FP16 on GPU for ~2x memory savings
report_to="none",
use_cpu=(not use_gpu),
dataloader_num_workers=0, # avoid multiprocessing memory overhead on Windows
dataloader_pin_memory=use_gpu, # pin_memory speeds up GPU, wastes RAM on CPU
)
# ─── Trainer ─────────────────────────────────────────────────────────
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
tokenizer=tokenizer,
data_collator=default_data_collator,
)
# ─── Train ───────────────────────────────────────────────────────────
logger.info("Starting training...")
trainer.train()
# ─── Save ────────────────────────────────────────────────────────────
logger.info(f"Saving model to {output_dir}")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info("Training complete!")
return output_dir
def main():
parser = argparse.ArgumentParser(description="Fine-tune QA model on CUAD")
parser.add_argument("--train_path", default="data/train.json")
parser.add_argument("--test_path", default="data/test.json")
parser.add_argument("--base_model", default="deepset/roberta-base-squad2")
parser.add_argument("--output_dir", default="ckpt_obligation")
parser.add_argument("--epochs", type=int, default=2,
help="Training epochs (2 is enough when warm-starting from squad2)")
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size (16 fits RTX 4050 6GB VRAM with fp16)")
parser.add_argument("--learning_rate", type=float, default=3e-5)
parser.add_argument("--max_length", type=int, default=384)
parser.add_argument("--doc_stride", type=int, default=128)
parser.add_argument("--max_train_samples", type=int, default=None,
help="Limit training samples. Default: auto-detect based on RAM. Use -1 to force ALL.")
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"],
help="Device: 'auto' detects GPU, 'cuda' forces GPU, 'cpu' forces CPU")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
train(**vars(args))
if __name__ == "__main__":
main()