| """ |
| Evaluate a trained BERT or DistilBERT model on the test split. |
| |
| Uses Trainer.predict() for inference and sklearn for detailed metrics. |
| |
| Usage: |
| python -m src.models.evaluate --model_dir checkpoints/distilbert/best |
| python -m src.models.evaluate --model_dir checkpoints/bert/best --split val |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from sklearn.metrics import classification_report, confusion_matrix, f1_score, matthews_corrcoef |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| from src.datasets.combined_pairs_dataset import ( |
| CombinedPairsDataset, |
| CombinedPairsConfig, |
| ID2LABEL, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Evaluate sentence-pair boundary classifier.") |
| parser.add_argument("--model_dir", required=True, help="Path to saved model directory") |
| parser.add_argument("--split", choices=["val", "test"], default="test") |
| parser.add_argument("--data_root", default="data") |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| model_dir = Path(args.model_dir) |
|
|
| |
| log.info(f"Loading model from {model_dir}") |
| model = AutoModelForSequenceClassification.from_pretrained(str(model_dir)) |
| tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True) |
|
|
| |
| cfg = CombinedPairsConfig( |
| data_root=args.data_root, |
| seed=args.seed, |
| max_length=args.max_length, |
| ) |
| builder = CombinedPairsDataset(cfg) |
| dd = builder.build_hf_dataset_dict(tokenizer) |
| ds = dd[args.split] |
|
|
| log.info(f"Evaluating on {args.split} split ({len(ds):,} samples)") |
|
|
| |
| eval_args = TrainingArguments( |
| output_dir="/tmp/eval_output", |
| per_device_eval_batch_size=args.batch_size, |
| report_to="none", |
| seed=args.seed, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=eval_args, |
| ) |
|
|
| predictions = trainer.predict(ds) |
| preds = np.argmax(predictions.predictions, axis=-1) |
| labels = predictions.label_ids |
|
|
| |
| target_names = [ID2LABEL[i] for i in range(3)] |
|
|
| weighted_f1 = f1_score(labels, preds, average="weighted") |
| macro_f1 = f1_score(labels, preds, average="macro") |
| mcc = matthews_corrcoef(labels, preds) |
|
|
| print(f"\n{'='*60}") |
| print(f" Evaluation on {args.split} split ({len(ds):,} samples)") |
| print(f"{'='*60}\n") |
|
|
| print(f" Weighted F1: {weighted_f1:.4f}") |
| print(f" Macro F1: {macro_f1:.4f}") |
| print(f" MCC: {mcc:.4f}\n") |
|
|
| report = classification_report( |
| labels, preds, |
| target_names=target_names, |
| digits=4, |
| ) |
| print(report) |
|
|
| print("Confusion matrix:") |
| cm = confusion_matrix(labels, preds, labels=[0, 1, 2]) |
| header = " " + " ".join(f"{n:>14s}" for n in target_names) |
| print(header) |
| for i, row in enumerate(cm): |
| row_str = " ".join(f"{v:>14,}" for v in row) |
| print(f" {target_names[i]:>14s} {row_str}") |
|
|
| |
| report_dict = classification_report( |
| labels, preds, |
| target_names=target_names, |
| output_dict=True, |
| ) |
| report_dict["weighted_f1"] = weighted_f1 |
| report_dict["macro_f1"] = macro_f1 |
| report_dict["mcc"] = mcc |
|
|
| out_path = model_dir / f"{args.split}_metrics.json" |
| with open(out_path, "w") as f: |
| json.dump(report_dict, f, indent=2) |
| print(f"\nMetrics saved to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|