| """ |
| Evaluation script. |
| Runs all evaluation metrics on the test set. |
| Run: python scripts/evaluate.py --config configs/training_config.yaml --split test |
| """ |
|
|
| import click |
| import yaml |
| import json |
| import torch |
| from loguru import logger |
| from rich.console import Console |
| from rich.table import Table |
|
|
| from src.model.base_model import load_model_and_tokenizer |
| from src.model.generation_utils import batch_generate |
| from src.evaluation.gleu_scorer import GLEUScorer |
| from src.evaluation.errant_evaluator import ERRANTEvaluator |
| from src.evaluation.style_metrics import StyleEvaluator |
| from src.style.fingerprinter import StyleFingerprinter |
| from src.vocabulary.awl_loader import AWLLoader |
|
|
| console = Console() |
|
|
|
|
| @click.command() |
| @click.option("--config", default="configs/training_config.yaml") |
| @click.option("--split", default="test") |
| @click.option("--max-samples", default=100, help="Max samples to evaluate") |
| def evaluate(config: str, split: str, max_samples: int): |
| """Run evaluation on the specified data split.""" |
| with open(config) as f: |
| cfg = yaml.safe_load(f) |
|
|
| model_cfg = cfg.get("model", {}) |
| gen_cfg = cfg.get("generation", {}) |
|
|
| checkpoint = "checkpoints/best_model" |
| try: |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| except Exception: |
| model, tokenizer, _ = load_model_and_tokenizer(model_cfg.get("key", "flan-t5-large"), quantize=False, use_lora=False) |
| model.eval() |
|
|
| data_path = cfg.get("data", {}).get(f"{split}_path", f"data/processed/{split}.jsonl") |
| sources, references = [], [] |
| with open(data_path) as f: |
| for i, line in enumerate(f): |
| if i >= max_samples: |
| break |
| obj = json.loads(line.strip()) |
| sources.append(obj["input"]) |
| references.append(obj["target"]) |
|
|
| prefix = "Correct the following text for grammar, spelling, and clarity. Text to correct: " |
| predictions = batch_generate(model, tokenizer, [prefix + s for s in sources], gen_cfg) |
|
|
| gleu_scorer = GLEUScorer() |
| gleu = gleu_scorer.compute_gleu(predictions, references) |
| bert_p, bert_r, bert_f1 = gleu_scorer.compute_bert_score(predictions, references) |
|
|
| errant_scores = ERRANTEvaluator().evaluate(sources, predictions, references) |
|
|
| fp = StyleFingerprinter(spacy_model="en_core_web_sm") |
| style_scores = StyleEvaluator(fp, AWLLoader()).evaluate_batch(sources, predictions, references) |
|
|
| table = Table(title=f"Evaluation ({split}, {len(sources)} samples)") |
| table.add_column("Metric", style="cyan") |
| table.add_column("Score", style="green") |
| table.add_row("GLEU", f"{gleu:.2f}") |
| table.add_row("BERTScore F1", f"{bert_f1:.4f}") |
| table.add_row("ERRANT F0.5", f"{errant_scores['f0.5']:.4f}") |
| table.add_row("Style Similarity", f"{style_scores['style_similarity_mean']:.4f}") |
| table.add_row("AWL Coverage", f"{style_scores['awl_coverage_mean']:.4f}") |
| console.print(table) |
|
|
| results = {"gleu": gleu, "bert_f1": bert_f1, "errant": errant_scores, "style": style_scores} |
| with open(f"logs/eval_results_{split}.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|