import sys from pathlib import Path import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments BASE_DIR = Path(__file__).resolve().parent.parent if str(BASE_DIR) not in sys.path: sys.path.insert(0, str(BASE_DIR)) from config import IAB_HEAD_CONFIG from training.common import ( build_balanced_class_weights, compute_classification_metrics, load_labeled_rows, prepare_dataset, write_json, ) class WeightedTrainer(Trainer): def __init__(self, *args, class_weights: torch.Tensor | None = None, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.get("logits") weight = self.class_weights.to(logits.device) if self.class_weights is not None else None loss_fct = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss train_rows = load_labeled_rows( IAB_HEAD_CONFIG.split_paths["train"], IAB_HEAD_CONFIG.label_field, IAB_HEAD_CONFIG.label2id, ) val_rows = load_labeled_rows( IAB_HEAD_CONFIG.split_paths["val"], IAB_HEAD_CONFIG.label_field, IAB_HEAD_CONFIG.label2id, ) test_rows = load_labeled_rows( IAB_HEAD_CONFIG.split_paths["test"], IAB_HEAD_CONFIG.label_field, IAB_HEAD_CONFIG.label2id, ) tokenizer = AutoTokenizer.from_pretrained(IAB_HEAD_CONFIG.model_name) train_dataset = prepare_dataset(train_rows, tokenizer, IAB_HEAD_CONFIG.max_length) val_dataset = prepare_dataset(val_rows, tokenizer, IAB_HEAD_CONFIG.max_length) test_dataset = prepare_dataset(test_rows, tokenizer, IAB_HEAD_CONFIG.max_length) class_weights = build_balanced_class_weights(train_rows, len(IAB_HEAD_CONFIG.labels)) model = AutoModelForSequenceClassification.from_pretrained( IAB_HEAD_CONFIG.model_name, num_labels=len(IAB_HEAD_CONFIG.labels), id2label=IAB_HEAD_CONFIG.id2label, label2id=IAB_HEAD_CONFIG.label2id, ) training_args = TrainingArguments( output_dir=str(IAB_HEAD_CONFIG.model_dir), eval_strategy="epoch", save_strategy="no", logging_strategy="epoch", num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=16, learning_rate=2e-5, weight_decay=0.01, report_to="none", ) trainer = WeightedTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_classification_metrics, class_weights=class_weights, ) print(f"Loaded IAB splits: train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}") print( "IAB class weights summary:", { "min": round(float(class_weights.min().item()), 4), "max": round(float(class_weights.max().item()), 4), "mean": round(float(class_weights.mean().item()), 4), }, ) trainer.train() val_metrics = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="val") test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test") print(val_metrics) print(test_metrics) IAB_HEAD_CONFIG.model_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(IAB_HEAD_CONFIG.model_dir) tokenizer.save_pretrained(IAB_HEAD_CONFIG.model_dir) write_json( IAB_HEAD_CONFIG.model_dir / "train_metrics.json", { "head": IAB_HEAD_CONFIG.slug, "train_count": len(train_rows), "val_count": len(val_rows), "test_count": len(test_rows), "label_count": len(IAB_HEAD_CONFIG.labels), "val_metrics": val_metrics, "test_metrics": test_metrics, }, )