""" Evaluate a trained BERT or DistilBERT model on the test split. Uses Trainer.predict() for inference and sklearn for detailed metrics. Usage: python -m src.models.evaluate --model_dir checkpoints/distilbert/best python -m src.models.evaluate --model_dir checkpoints/bert/best --split val """ import argparse import json import logging from pathlib import Path import numpy as np import torch from sklearn.metrics import classification_report, confusion_matrix, f1_score, matthews_corrcoef from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, ) from src.datasets.combined_pairs_dataset import ( CombinedPairsDataset, CombinedPairsConfig, ID2LABEL, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Evaluate sentence-pair boundary classifier.") parser.add_argument("--model_dir", required=True, help="Path to saved model directory") parser.add_argument("--split", choices=["val", "test"], default="test") parser.add_argument("--data_root", default="data") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() model_dir = Path(args.model_dir) # ── load model + tokenizer ────────────────────────────────────────── log.info(f"Loading model from {model_dir}") model = AutoModelForSequenceClassification.from_pretrained(str(model_dir)) tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True) # ── load data ─────────────────────────────────────────────────────── cfg = CombinedPairsConfig( data_root=args.data_root, seed=args.seed, max_length=args.max_length, ) builder = CombinedPairsDataset(cfg) dd = builder.build_hf_dataset_dict(tokenizer) ds = dd[args.split] log.info(f"Evaluating on {args.split} split ({len(ds):,} samples)") # ── predict via Trainer ───────────────────────────────────────────── eval_args = TrainingArguments( output_dir="/tmp/eval_output", per_device_eval_batch_size=args.batch_size, report_to="none", seed=args.seed, ) trainer = Trainer( model=model, args=eval_args, ) predictions = trainer.predict(ds) preds = np.argmax(predictions.predictions, axis=-1) labels = predictions.label_ids # ── report ────────────────────────────────────────────────────────── target_names = [ID2LABEL[i] for i in range(3)] weighted_f1 = f1_score(labels, preds, average="weighted") macro_f1 = f1_score(labels, preds, average="macro") mcc = matthews_corrcoef(labels, preds) print(f"\n{'='*60}") print(f" Evaluation on {args.split} split ({len(ds):,} samples)") print(f"{'='*60}\n") print(f" Weighted F1: {weighted_f1:.4f}") print(f" Macro F1: {macro_f1:.4f}") print(f" MCC: {mcc:.4f}\n") report = classification_report( labels, preds, target_names=target_names, digits=4, ) print(report) print("Confusion matrix:") cm = confusion_matrix(labels, preds, labels=[0, 1, 2]) header = " " + " ".join(f"{n:>14s}" for n in target_names) print(header) for i, row in enumerate(cm): row_str = " ".join(f"{v:>14,}" for v in row) print(f" {target_names[i]:>14s} {row_str}") # ── save metrics ──────────────────────────────────────────────────── report_dict = classification_report( labels, preds, target_names=target_names, output_dict=True, ) report_dict["weighted_f1"] = weighted_f1 report_dict["macro_f1"] = macro_f1 report_dict["mcc"] = mcc out_path = model_dir / f"{args.split}_metrics.json" with open(out_path, "w") as f: json.dump(report_dict, f, indent=2) print(f"\nMetrics saved to {out_path}") if __name__ == "__main__": main()