from __future__ import annotations
import argparse
import logging
from pathlib import Path
import evaluate
import numpy as np
import torch
from datasets import load_from_disk
from transformers import AutoModelForSeq2SeqLM
from .common import (
ARTIFACT_DIR,
DEFAULT_SUMMARY_COLUMN,
DEFAULT_TEXT_COLUMN,
ensure_project_dirs,
load_tokenizer,
maybe_limit_split,
resolve_model_reference,
validate_model_dir,
)
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Compare two models side-by-side on a test set."
)
parser.add_argument("--model-a", required=True, help="Path to Model A checkpoint.")
parser.add_argument("--model-b", required=True, help="Path to Model B checkpoint.")
parser.add_argument(
"--dataset-dir", required=True, help="Prepared dataset directory."
)
parser.add_argument("--split", default="test")
parser.add_argument("--max-samples", type=int, default=20)
parser.add_argument("--text-column", default=DEFAULT_TEXT_COLUMN)
parser.add_argument("--summary-column", default=DEFAULT_SUMMARY_COLUMN)
parser.add_argument(
"--output-file", default=str(ARTIFACT_DIR / "comparison.html")
)
return parser.parse_args()
@torch.inference_mode()
def generate_summaries(
model_path: str, dataset, text_col: str, device: torch.device
) -> list[str]:
ref = resolve_model_reference(model_path)
validate_model_dir(ref)
LOGGER.info(f"Loading {ref}...")
tokenizer = load_tokenizer(ref)
model = AutoModelForSeq2SeqLM.from_pretrained(ref).to(device)
model.eval()
predictions = []
for item in dataset:
text = item[text_col]
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
).to(device)
out = model.generate(**inputs, max_length=128, num_beams=4)
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
predictions.append(pred)
del model
torch.cuda.empty_cache()
return predictions
def score_predictions(predictions: list[str], references: list[str]) -> dict:
rouge = evaluate.load("rouge")
r_res = rouge.compute(
predictions=predictions, references=references, use_stemmer=True
)
from bert_score import score as bert_score_fn
safe_preds = [p if p.strip() else "..." for p in predictions]
safe_refs = [r if r.strip() else "..." for r in references]
LOGGER.info("Computing BERTScore...")
_, _, f1 = bert_score_fn(safe_preds, safe_refs, lang="en", verbose=False)
return {
"rouge1": r_res["rouge1"],
"rouge2": r_res["rouge2"],
"rougeL": r_res["rougeL"],
"bertscore": float(f1.mean().item()),
}
def generate_html(
model_a_name: str,
model_b_name: str,
scores_a: dict,
scores_b: dict,
dataset,
preds_a: list[str],
preds_b: list[str],
text_col: str,
sum_col: str,
) -> str:
html = f"""
Model Comparison
Model Comparison
Aggregate Scores
| Metric |
Model A: {model_a_name} |
Model B: {model_b_name} |
"""
for k in ["rouge1", "rouge2", "rougeL", "bertscore"]:
va = scores_a[k]
vb = scores_b[k]
ca = "better" if va >= vb else ""
cb = "better" if vb > va else ""
html += f"""
| {k.upper()} |
{va:.4f} |
{vb:.4f} |
"""
html += """
Side-by-Side Predictions
| Source |
Reference |
Model A |
Model B |
"""
for i, item in enumerate(dataset):
html += f"""
| {item[text_col]} |
{item[sum_col]} |
{preds_a[i]} |
{preds_b[i]} |
"""
html += """
"""
return html
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
args = parse_args()
ensure_project_dirs()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f"Loading dataset {args.dataset_dir} (split: {args.split})...")
dataset = load_from_disk(args.dataset_dir)[args.split]
dataset = maybe_limit_split(dataset, args.max_samples)
refs = [item[args.summary_column] for item in dataset]
LOGGER.info("--- Processing Model A ---")
preds_a = generate_summaries(args.model_a, dataset, args.text_column, device)
scores_a = score_predictions(preds_a, refs)
LOGGER.info("--- Processing Model B ---")
preds_b = generate_summaries(args.model_b, dataset, args.text_column, device)
scores_b = score_predictions(preds_b, refs)
name_a = Path(args.model_a).name
name_b = Path(args.model_b).name
LOGGER.info("Generating HTML report...")
html = generate_html(
name_a,
name_b,
scores_a,
scores_b,
dataset,
preds_a,
preds_b,
args.text_column,
args.summary_column,
)
out_file = Path(args.output_file)
out_file.parent.mkdir(parents=True, exist_ok=True)
out_file.write_text(html, encoding="utf-8")
LOGGER.info(f"Comparison report written to {out_file.absolute()}")
if __name__ == "__main__":
main()