""" Evaluation script. Runs all evaluation metrics on the test set. Run: python scripts/evaluate.py --config configs/training_config.yaml --split test """ import click import yaml import json import torch from loguru import logger from rich.console import Console from rich.table import Table from src.model.base_model import load_model_and_tokenizer from src.model.generation_utils import batch_generate from src.evaluation.gleu_scorer import GLEUScorer from src.evaluation.errant_evaluator import ERRANTEvaluator from src.evaluation.style_metrics import StyleEvaluator from src.style.fingerprinter import StyleFingerprinter from src.vocabulary.awl_loader import AWLLoader console = Console() @click.command() @click.option("--config", default="configs/training_config.yaml") @click.option("--split", default="test") @click.option("--max-samples", default=100, help="Max samples to evaluate") def evaluate(config: str, split: str, max_samples: int): """Run evaluation on the specified data split.""" with open(config) as f: cfg = yaml.safe_load(f) model_cfg = cfg.get("model", {}) gen_cfg = cfg.get("generation", {}) checkpoint = "checkpoints/best_model" try: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint) except Exception: model, tokenizer, _ = load_model_and_tokenizer(model_cfg.get("key", "flan-t5-large"), quantize=False, use_lora=False) model.eval() data_path = cfg.get("data", {}).get(f"{split}_path", f"data/processed/{split}.jsonl") sources, references = [], [] with open(data_path) as f: for i, line in enumerate(f): if i >= max_samples: break obj = json.loads(line.strip()) sources.append(obj["input"]) references.append(obj["target"]) prefix = "Correct the following text for grammar, spelling, and clarity. Text to correct: " predictions = batch_generate(model, tokenizer, [prefix + s for s in sources], gen_cfg) gleu_scorer = GLEUScorer() gleu = gleu_scorer.compute_gleu(predictions, references) bert_p, bert_r, bert_f1 = gleu_scorer.compute_bert_score(predictions, references) errant_scores = ERRANTEvaluator().evaluate(sources, predictions, references) fp = StyleFingerprinter(spacy_model="en_core_web_sm") style_scores = StyleEvaluator(fp, AWLLoader()).evaluate_batch(sources, predictions, references) table = Table(title=f"Evaluation ({split}, {len(sources)} samples)") table.add_column("Metric", style="cyan") table.add_column("Score", style="green") table.add_row("GLEU", f"{gleu:.2f}") table.add_row("BERTScore F1", f"{bert_f1:.4f}") table.add_row("ERRANT F0.5", f"{errant_scores['f0.5']:.4f}") table.add_row("Style Similarity", f"{style_scores['style_similarity_mean']:.4f}") table.add_row("AWL Coverage", f"{style_scores['awl_coverage_mean']:.4f}") console.print(table) results = {"gleu": gleu, "bert_f1": bert_f1, "errant": errant_scores, "style": style_scores} with open(f"logs/eval_results_{split}.json", "w") as f: json.dump(results, f, indent=2) if __name__ == "__main__": evaluate()