bc-test / src /models /train.py
lamossta's picture
training and models
945de56
"""
Train BERT or DistilBERT or DeBERTa on combined sentence-pair boundary classification.
Uses HuggingFace Trainer and TrainingArguments natively. Class imbalance
is handled via a WeightedTrainer subclass that overrides compute_loss.
Usage:
python -m src.models.train --model distilbert --out checkpoints/distilbert
python -m src.models.train --model bert --out checkpoints/bert --epochs 5 --lr 3e-5
"""
import argparse
import json
import logging
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import wandb
from dotenv import load_dotenv
from sklearn.metrics import f1_score, matthews_corrcoef
load_dotenv()
wandb.login(key=os.getenv("WB_TOKEN"))
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
from src.datasets.combined_pairs_dataset import (
CombinedPairsDataset,
CombinedPairsConfig,
NUM_LABELS,
ID2LABEL,
LABEL2ID,
)
from src.models.bert import load_bert, load_bert_tokenizer
from src.models.deberta import load_deberta, load_deberta_tokenizer
from src.models.distilbert import load_distilbert, load_distilbert_tokenizer
from src.schemas.training_args import BertTrainingArgs, DebertaTrainingArgs, DistilBertTrainingArgs
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
MODEL_REGISTRY = {
"bert": (load_bert, load_bert_tokenizer, BertTrainingArgs),
"distilbert": (load_distilbert, load_distilbert_tokenizer, DistilBertTrainingArgs),
"deberta": (load_deberta, load_deberta_tokenizer, DebertaTrainingArgs),
}
# ─────────────────────────────────────────────────────────────────────────────
# Weighted Trainer
# ─────────────────────────────────────────────────────────────────────────────
class WeightedTrainer(Trainer):
"""Trainer with weighted cross-entropy loss for class imbalance."""
def __init__(self, class_weights: torch.Tensor | None = None, **kwargs):
super().__init__(**kwargs)
self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.class_weights is not None:
weight = self.class_weights.to(logits.device)
else:
weight = None
loss = nn.functional.cross_entropy(logits, labels, weight=weight)
return (loss, outputs) if return_outputs else loss
# ─────────────────────────────────────────────────────────────────────────────
# Metrics
# ─────────────────────────────────────────────────────────────────────────────
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
macro_f1 = f1_score(labels, preds, average="macro")
weighted_f1 = f1_score(labels, preds, average="weighted")
mcc = matthews_corrcoef(labels, preds)
per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2])
return {
"macro_f1": macro_f1,
"weighted_f1": weighted_f1,
"mcc": mcc,
"f1_same_para": per_class[0],
"f1_new_para": per_class[1],
"f1_newline": per_class[2],
}
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description="Train sentence-pair boundary classifier.")
parser.add_argument("--model", choices=["bert", "distilbert", "deberta"], default="distilbert")
parser.add_argument("--out", help="Output directory (overrides dataclass default)")
parser.add_argument("--data_root", default="data")
parser.add_argument("--epochs", type=int)
parser.add_argument("--batch_size", type=int)
parser.add_argument("--lr", type=float)
parser.add_argument("--weight_decay", type=float)
parser.add_argument("--warmup_ratio", type=float)
parser.add_argument("--max_length", type=int)
parser.add_argument("--gutenberg_cap", type=int)
parser.add_argument("--seed", type=int)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--patience", type=int)
args = parser.parse_args()
# ── build training args from dataclass + CLI overrides ──────────────
model_loader, tokenizer_loader, args_cls = MODEL_REGISTRY[args.model]
override = {}
for field in ("output_dir", "epochs", "batch_size", "lr", "weight_decay",
"warmup_ratio", "max_length", "gutenberg_cap", "seed", "bf16", "patience"):
cli_key = "out" if field == "output_dir" else field
val = getattr(args, cli_key, None)
if val is not None:
override[field] = val
train_args = args_cls(**override)
out_dir = Path(train_args.output_dir)
# ── wandb ──────────────────────────────────────────────────────────
os.environ["WANDB_PROJECT"] = "bottlecap"
os.environ["WANDB_RUN_NAME"] = args.model
# ── model + tokenizer ───────────────────────────────────────────────
model = model_loader()
tokenizer = tokenizer_loader()
log.info(f"Model: {args.model} ({sum(p.numel() for p in model.parameters()):,} params)")
# ── data ────────────────────────────────────────────────────────────
cfg = CombinedPairsConfig(
data_root=args.data_root,
gutenberg_train_cap=train_args.gutenberg_cap,
seed=train_args.seed,
max_length=train_args.max_length,
)
builder = CombinedPairsDataset(cfg)
log.info("Building splits and tokenizing ...")
raw_splits = builder.build_splits()
class_weights = builder.compute_class_weights(raw_splits["train"])
log.info(f"Class weights: {class_weights.tolist()}")
dd = builder.build_hf_dataset_dict(tokenizer, raw_splits=raw_splits)
# ── training arguments ──────────────────────────────────────────────
training_args = train_args.to_training_arguments()
callbacks = []
if train_args.patience > 0:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=train_args.patience))
# ── trainer ─────────────────────────────────────────────────────────
trainer = WeightedTrainer(
class_weights=class_weights,
model=model,
args=training_args,
train_dataset=dd["train"],
eval_dataset=dd["val"],
compute_metrics=compute_metrics,
callbacks=callbacks,
)
log.info("Starting training ...")
trainer.train()
# ── save best ───────────────────────────────────────────────────────
best_dir = out_dir / "best"
best_dir.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(best_dir))
tokenizer.save_pretrained(str(best_dir))
torch.save(class_weights, best_dir / "class_weights.pt")
train_config = {
"model_type": args.model,
"pretrained": model.config._name_or_path,
"epochs": train_args.epochs,
"batch_size": train_args.batch_size,
"lr": train_args.lr,
"max_length": train_args.max_length,
"class_weights": class_weights.tolist(),
"num_labels": NUM_LABELS,
"id2label": ID2LABEL,
"label2id": LABEL2ID,
}
with open(best_dir / "train_config.json", "w") as f:
json.dump(train_config, f, indent=2)
log.info(f"Best model saved to {best_dir}")
# ── final eval ──────────────────────────────────────────────────────
metrics = trainer.evaluate()
log.info(f"Val metrics: {metrics}")
with open(out_dir / "val_metrics.json", "w") as f:
json.dump(metrics, f, indent=2)
log.info("Done.")
if __name__ == "__main__":
main()