| """ |
| Train BERT or DistilBERT or DeBERTa on combined sentence-pair boundary classification. |
| |
| Uses HuggingFace Trainer and TrainingArguments natively. Class imbalance |
| is handled via a WeightedTrainer subclass that overrides compute_loss. |
| |
| Usage: |
| python -m src.models.train --model distilbert --out checkpoints/distilbert |
| python -m src.models.train --model bert --out checkpoints/bert --epochs 5 --lr 3e-5 |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import wandb |
| from dotenv import load_dotenv |
| from sklearn.metrics import f1_score, matthews_corrcoef |
|
|
| load_dotenv() |
| wandb.login(key=os.getenv("WB_TOKEN")) |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| EarlyStoppingCallback, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| from src.datasets.combined_pairs_dataset import ( |
| CombinedPairsDataset, |
| CombinedPairsConfig, |
| NUM_LABELS, |
| ID2LABEL, |
| LABEL2ID, |
| ) |
| from src.models.bert import load_bert, load_bert_tokenizer |
| from src.models.deberta import load_deberta, load_deberta_tokenizer |
| from src.models.distilbert import load_distilbert, load_distilbert_tokenizer |
| from src.schemas.training_args import BertTrainingArgs, DebertaTrainingArgs, DistilBertTrainingArgs |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| MODEL_REGISTRY = { |
| "bert": (load_bert, load_bert_tokenizer, BertTrainingArgs), |
| "distilbert": (load_distilbert, load_distilbert_tokenizer, DistilBertTrainingArgs), |
| "deberta": (load_deberta, load_deberta_tokenizer, DebertaTrainingArgs), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class WeightedTrainer(Trainer): |
| """Trainer with weighted cross-entropy loss for class imbalance.""" |
|
|
| def __init__(self, class_weights: torch.Tensor | None = None, **kwargs): |
| super().__init__(**kwargs) |
| self.class_weights = class_weights |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels = inputs.pop("labels") |
| outputs = model(**inputs) |
| logits = outputs.logits |
|
|
| if self.class_weights is not None: |
| weight = self.class_weights.to(logits.device) |
| else: |
| weight = None |
|
|
| loss = nn.functional.cross_entropy(logits, labels, weight=weight) |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| |
| |
| |
|
|
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| macro_f1 = f1_score(labels, preds, average="macro") |
| weighted_f1 = f1_score(labels, preds, average="weighted") |
| mcc = matthews_corrcoef(labels, preds) |
| per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2]) |
| return { |
| "macro_f1": macro_f1, |
| "weighted_f1": weighted_f1, |
| "mcc": mcc, |
| "f1_same_para": per_class[0], |
| "f1_new_para": per_class[1], |
| "f1_newline": per_class[2], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Train sentence-pair boundary classifier.") |
| parser.add_argument("--model", choices=["bert", "distilbert", "deberta"], default="distilbert") |
| parser.add_argument("--out", help="Output directory (overrides dataclass default)") |
| parser.add_argument("--data_root", default="data") |
| parser.add_argument("--epochs", type=int) |
| parser.add_argument("--batch_size", type=int) |
| parser.add_argument("--lr", type=float) |
| parser.add_argument("--weight_decay", type=float) |
| parser.add_argument("--warmup_ratio", type=float) |
| parser.add_argument("--max_length", type=int) |
| parser.add_argument("--gutenberg_cap", type=int) |
| parser.add_argument("--seed", type=int) |
| parser.add_argument("--bf16", action="store_true") |
| parser.add_argument("--patience", type=int) |
| args = parser.parse_args() |
|
|
| |
| model_loader, tokenizer_loader, args_cls = MODEL_REGISTRY[args.model] |
|
|
| override = {} |
| for field in ("output_dir", "epochs", "batch_size", "lr", "weight_decay", |
| "warmup_ratio", "max_length", "gutenberg_cap", "seed", "bf16", "patience"): |
| cli_key = "out" if field == "output_dir" else field |
| val = getattr(args, cli_key, None) |
| if val is not None: |
| override[field] = val |
| train_args = args_cls(**override) |
|
|
| out_dir = Path(train_args.output_dir) |
|
|
| |
| os.environ["WANDB_PROJECT"] = "bottlecap" |
| os.environ["WANDB_RUN_NAME"] = args.model |
|
|
| |
| model = model_loader() |
| tokenizer = tokenizer_loader() |
|
|
| log.info(f"Model: {args.model} ({sum(p.numel() for p in model.parameters()):,} params)") |
|
|
| |
| cfg = CombinedPairsConfig( |
| data_root=args.data_root, |
| gutenberg_train_cap=train_args.gutenberg_cap, |
| seed=train_args.seed, |
| max_length=train_args.max_length, |
| ) |
| builder = CombinedPairsDataset(cfg) |
|
|
| log.info("Building splits and tokenizing ...") |
| raw_splits = builder.build_splits() |
| class_weights = builder.compute_class_weights(raw_splits["train"]) |
| log.info(f"Class weights: {class_weights.tolist()}") |
|
|
| dd = builder.build_hf_dataset_dict(tokenizer, raw_splits=raw_splits) |
|
|
| |
| training_args = train_args.to_training_arguments() |
|
|
| callbacks = [] |
| if train_args.patience > 0: |
| callbacks.append(EarlyStoppingCallback(early_stopping_patience=train_args.patience)) |
|
|
| |
| trainer = WeightedTrainer( |
| class_weights=class_weights, |
| model=model, |
| args=training_args, |
| train_dataset=dd["train"], |
| eval_dataset=dd["val"], |
| compute_metrics=compute_metrics, |
| callbacks=callbacks, |
| ) |
|
|
| log.info("Starting training ...") |
| trainer.train() |
|
|
| |
| best_dir = out_dir / "best" |
| best_dir.mkdir(parents=True, exist_ok=True) |
|
|
| trainer.save_model(str(best_dir)) |
| tokenizer.save_pretrained(str(best_dir)) |
| torch.save(class_weights, best_dir / "class_weights.pt") |
|
|
| train_config = { |
| "model_type": args.model, |
| "pretrained": model.config._name_or_path, |
| "epochs": train_args.epochs, |
| "batch_size": train_args.batch_size, |
| "lr": train_args.lr, |
| "max_length": train_args.max_length, |
| "class_weights": class_weights.tolist(), |
| "num_labels": NUM_LABELS, |
| "id2label": ID2LABEL, |
| "label2id": LABEL2ID, |
| } |
| with open(best_dir / "train_config.json", "w") as f: |
| json.dump(train_config, f, indent=2) |
|
|
| log.info(f"Best model saved to {best_dir}") |
|
|
| |
| metrics = trainer.evaluate() |
| log.info(f"Val metrics: {metrics}") |
|
|
| with open(out_dir / "val_metrics.json", "w") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| log.info("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|