Spaces:
Sleeping
Sleeping
czjun
Update README and implement training and evaluation scripts for Chinese summarization model
8d28a45 | from __future__ import annotations | |
| import argparse | |
| import csv | |
| import time | |
| from pathlib import Path | |
| try: | |
| from bert_score import score as bertscore | |
| from rouge_score import rouge_scorer | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| except Exception as exc: # pragma: no cover | |
| raise SystemExit( | |
| "Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first." | |
| ) from exc | |
| from data_utils import load_jsonl | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Evaluate summarization models") | |
| parser.add_argument("--test-path", required=True) | |
| parser.add_argument("--model-name", default="fnlp/bart-base-chinese") | |
| parser.add_argument("--max-source-length", type=int, default=512) | |
| parser.add_argument("--target-length", type=int, default=120) | |
| parser.add_argument("--tolerance", type=float, default=0.2) | |
| parser.add_argument("--output-csv", default="metrics_report.csv") | |
| parser.add_argument("--qafacteval-model-folder", default=None) | |
| return parser.parse_args() | |
| def length_hit(text: str, target_length: int, tolerance: float) -> bool: | |
| low = int(target_length * (1 - tolerance)) | |
| high = int(target_length * (1 + tolerance)) | |
| return low <= len(text) <= high | |
| def try_qafacteval(model_folder: str | None, sources, preds): | |
| if not model_folder: | |
| return [None] * len(preds) | |
| try: | |
| from qafacteval import QAFactEval | |
| except Exception: | |
| return [None] * len(preds) | |
| metric = QAFactEval( | |
| lerc_quip_path=f"{model_folder}/quip-512-mocha", | |
| generation_model_path=f"{model_folder}/generation/model.tar.gz", | |
| answering_model_dir=f"{model_folder}/answering", | |
| lerc_model_path=f"{model_folder}/lerc/model.tar.gz", | |
| lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz", | |
| cuda_device=0 if torch.cuda.is_available() else -1, | |
| use_lerc_quip=True, | |
| verbose=False, | |
| generation_batch_size=8, | |
| answering_batch_size=8, | |
| lerc_batch_size=4, | |
| ) | |
| results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True) | |
| scores = [] | |
| for row in results: | |
| item = row[0]["qa-eval"].get("lerc_quip") | |
| scores.append(item) | |
| return scores | |
| def main(): | |
| args = parse_args() | |
| examples = load_jsonl(args.test_path) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) | |
| sources = [] | |
| refs = [] | |
| preds = [] | |
| times_ms = [] | |
| length_flags = [] | |
| for ex in examples: | |
| inputs = tokenizer( | |
| ex.article, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=args.max_source_length, | |
| ) | |
| inputs.pop("token_type_ids", None) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| start = time.perf_counter() | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max(48, min(192, int(args.target_length * 1.1))), | |
| num_beams=4, | |
| no_repeat_ngram_size=3, | |
| length_penalty=1.0, | |
| early_stopping=True, | |
| ) | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| pred = tokenizer.decode(out[0], skip_special_tokens=True).strip() | |
| sources.append(ex.article) | |
| refs.append(ex.summary) | |
| preds.append(pred) | |
| times_ms.append(elapsed_ms) | |
| length_flags.append(length_hit(pred, args.target_length, args.tolerance)) | |
| rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)] | |
| P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False) | |
| qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds) | |
| rouge_l = sum(rouge_ls) / max(1, len(rouge_ls)) | |
| bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean()) | |
| length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags)) | |
| avg_latency = sum(times_ms) / max(1, len(times_ms)) | |
| qafacteval_valid = [s for s in qafacteval_scores if s is not None] | |
| qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None | |
| print(f"ROUGE-L: {rouge_l:.4f}") | |
| print(f"BERTScore: {bert_f1:.4f}") | |
| print(f"Length Hit Rate: {length_rate:.4f}") | |
| print(f"Avg Latency(ms): {avg_latency:.2f}") | |
| if qafacteval_avg is not None: | |
| print(f"QAFactEval: {qafacteval_avg:.4f}") | |
| else: | |
| print("QAFactEval: N/A") | |
| out_path = Path(args.output_csv) | |
| with out_path.open("w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"]) | |
| writer.writerow( | |
| [ | |
| args.model_name, | |
| f"{rouge_l:.4f}", | |
| f"{bert_f1:.4f}", | |
| f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "", | |
| f"{length_rate:.4f}", | |
| f"{avg_latency:.2f}", | |
| ] | |
| ) | |
| print(f"saved metrics to {out_path}") | |
| if __name__ == "__main__": | |
| main() | |