Spaces:
Running
Running
| """ | |
| QA Model Training Script | |
| ========================= | |
| Fine-tunes a QA model (roberta-base or squad2 warm-start) on the CUAD dataset. | |
| CUAD contexts are full contracts (avg 54K chars). | |
| We use sliding window tokenization to create 384-token windows | |
| with 128-token overlap β this is how SQuAD-style models handle long documents. | |
| Memory Safety: | |
| - Default batch_size=2 to avoid RAM exhaustion on laptops | |
| - Default max_train_samples=500 (use --max_train_samples=-1 for full dataset) | |
| - Aggressive garbage collection after data transformations | |
| - Tokenization uses small batch sizes to limit peak memory | |
| Usage: | |
| python -m src.train_qa # defaults (safe) | |
| python -m src.train_qa --epochs 3 --batch_size 4 | |
| python -m src.train_qa --base_model deepset/roberta-base-squad2 | |
| python -m src.train_qa --max_train_samples -1 # full dataset (needs 16GB+ RAM) | |
| """ | |
| import argparse | |
| import functools | |
| import gc | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from typing import Any, Dict, List, Tuple | |
| logger = logging.getLogger(__name__) | |
| def load_cuad_data(filepath: str) -> Dict: | |
| """Load a CUAD JSON file (SQuAD 2.0 format).""" | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return data | |
| def cuad_to_squad_examples(data: Dict) -> List[Dict]: | |
| """Convert CUAD data to a flat list of SQuAD-style examples. | |
| Each example has: id, question, context, answers, is_impossible. | |
| """ | |
| examples = [] | |
| for article in data["data"]: | |
| title = article.get("title", "") | |
| for paragraph in article["paragraphs"]: | |
| context = paragraph["context"] | |
| for qa in paragraph["qas"]: | |
| example = { | |
| "id": qa["id"], | |
| "title": title, | |
| "question": qa["question"], | |
| "context": context, | |
| "answers": { | |
| "text": [a["text"] for a in qa.get("answers", [])], | |
| "answer_start": [a["answer_start"] for a in qa.get("answers", [])], | |
| }, | |
| "is_impossible": qa.get("is_impossible", False), | |
| } | |
| examples.append(example) | |
| return examples | |
| def prepare_train_features(examples, tokenizer, max_length=384, doc_stride=128): | |
| """Prepare training features with sliding window tokenization. | |
| Handles long documents by creating overlapping windows. | |
| Maps answer spans to the correct window positions. | |
| """ | |
| pad_on_right = tokenizer.padding_side == "right" | |
| tokenized = tokenizer( | |
| examples["question"] if pad_on_right else examples["context"], | |
| examples["context"] if pad_on_right else examples["question"], | |
| truncation="only_second" if pad_on_right else "only_first", | |
| max_length=max_length, | |
| stride=doc_stride, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| ) | |
| sample_mapping = tokenized.pop("overflow_to_sample_mapping") | |
| offset_mapping = tokenized.pop("offset_mapping") | |
| tokenized["start_positions"] = [] | |
| tokenized["end_positions"] = [] | |
| for i, offsets in enumerate(offset_mapping): | |
| input_ids = tokenized["input_ids"][i] | |
| cls_index = input_ids.index(tokenizer.cls_token_id) | |
| sequence_ids = tokenized.sequence_ids(i) | |
| sample_index = sample_mapping[i] | |
| answers = examples["answers"][sample_index] | |
| if examples["is_impossible"][sample_index] or len(answers["answer_start"]) == 0: | |
| tokenized["start_positions"].append(cls_index) | |
| tokenized["end_positions"].append(cls_index) | |
| else: | |
| start_char = answers["answer_start"][0] | |
| end_char = start_char + len(answers["text"][0]) | |
| token_start_index = 0 | |
| while sequence_ids[token_start_index] != (1 if pad_on_right else 0): | |
| token_start_index += 1 | |
| token_end_index = len(input_ids) - 1 | |
| while sequence_ids[token_end_index] != (1 if pad_on_right else 0): | |
| token_end_index -= 1 | |
| if not ( | |
| offsets[token_start_index][0] <= start_char | |
| and offsets[token_end_index][1] >= end_char | |
| ): | |
| tokenized["start_positions"].append(cls_index) | |
| tokenized["end_positions"].append(cls_index) | |
| else: | |
| while ( | |
| token_start_index < len(offsets) | |
| and offsets[token_start_index][0] <= start_char | |
| ): | |
| token_start_index += 1 | |
| tokenized["start_positions"].append(token_start_index - 1) | |
| while offsets[token_end_index][1] >= end_char: | |
| token_end_index -= 1 | |
| tokenized["end_positions"].append(token_end_index + 1) | |
| return tokenized | |
| def _tokenize_wrapper(examples, tokenizer, max_length, doc_stride): | |
| """Top-level wrapper so it can be pickled for multiprocessing on Windows.""" | |
| return prepare_train_features(examples, tokenizer, max_length, doc_stride) | |
| def train( | |
| train_path: str = "data/train.json", | |
| test_path: str = "data/test.json", | |
| base_model: str = "deepset/roberta-base-squad2", | |
| output_dir: str = "ckpt_obligation", | |
| epochs: int = 2, | |
| batch_size: int = 16, | |
| learning_rate: float = 3e-5, | |
| max_length: int = 384, | |
| doc_stride: int = 128, | |
| max_train_samples: int = None, | |
| device: str = "auto", | |
| ): | |
| """Fine-tune a QA model on CUAD data. | |
| Args: | |
| train_path: Path to CUAD train.json | |
| test_path: Path to CUAD test.json | |
| base_model: HuggingFace model name to fine-tune | |
| output_dir: Directory to save the fine-tuned model | |
| epochs: Number of training epochs (2 is enough when warm-starting from squad2) | |
| batch_size: Training batch size (16 fits in RTX 4050 6GB VRAM with fp16) | |
| learning_rate: Learning rate | |
| max_length: Maximum sequence length for tokenizer | |
| doc_stride: Sliding window stride for long documents | |
| max_train_samples: Limit training samples (default 500, use -1 for all) | |
| device: 'auto' (detect GPU), 'cuda', or 'cpu' | |
| """ | |
| # βββ Imports (done here to allow module to be imported without torch) β | |
| try: | |
| import torch | |
| from datasets import Dataset | |
| from transformers import ( | |
| AutoModelForQuestionAnswering, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| default_data_collator, | |
| ) | |
| except ImportError as e: | |
| print(f"Missing dependency: {e}") | |
| print("Install with: pip install torch transformers datasets") | |
| sys.exit(1) | |
| # βββ Resolve device ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from all_model_code.model_1_code.utils import get_device, get_safe_train_samples | |
| device = get_device(device) | |
| logger.info(f"Training device: {device}") | |
| # βββ Handle max_train_samples βββββββββββββββββββββββββββββββββββββββββ | |
| # -1 means "force all regardless of RAM" (user knows what they're doing) | |
| force_all = (max_train_samples is not None and max_train_samples < 0) | |
| if force_all: | |
| max_train_samples = None | |
| # βββ Load Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info(f"Loading training data from {train_path}") | |
| train_data = load_cuad_data(train_path) | |
| train_examples = cuad_to_squad_examples(train_data) | |
| logger.info(f" β {len(train_examples)} training examples") | |
| # Free the raw JSON immediately β it's huge and no longer needed | |
| del train_data | |
| gc.collect() | |
| # NOTE: test data is NOT loaded here to save memory. | |
| # Evaluation should be done separately via src.evaluate. | |
| # βββ Auto-detect safe sample count if no explicit limit ββββββββββββββ | |
| if max_train_samples is None and not force_all: | |
| max_train_samples = get_safe_train_samples(len(train_examples)) | |
| # βββ Create HuggingFace Datasets βββββββββββββββββββββββββββββββββββββ | |
| # Convert to column-oriented format for HF datasets | |
| def examples_to_columns(examples): | |
| columns = { | |
| "id": [], "question": [], "context": [], | |
| "answers": [], "is_impossible": [], | |
| } | |
| for ex in examples: | |
| columns["id"].append(ex["id"]) | |
| columns["question"].append(ex["question"]) | |
| columns["context"].append(ex["context"]) | |
| columns["answers"].append(ex["answers"]) | |
| columns["is_impossible"].append(ex["is_impossible"]) | |
| return columns | |
| train_dataset = Dataset.from_dict(examples_to_columns(train_examples)) | |
| # Free the examples list β dataset holds the data now | |
| del train_examples | |
| gc.collect() | |
| if max_train_samples and max_train_samples < len(train_dataset): | |
| train_dataset = train_dataset.select(range(max_train_samples)) | |
| logger.info(f" Using {len(train_dataset)} training samples") | |
| # βββ Load Tokenizer & Model ββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info(f"Loading model: {base_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| model = AutoModelForQuestionAnswering.from_pretrained(base_model) | |
| # βββ Tokenize (with disk cache) ββββββββββββββββββββββββββββββββββββ | |
| # CUAD contracts are huge (~54K chars each Γ 22K examples). | |
| # Tokenization takes ~50 min, so we cache to disk after the first run. | |
| # Cache is keyed by sample count so changing --max_train_samples auto-invalidates. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" # Rust-level threading | |
| num_samples = len(train_dataset) | |
| cache_dir = os.path.join(output_dir, "tokenized_cache") | |
| cache_path = os.path.join(cache_dir, f"tokenized_train_{num_samples}") | |
| if os.path.exists(cache_path): | |
| from datasets import load_from_disk | |
| logger.info(f"Loading cached tokenized data from {cache_path}") | |
| tokenized_train = load_from_disk(cache_path) | |
| logger.info(f" β {len(tokenized_train)} cached features loaded instantly!") | |
| else: | |
| logger.info(f"Tokenizing {num_samples} training examples (sliding window) β this only happens once per sample count...") | |
| tokenized_train = train_dataset.map( | |
| lambda ex: prepare_train_features(ex, tokenizer, max_length, doc_stride), | |
| batched=True, | |
| batch_size=100, # small batch to limit peak memory | |
| remove_columns=train_dataset.column_names, | |
| desc="Tokenizing", | |
| ) | |
| # Save to disk so next run skips this entirely | |
| os.makedirs(cache_dir, exist_ok=True) | |
| tokenized_train.save_to_disk(cache_path) | |
| logger.info(f" β Tokenized data cached to {cache_path}") | |
| # Free the un-tokenized dataset | |
| del train_dataset | |
| gc.collect() | |
| logger.info(f" β {len(tokenized_train)} tokenized features (from sliding windows)") | |
| # βββ Training Arguments ββββββββββββββββββββββββββββββββββββββββββββββ | |
| use_gpu = (device == "cuda") | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| weight_decay=0.01, | |
| warmup_ratio=0.1, | |
| logging_steps=50, | |
| save_strategy="epoch", | |
| save_total_limit=1, # keep only 1 checkpoint to save disk/memory | |
| fp16=use_gpu, # FP16 on GPU for ~2x memory savings | |
| report_to="none", | |
| use_cpu=(not use_gpu), | |
| dataloader_num_workers=0, # avoid multiprocessing memory overhead on Windows | |
| dataloader_pin_memory=use_gpu, # pin_memory speeds up GPU, wastes RAM on CPU | |
| ) | |
| # βββ Trainer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_train, | |
| tokenizer=tokenizer, | |
| data_collator=default_data_collator, | |
| ) | |
| # βββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Starting training...") | |
| trainer.train() | |
| # βββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info(f"Saving model to {output_dir}") | |
| trainer.save_model(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| logger.info("Training complete!") | |
| return output_dir | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Fine-tune QA model on CUAD") | |
| parser.add_argument("--train_path", default="data/train.json") | |
| parser.add_argument("--test_path", default="data/test.json") | |
| parser.add_argument("--base_model", default="deepset/roberta-base-squad2") | |
| parser.add_argument("--output_dir", default="ckpt_obligation") | |
| parser.add_argument("--epochs", type=int, default=2, | |
| help="Training epochs (2 is enough when warm-starting from squad2)") | |
| parser.add_argument("--batch_size", type=int, default=16, | |
| help="Batch size (16 fits RTX 4050 6GB VRAM with fp16)") | |
| parser.add_argument("--learning_rate", type=float, default=3e-5) | |
| parser.add_argument("--max_length", type=int, default=384) | |
| parser.add_argument("--doc_stride", type=int, default=128) | |
| parser.add_argument("--max_train_samples", type=int, default=None, | |
| help="Limit training samples. Default: auto-detect based on RAM. Use -1 to force ALL.") | |
| parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"], | |
| help="Device: 'auto' detects GPU, 'cuda' forces GPU, 'cpu' forces CPU") | |
| args = parser.parse_args() | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| train(**vars(args)) | |
| if __name__ == "__main__": | |
| main() | |