File size: 3,300 Bytes
3df5819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Evaluation script.
Runs all evaluation metrics on the test set.
Run: python scripts/evaluate.py --config configs/training_config.yaml --split test
"""

import click
import yaml
import json
import torch
from loguru import logger
from rich.console import Console
from rich.table import Table

from src.model.base_model import load_model_and_tokenizer
from src.model.generation_utils import batch_generate
from src.evaluation.gleu_scorer import GLEUScorer
from src.evaluation.errant_evaluator import ERRANTEvaluator
from src.evaluation.style_metrics import StyleEvaluator
from src.style.fingerprinter import StyleFingerprinter
from src.vocabulary.awl_loader import AWLLoader

console = Console()


@click.command()
@click.option("--config", default="configs/training_config.yaml")
@click.option("--split", default="test")
@click.option("--max-samples", default=100, help="Max samples to evaluate")
def evaluate(config: str, split: str, max_samples: int):
    """Run evaluation on the specified data split."""
    with open(config) as f:
        cfg = yaml.safe_load(f)

    model_cfg = cfg.get("model", {})
    gen_cfg = cfg.get("generation", {})

    checkpoint = "checkpoints/best_model"
    try:
        from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    except Exception:
        model, tokenizer, _ = load_model_and_tokenizer(model_cfg.get("key", "flan-t5-large"), quantize=False, use_lora=False)
    model.eval()

    data_path = cfg.get("data", {}).get(f"{split}_path", f"data/processed/{split}.jsonl")
    sources, references = [], []
    with open(data_path) as f:
        for i, line in enumerate(f):
            if i >= max_samples:
                break
            obj = json.loads(line.strip())
            sources.append(obj["input"])
            references.append(obj["target"])

    prefix = "Correct the following text for grammar, spelling, and clarity. Text to correct: "
    predictions = batch_generate(model, tokenizer, [prefix + s for s in sources], gen_cfg)

    gleu_scorer = GLEUScorer()
    gleu = gleu_scorer.compute_gleu(predictions, references)
    bert_p, bert_r, bert_f1 = gleu_scorer.compute_bert_score(predictions, references)

    errant_scores = ERRANTEvaluator().evaluate(sources, predictions, references)

    fp = StyleFingerprinter(spacy_model="en_core_web_sm")
    style_scores = StyleEvaluator(fp, AWLLoader()).evaluate_batch(sources, predictions, references)

    table = Table(title=f"Evaluation ({split}, {len(sources)} samples)")
    table.add_column("Metric", style="cyan")
    table.add_column("Score", style="green")
    table.add_row("GLEU", f"{gleu:.2f}")
    table.add_row("BERTScore F1", f"{bert_f1:.4f}")
    table.add_row("ERRANT F0.5", f"{errant_scores['f0.5']:.4f}")
    table.add_row("Style Similarity", f"{style_scores['style_similarity_mean']:.4f}")
    table.add_row("AWL Coverage", f"{style_scores['awl_coverage_mean']:.4f}")
    console.print(table)

    results = {"gleu": gleu, "bert_f1": bert_f1, "errant": errant_scores, "style": style_scores}
    with open(f"logs/eval_results_{split}.json", "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    evaluate()