""" Train BERT or DistilBERT or DeBERTa on combined sentence-pair boundary classification. Uses HuggingFace Trainer and TrainingArguments natively. Class imbalance is handled via a WeightedTrainer subclass that overrides compute_loss. Usage: python -m src.models.train --model distilbert --out checkpoints/distilbert python -m src.models.train --model bert --out checkpoints/bert --epochs 5 --lr 3e-5 """ import argparse import json import logging import os from pathlib import Path import numpy as np import torch import torch.nn as nn import wandb from dotenv import load_dotenv from sklearn.metrics import f1_score, matthews_corrcoef load_dotenv() wandb.login(key=os.getenv("WB_TOKEN")) from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, EarlyStoppingCallback, Trainer, TrainingArguments, ) from src.datasets.combined_pairs_dataset import ( CombinedPairsDataset, CombinedPairsConfig, NUM_LABELS, ID2LABEL, LABEL2ID, ) from src.models.bert import load_bert, load_bert_tokenizer from src.models.deberta import load_deberta, load_deberta_tokenizer from src.models.distilbert import load_distilbert, load_distilbert_tokenizer from src.schemas.training_args import BertTrainingArgs, DebertaTrainingArgs, DistilBertTrainingArgs logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) MODEL_REGISTRY = { "bert": (load_bert, load_bert_tokenizer, BertTrainingArgs), "distilbert": (load_distilbert, load_distilbert_tokenizer, DistilBertTrainingArgs), "deberta": (load_deberta, load_deberta_tokenizer, DebertaTrainingArgs), } # ───────────────────────────────────────────────────────────────────────────── # Weighted Trainer # ───────────────────────────────────────────────────────────────────────────── class WeightedTrainer(Trainer): """Trainer with weighted cross-entropy loss for class imbalance.""" def __init__(self, class_weights: torch.Tensor | None = None, **kwargs): super().__init__(**kwargs) self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits if self.class_weights is not None: weight = self.class_weights.to(logits.device) else: weight = None loss = nn.functional.cross_entropy(logits, labels, weight=weight) return (loss, outputs) if return_outputs else loss # ───────────────────────────────────────────────────────────────────────────── # Metrics # ───────────────────────────────────────────────────────────────────────────── def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) macro_f1 = f1_score(labels, preds, average="macro") weighted_f1 = f1_score(labels, preds, average="weighted") mcc = matthews_corrcoef(labels, preds) per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2]) return { "macro_f1": macro_f1, "weighted_f1": weighted_f1, "mcc": mcc, "f1_same_para": per_class[0], "f1_new_para": per_class[1], "f1_newline": per_class[2], } # ───────────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────────── def main() -> None: parser = argparse.ArgumentParser(description="Train sentence-pair boundary classifier.") parser.add_argument("--model", choices=["bert", "distilbert", "deberta"], default="distilbert") parser.add_argument("--out", help="Output directory (overrides dataclass default)") parser.add_argument("--data_root", default="data") parser.add_argument("--epochs", type=int) parser.add_argument("--batch_size", type=int) parser.add_argument("--lr", type=float) parser.add_argument("--weight_decay", type=float) parser.add_argument("--warmup_ratio", type=float) parser.add_argument("--max_length", type=int) parser.add_argument("--gutenberg_cap", type=int) parser.add_argument("--seed", type=int) parser.add_argument("--bf16", action="store_true") parser.add_argument("--patience", type=int) args = parser.parse_args() # ── build training args from dataclass + CLI overrides ────────────── model_loader, tokenizer_loader, args_cls = MODEL_REGISTRY[args.model] override = {} for field in ("output_dir", "epochs", "batch_size", "lr", "weight_decay", "warmup_ratio", "max_length", "gutenberg_cap", "seed", "bf16", "patience"): cli_key = "out" if field == "output_dir" else field val = getattr(args, cli_key, None) if val is not None: override[field] = val train_args = args_cls(**override) out_dir = Path(train_args.output_dir) # ── wandb ────────────────────────────────────────────────────────── os.environ["WANDB_PROJECT"] = "bottlecap" os.environ["WANDB_RUN_NAME"] = args.model # ── model + tokenizer ─────────────────────────────────────────────── model = model_loader() tokenizer = tokenizer_loader() log.info(f"Model: {args.model} ({sum(p.numel() for p in model.parameters()):,} params)") # ── data ──────────────────────────────────────────────────────────── cfg = CombinedPairsConfig( data_root=args.data_root, gutenberg_train_cap=train_args.gutenberg_cap, seed=train_args.seed, max_length=train_args.max_length, ) builder = CombinedPairsDataset(cfg) log.info("Building splits and tokenizing ...") raw_splits = builder.build_splits() class_weights = builder.compute_class_weights(raw_splits["train"]) log.info(f"Class weights: {class_weights.tolist()}") dd = builder.build_hf_dataset_dict(tokenizer, raw_splits=raw_splits) # ── training arguments ────────────────────────────────────────────── training_args = train_args.to_training_arguments() callbacks = [] if train_args.patience > 0: callbacks.append(EarlyStoppingCallback(early_stopping_patience=train_args.patience)) # ── trainer ───────────────────────────────────────────────────────── trainer = WeightedTrainer( class_weights=class_weights, model=model, args=training_args, train_dataset=dd["train"], eval_dataset=dd["val"], compute_metrics=compute_metrics, callbacks=callbacks, ) log.info("Starting training ...") trainer.train() # ── save best ─────────────────────────────────────────────────────── best_dir = out_dir / "best" best_dir.mkdir(parents=True, exist_ok=True) trainer.save_model(str(best_dir)) tokenizer.save_pretrained(str(best_dir)) torch.save(class_weights, best_dir / "class_weights.pt") train_config = { "model_type": args.model, "pretrained": model.config._name_or_path, "epochs": train_args.epochs, "batch_size": train_args.batch_size, "lr": train_args.lr, "max_length": train_args.max_length, "class_weights": class_weights.tolist(), "num_labels": NUM_LABELS, "id2label": ID2LABEL, "label2id": LABEL2ID, } with open(best_dir / "train_config.json", "w") as f: json.dump(train_config, f, indent=2) log.info(f"Best model saved to {best_dir}") # ── final eval ────────────────────────────────────────────────────── metrics = trainer.evaluate() log.info(f"Val metrics: {metrics}") with open(out_dir / "val_metrics.json", "w") as f: json.dump(metrics, f, indent=2) log.info("Done.") if __name__ == "__main__": main()