Transformer / evaluate.py
czjun
Update README and implement training and evaluation scripts for Chinese summarization model
8d28a45
from __future__ import annotations
import argparse
import csv
import time
from pathlib import Path
try:
from bert_score import score as bertscore
from rouge_score import rouge_scorer
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
except Exception as exc: # pragma: no cover
raise SystemExit(
"Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first."
) from exc
from data_utils import load_jsonl
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate summarization models")
parser.add_argument("--test-path", required=True)
parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
parser.add_argument("--max-source-length", type=int, default=512)
parser.add_argument("--target-length", type=int, default=120)
parser.add_argument("--tolerance", type=float, default=0.2)
parser.add_argument("--output-csv", default="metrics_report.csv")
parser.add_argument("--qafacteval-model-folder", default=None)
return parser.parse_args()
def length_hit(text: str, target_length: int, tolerance: float) -> bool:
low = int(target_length * (1 - tolerance))
high = int(target_length * (1 + tolerance))
return low <= len(text) <= high
def try_qafacteval(model_folder: str | None, sources, preds):
if not model_folder:
return [None] * len(preds)
try:
from qafacteval import QAFactEval
except Exception:
return [None] * len(preds)
metric = QAFactEval(
lerc_quip_path=f"{model_folder}/quip-512-mocha",
generation_model_path=f"{model_folder}/generation/model.tar.gz",
answering_model_dir=f"{model_folder}/answering",
lerc_model_path=f"{model_folder}/lerc/model.tar.gz",
lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz",
cuda_device=0 if torch.cuda.is_available() else -1,
use_lerc_quip=True,
verbose=False,
generation_batch_size=8,
answering_batch_size=8,
lerc_batch_size=4,
)
results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True)
scores = []
for row in results:
item = row[0]["qa-eval"].get("lerc_quip")
scores.append(item)
return scores
def main():
args = parse_args()
examples = load_jsonl(args.test_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
sources = []
refs = []
preds = []
times_ms = []
length_flags = []
for ex in examples:
inputs = tokenizer(
ex.article,
return_tensors="pt",
truncation=True,
max_length=args.max_source_length,
)
inputs.pop("token_type_ids", None)
inputs = {k: v.to(device) for k, v in inputs.items()}
start = time.perf_counter()
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max(48, min(192, int(args.target_length * 1.1))),
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=1.0,
early_stopping=True,
)
elapsed_ms = (time.perf_counter() - start) * 1000
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
sources.append(ex.article)
refs.append(ex.summary)
preds.append(pred)
times_ms.append(elapsed_ms)
length_flags.append(length_hit(pred, args.target_length, args.tolerance))
rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)]
P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False)
qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds)
rouge_l = sum(rouge_ls) / max(1, len(rouge_ls))
bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean())
length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags))
avg_latency = sum(times_ms) / max(1, len(times_ms))
qafacteval_valid = [s for s in qafacteval_scores if s is not None]
qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None
print(f"ROUGE-L: {rouge_l:.4f}")
print(f"BERTScore: {bert_f1:.4f}")
print(f"Length Hit Rate: {length_rate:.4f}")
print(f"Avg Latency(ms): {avg_latency:.2f}")
if qafacteval_avg is not None:
print(f"QAFactEval: {qafacteval_avg:.4f}")
else:
print("QAFactEval: N/A")
out_path = Path(args.output_csv)
with out_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"])
writer.writerow(
[
args.model_name,
f"{rouge_l:.4f}",
f"{bert_f1:.4f}",
f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "",
f"{length_rate:.4f}",
f"{avg_latency:.2f}",
]
)
print(f"saved metrics to {out_path}")
if __name__ == "__main__":
main()