rewrite / scripts /evaluate.py
morpheuslord's picture
Add files using upload-large-folder tool
3df5819 verified
"""
Evaluation script.
Runs all evaluation metrics on the test set.
Run: python scripts/evaluate.py --config configs/training_config.yaml --split test
"""
import click
import yaml
import json
import torch
from loguru import logger
from rich.console import Console
from rich.table import Table
from src.model.base_model import load_model_and_tokenizer
from src.model.generation_utils import batch_generate
from src.evaluation.gleu_scorer import GLEUScorer
from src.evaluation.errant_evaluator import ERRANTEvaluator
from src.evaluation.style_metrics import StyleEvaluator
from src.style.fingerprinter import StyleFingerprinter
from src.vocabulary.awl_loader import AWLLoader
console = Console()
@click.command()
@click.option("--config", default="configs/training_config.yaml")
@click.option("--split", default="test")
@click.option("--max-samples", default=100, help="Max samples to evaluate")
def evaluate(config: str, split: str, max_samples: int):
"""Run evaluation on the specified data split."""
with open(config) as f:
cfg = yaml.safe_load(f)
model_cfg = cfg.get("model", {})
gen_cfg = cfg.get("generation", {})
checkpoint = "checkpoints/best_model"
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
except Exception:
model, tokenizer, _ = load_model_and_tokenizer(model_cfg.get("key", "flan-t5-large"), quantize=False, use_lora=False)
model.eval()
data_path = cfg.get("data", {}).get(f"{split}_path", f"data/processed/{split}.jsonl")
sources, references = [], []
with open(data_path) as f:
for i, line in enumerate(f):
if i >= max_samples:
break
obj = json.loads(line.strip())
sources.append(obj["input"])
references.append(obj["target"])
prefix = "Correct the following text for grammar, spelling, and clarity. Text to correct: "
predictions = batch_generate(model, tokenizer, [prefix + s for s in sources], gen_cfg)
gleu_scorer = GLEUScorer()
gleu = gleu_scorer.compute_gleu(predictions, references)
bert_p, bert_r, bert_f1 = gleu_scorer.compute_bert_score(predictions, references)
errant_scores = ERRANTEvaluator().evaluate(sources, predictions, references)
fp = StyleFingerprinter(spacy_model="en_core_web_sm")
style_scores = StyleEvaluator(fp, AWLLoader()).evaluate_batch(sources, predictions, references)
table = Table(title=f"Evaluation ({split}, {len(sources)} samples)")
table.add_column("Metric", style="cyan")
table.add_column("Score", style="green")
table.add_row("GLEU", f"{gleu:.2f}")
table.add_row("BERTScore F1", f"{bert_f1:.4f}")
table.add_row("ERRANT F0.5", f"{errant_scores['f0.5']:.4f}")
table.add_row("Style Similarity", f"{style_scores['style_similarity_mean']:.4f}")
table.add_row("AWL Coverage", f"{style_scores['awl_coverage_mean']:.4f}")
console.print(table)
results = {"gleu": gleu, "bert_f1": bert_f1, "errant": errant_scores, "style": style_scores}
with open(f"logs/eval_results_{split}.json", "w") as f:
json.dump(results, f, indent=2)
if __name__ == "__main__":
evaluate()