File size: 5,438 Bytes
8d28a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import argparse
import csv
import time
from pathlib import Path

try:
    from bert_score import score as bertscore
    from rouge_score import rouge_scorer
    import torch
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
except Exception as exc:  # pragma: no cover
    raise SystemExit(
        "Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first."
    ) from exc

from data_utils import load_jsonl


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate summarization models")
    parser.add_argument("--test-path", required=True)
    parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
    parser.add_argument("--max-source-length", type=int, default=512)
    parser.add_argument("--target-length", type=int, default=120)
    parser.add_argument("--tolerance", type=float, default=0.2)
    parser.add_argument("--output-csv", default="metrics_report.csv")
    parser.add_argument("--qafacteval-model-folder", default=None)
    return parser.parse_args()


def length_hit(text: str, target_length: int, tolerance: float) -> bool:
    low = int(target_length * (1 - tolerance))
    high = int(target_length * (1 + tolerance))
    return low <= len(text) <= high


def try_qafacteval(model_folder: str | None, sources, preds):
    if not model_folder:
        return [None] * len(preds)
    try:
        from qafacteval import QAFactEval
    except Exception:
        return [None] * len(preds)
    metric = QAFactEval(
        lerc_quip_path=f"{model_folder}/quip-512-mocha",
        generation_model_path=f"{model_folder}/generation/model.tar.gz",
        answering_model_dir=f"{model_folder}/answering",
        lerc_model_path=f"{model_folder}/lerc/model.tar.gz",
        lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz",
        cuda_device=0 if torch.cuda.is_available() else -1,
        use_lerc_quip=True,
        verbose=False,
        generation_batch_size=8,
        answering_batch_size=8,
        lerc_batch_size=4,
    )
    results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True)
    scores = []
    for row in results:
        item = row[0]["qa-eval"].get("lerc_quip")
        scores.append(item)
    return scores


def main():
    args = parse_args()
    examples = load_jsonl(args.test_path)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
    sources = []
    refs = []
    preds = []
    times_ms = []
    length_flags = []

    for ex in examples:
        inputs = tokenizer(
            ex.article,
            return_tensors="pt",
            truncation=True,
            max_length=args.max_source_length,
        )
        inputs.pop("token_type_ids", None)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        start = time.perf_counter()
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=max(48, min(192, int(args.target_length * 1.1))),
                num_beams=4,
                no_repeat_ngram_size=3,
                length_penalty=1.0,
                early_stopping=True,
            )
        elapsed_ms = (time.perf_counter() - start) * 1000
        pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()

        sources.append(ex.article)
        refs.append(ex.summary)
        preds.append(pred)
        times_ms.append(elapsed_ms)
        length_flags.append(length_hit(pred, args.target_length, args.tolerance))

    rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)]
    P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False)
    qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds)

    rouge_l = sum(rouge_ls) / max(1, len(rouge_ls))
    bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean())
    length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags))
    avg_latency = sum(times_ms) / max(1, len(times_ms))
    qafacteval_valid = [s for s in qafacteval_scores if s is not None]
    qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None

    print(f"ROUGE-L: {rouge_l:.4f}")
    print(f"BERTScore: {bert_f1:.4f}")
    print(f"Length Hit Rate: {length_rate:.4f}")
    print(f"Avg Latency(ms): {avg_latency:.2f}")
    if qafacteval_avg is not None:
        print(f"QAFactEval: {qafacteval_avg:.4f}")
    else:
        print("QAFactEval: N/A")

    out_path = Path(args.output_csv)
    with out_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"])
        writer.writerow(
            [
                args.model_name,
                f"{rouge_l:.4f}",
                f"{bert_f1:.4f}",
                f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "",
                f"{length_rate:.4f}",
                f"{avg_latency:.2f}",
            ]
        )
    print(f"saved metrics to {out_path}")


if __name__ == "__main__":
    main()