bc-test / src /models /evaluate.py
lamossta's picture
training and models
945de56
"""
Evaluate a trained BERT or DistilBERT model on the test split.
Uses Trainer.predict() for inference and sklearn for detailed metrics.
Usage:
python -m src.models.evaluate --model_dir checkpoints/distilbert/best
python -m src.models.evaluate --model_dir checkpoints/bert/best --split val
"""
import argparse
import json
import logging
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix, f1_score, matthews_corrcoef
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from src.datasets.combined_pairs_dataset import (
CombinedPairsDataset,
CombinedPairsConfig,
ID2LABEL,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate sentence-pair boundary classifier.")
parser.add_argument("--model_dir", required=True, help="Path to saved model directory")
parser.add_argument("--split", choices=["val", "test"], default="test")
parser.add_argument("--data_root", default="data")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
model_dir = Path(args.model_dir)
# ── load model + tokenizer ──────────────────────────────────────────
log.info(f"Loading model from {model_dir}")
model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True)
# ── load data ───────────────────────────────────────────────────────
cfg = CombinedPairsConfig(
data_root=args.data_root,
seed=args.seed,
max_length=args.max_length,
)
builder = CombinedPairsDataset(cfg)
dd = builder.build_hf_dataset_dict(tokenizer)
ds = dd[args.split]
log.info(f"Evaluating on {args.split} split ({len(ds):,} samples)")
# ── predict via Trainer ─────────────────────────────────────────────
eval_args = TrainingArguments(
output_dir="/tmp/eval_output",
per_device_eval_batch_size=args.batch_size,
report_to="none",
seed=args.seed,
)
trainer = Trainer(
model=model,
args=eval_args,
)
predictions = trainer.predict(ds)
preds = np.argmax(predictions.predictions, axis=-1)
labels = predictions.label_ids
# ── report ──────────────────────────────────────────────────────────
target_names = [ID2LABEL[i] for i in range(3)]
weighted_f1 = f1_score(labels, preds, average="weighted")
macro_f1 = f1_score(labels, preds, average="macro")
mcc = matthews_corrcoef(labels, preds)
print(f"\n{'='*60}")
print(f" Evaluation on {args.split} split ({len(ds):,} samples)")
print(f"{'='*60}\n")
print(f" Weighted F1: {weighted_f1:.4f}")
print(f" Macro F1: {macro_f1:.4f}")
print(f" MCC: {mcc:.4f}\n")
report = classification_report(
labels, preds,
target_names=target_names,
digits=4,
)
print(report)
print("Confusion matrix:")
cm = confusion_matrix(labels, preds, labels=[0, 1, 2])
header = " " + " ".join(f"{n:>14s}" for n in target_names)
print(header)
for i, row in enumerate(cm):
row_str = " ".join(f"{v:>14,}" for v in row)
print(f" {target_names[i]:>14s} {row_str}")
# ── save metrics ────────────────────────────────────────────────────
report_dict = classification_report(
labels, preds,
target_names=target_names,
output_dict=True,
)
report_dict["weighted_f1"] = weighted_f1
report_dict["macro_f1"] = macro_f1
report_dict["mcc"] = mcc
out_path = model_dir / f"{args.split}_metrics.json"
with open(out_path, "w") as f:
json.dump(report_dict, f, indent=2)
print(f"\nMetrics saved to {out_path}")
if __name__ == "__main__":
main()