import sys from pathlib import Path import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments BASE_DIR = Path(__file__).resolve().parent.parent if str(BASE_DIR) not in sys.path: sys.path.insert(0, str(BASE_DIR)) from config import ( DECISION_PHASE_DIFFICULTY_DATA_DIR, DECISION_PHASE_HEAD_CONFIG, DECISION_PHASE_TRAINING_WEIGHTS, FULL_INTENT_TAXONOMY_DATA_DIR, ) from training.common import ( build_label_weight_tensor, compute_classification_metrics, load_labeled_rows, load_labeled_rows_from_paths, prepare_dataset, write_json, ) class WeightedTrainer(Trainer): def __init__(self, *args, class_weights: torch.Tensor | None = None, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.get("logits") weight = self.class_weights.to(logits.device) if self.class_weights is not None else None loss_fct = torch.nn.CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss train_rows = load_labeled_rows_from_paths( [ DECISION_PHASE_HEAD_CONFIG.split_paths["train"], FULL_INTENT_TAXONOMY_DATA_DIR / "train.jsonl", DECISION_PHASE_DIFFICULTY_DATA_DIR / "train.jsonl", ], DECISION_PHASE_HEAD_CONFIG.label_field, DECISION_PHASE_HEAD_CONFIG.label2id, ) val_rows = load_labeled_rows_from_paths( [ DECISION_PHASE_HEAD_CONFIG.split_paths["val"], FULL_INTENT_TAXONOMY_DATA_DIR / "val.jsonl", DECISION_PHASE_DIFFICULTY_DATA_DIR / "val.jsonl", ], DECISION_PHASE_HEAD_CONFIG.label_field, DECISION_PHASE_HEAD_CONFIG.label2id, ) test_rows = load_labeled_rows( DECISION_PHASE_HEAD_CONFIG.split_paths["test"], DECISION_PHASE_HEAD_CONFIG.label_field, DECISION_PHASE_HEAD_CONFIG.label2id, ) tokenizer = AutoTokenizer.from_pretrained(DECISION_PHASE_HEAD_CONFIG.model_name) train_dataset = prepare_dataset(train_rows, tokenizer, DECISION_PHASE_HEAD_CONFIG.max_length) val_dataset = prepare_dataset(val_rows, tokenizer, DECISION_PHASE_HEAD_CONFIG.max_length) test_dataset = prepare_dataset(test_rows, tokenizer, DECISION_PHASE_HEAD_CONFIG.max_length) class_weights = build_label_weight_tensor(DECISION_PHASE_HEAD_CONFIG.labels, DECISION_PHASE_TRAINING_WEIGHTS) model = AutoModelForSequenceClassification.from_pretrained( DECISION_PHASE_HEAD_CONFIG.model_name, num_labels=len(DECISION_PHASE_HEAD_CONFIG.labels), id2label=DECISION_PHASE_HEAD_CONFIG.id2label, label2id=DECISION_PHASE_HEAD_CONFIG.label2id, ) training_args = TrainingArguments( output_dir=str(DECISION_PHASE_HEAD_CONFIG.model_dir), eval_strategy="epoch", save_strategy="no", logging_strategy="epoch", num_train_epochs=4, per_device_train_batch_size=4, per_device_eval_batch_size=4, learning_rate=2e-5, weight_decay=0.01, report_to="none", ) trainer = WeightedTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_classification_metrics, class_weights=class_weights, ) print( f"Loaded decision_phase splits: train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}" ) print(f"Decision phase weights: {[round(float(x), 3) for x in class_weights.tolist()]}") trainer.train() val_metrics = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="val") test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test") print(val_metrics) print(test_metrics) DECISION_PHASE_HEAD_CONFIG.model_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(DECISION_PHASE_HEAD_CONFIG.model_dir) tokenizer.save_pretrained(DECISION_PHASE_HEAD_CONFIG.model_dir) write_json( DECISION_PHASE_HEAD_CONFIG.model_dir / "train_metrics.json", { "head": DECISION_PHASE_HEAD_CONFIG.slug, "train_count": len(train_rows), "val_count": len(val_rows), "test_count": len(test_rows), "val_metrics": val_metrics, "test_metrics": test_metrics, }, )