| | import os |
| | import json |
| | import logging |
| | from typing import Dict, List, Tuple, Any |
| | import numpy as np |
| | from rouge_score import rouge_scorer |
| | from bert_score import score as bert_score |
| | from transformers import AutoTokenizer |
| | import torch |
| | import argparse |
| |
|
| |
|
| | class SyntheticSummariesEvaluator: |
| | def __init__( |
| | self, |
| | input_path: str, |
| | output_dir: str = "metrics", |
| | device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| | max_length: int = 512, |
| | batch_size: int = 16, |
| | rescale_with_baseline: bool = False, |
| | include_article: bool = False, |
| | w_rouge: float = 0.5, |
| | w_bert: float = 0.5, |
| | worst_quantile: float = 0.33, |
| | good_quantile: float = 0.5, |
| | best_quantile: float = 0.67, |
| | |
| | ): |
| | self.input_path = input_path |
| | self.output_dir = output_dir |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | with open(input_path, "r", encoding="utf-8") as f: |
| | self.data: List[Dict[str, Any]] = json.load(f) |
| |
|
| | self.device = device |
| | self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
| | self.max_length = max_length |
| | self.batch_size = batch_size |
| | self.rescale_with_baseline = rescale_with_baseline |
| | self.include_article = include_article |
| |
|
| | |
| | s = (w_rouge + w_bert) or 1.0 |
| | self.w_rouge = float(w_rouge) / s |
| | self.w_bert = float(w_bert) / s |
| |
|
| | |
| | if not (0.0 <= worst_quantile < best_quantile <= 1.0): |
| | logging.warning("Invalid quantiles; resetting to worst=0.33, best=0.67") |
| | worst_quantile, best_quantile = 0.33, 0.67 |
| | self.worst_q = worst_quantile |
| | self.best_q = best_quantile |
| | self.good_q = good_quantile |
| |
|
| | self.rouge = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True) |
| |
|
| | def _truncate(self, text: str) -> str: |
| | tokens = self.tokenizer.encode( |
| | text, |
| | add_special_tokens=True, |
| | max_length=self.max_length, |
| | truncation=True, |
| | ) |
| | return self.tokenizer.decode(tokens, skip_special_tokens=True) |
| |
|
| | def _compute_rougeLsum_f1(self, ref: str, hyp: str) -> float: |
| | result = self.rouge.score(ref, hyp) |
| | return float(result["rougeLsum"].fmeasure) |
| |
|
| | def _combine(self, rouge: float, bert_f: float) -> float: |
| | |
| | vals, ws = [], [] |
| | if rouge == rouge: |
| | vals.append(rouge); ws.append(self.w_rouge) |
| | if bert_f == bert_f: |
| | vals.append(bert_f); ws.append(self.w_bert) |
| | if not ws: |
| | return float("nan") |
| | s = sum(ws) |
| | ws = [w / s for w in ws] |
| | return float(sum(v * w for v, w in zip(vals, ws))) |
| |
|
| | def evaluate(self): |
| | |
| | pair_indices: List[Tuple[int, str]] = [] |
| | cands_trunc, refs_trunc = [], [] |
| | rouge_store: Dict[Tuple[int, str], float] = {} |
| |
|
| | for i, rec in enumerate(self.data): |
| | gold = rec.get("gold_summary", "") |
| | syn = rec.get("synthetic_summary", {}) or {} |
| |
|
| | for key in syn.keys(): |
| | cand = syn[key] if isinstance(syn[key], str) else str(syn[key]) |
| | cands_trunc.append(self._truncate(cand)) |
| | refs_trunc.append(self._truncate(gold)) |
| | pair_indices.append((i, key)) |
| | rouge_store[(i, key)] = self._compute_rougeLsum_f1(gold, cand) |
| |
|
| | |
| | F_vals = [np.nan] * len(pair_indices) |
| | if len(pair_indices) > 0: |
| | try: |
| | _, _, F = bert_score( |
| | cands=cands_trunc, |
| | refs=refs_trunc, |
| | model_type="emilyalsentzer/Bio_ClinicalBERT", |
| | num_layers=12, |
| | lang="en", |
| | device=self.device, |
| | rescale_with_baseline=self.rescale_with_baseline, |
| | batch_size=self.batch_size, |
| | ) |
| | F_vals = F.tolist() |
| | except Exception as e: |
| | logging.error(f"Error computing BERTScore: {e}", exc_info=True) |
| |
|
| | |
| | results_per_record: List[Dict[str, Any]] = [] |
| | for i, rec in enumerate(self.data): |
| | out = { |
| | "id": i, |
| | "gold_summary": rec.get("gold_summary", ""), |
| | "synthetic_summary": {} |
| | } |
| | if self.include_article: |
| | out["article"] = rec.get("article", "") |
| | syn = rec.get("synthetic_summary", {}) or {} |
| | for key in syn.keys(): |
| | out["synthetic_summary"][key] = { |
| | "text": syn[key] if isinstance(syn[key], str) else str(syn[key]), |
| | "score": {} |
| | } |
| | results_per_record.append(out) |
| |
|
| | |
| | idx_map = {(i_k[0], i_k[1]): idx for idx, i_k in enumerate(pair_indices)} |
| |
|
| | |
| | per_pair_combined: Dict[Tuple[int, str], float] = {} |
| | level_scores = {"B1": [], "B2": [], "B3": []} |
| | for (i, key), idx in idx_map.items(): |
| | r = rouge_store[(i, key)] |
| | f = F_vals[idx] |
| | c = self._combine(r, f) |
| | per_pair_combined[(i, key)] = c |
| | if key in level_scores: |
| | level_scores[key].append(c) |
| |
|
| | |
| | thresholds = {} |
| | for key in ["B1", "B2", "B3"]: |
| | scores = np.array(level_scores[key], dtype=float) |
| | if scores.size > 0 and np.any(scores == scores): |
| | worst_thr = float(np.nanpercentile(scores, self.worst_q * 100)) |
| | best_thr = float(np.nanpercentile(scores, self.best_q * 100)) |
| | good_thr = float(np.nanpercentile(scores, self.good_q * 100)) |
| | else: |
| | worst_thr = best_thr = good_thr = float("-inf") |
| | thresholds[key] = { |
| | "worst_thr": worst_thr, |
| | "best_thr": best_thr, |
| | "good_thr": good_thr |
| | } |
| |
|
| | |
| | agg = { |
| | "B1": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| | "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| | "B2": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| | "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| | "B3": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| | "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| | } |
| |
|
| | for (i, key), idx in idx_map.items(): |
| | r = rouge_store[(i, key)] |
| | f = F_vals[idx] |
| | c = per_pair_combined[(i, key)] |
| |
|
| | |
| | results_per_record[i]["synthetic_summary"][key]["score"] = { |
| | "ROUGE-L-Sum": float(r) if r == r else None, |
| | "BERTScore_F": float(f) if f == f else None, |
| | } |
| |
|
| | |
| | thr = thresholds.get(key, {"worst_thr": float("-inf"), "best_thr": float("-inf"), "good_thr": float("-inf")}) |
| | if not (c == c): |
| | category = "worst" |
| | is_good = False |
| | else: |
| | if c < thr["worst_thr"]: |
| | category = "worst" |
| | elif c < thr["best_thr"]: |
| | category = "good" |
| | else: |
| | category = "best" |
| | is_good = c >= thr["good_thr"] |
| |
|
| | results_per_record[i]["synthetic_summary"][key]["quality"] = { |
| | "category": category, |
| | "is_good": bool(is_good), |
| | "combined_score": float(c) if c == c else None |
| | } |
| |
|
| | |
| | if key in agg: |
| | if r == r: |
| | agg[key]["ROUGE-L-Sum"].append(float(r)) |
| | if f == f: |
| | agg[key]["BERTScore_F"].append(float(f)) |
| | if c == c: |
| | agg[key]["combined"].append(float(c)) |
| | agg[key]["count"] += 1 |
| | agg[key][category] += 1 |
| | if is_good: |
| | agg[key]["good_true"] += 1 |
| |
|
| | |
| | dataset_level_metrics = { |
| | "config": { |
| | "weights": {"w_rouge": self.w_rouge, "w_bert": self.w_bert}, |
| | "quantiles": {"worst_q": self.worst_q, "best_q": self.best_q, "good_q": self.good_q}, |
| | "thresholds": thresholds, |
| | } |
| | } |
| | for key, m in agg.items(): |
| | count = max(1, m["count"]) |
| | dataset_level_metrics[key] = { |
| | "ROUGE-L-Sum": float(np.mean(m["ROUGE-L-Sum"])) if m["ROUGE-L-Sum"] else None, |
| | "BERTScore_F": float(np.mean(m["BERTScore_F"])) if m["BERTScore_F"] else None, |
| | "combined_mean": float(np.mean(m["combined"])) if m["combined"] else None, |
| | "count": m["count"], |
| | "best_rate": m["best"] / count, |
| | "good_rate": m["good"] / count, |
| | "worst_rate": m["worst"] / count, |
| | "is_good_rate": m["good_true"] / count |
| | } |
| |
|
| | return results_per_record, dataset_level_metrics |
| |
|
| | def save(self, per_record: List[Dict[str, Any]], dataset_metrics: Dict[str, Dict[str, float]]): |
| | base = os.path.splitext(os.path.basename(self.input_path))[0] |
| | per_record_path = os.path.join(self.output_dir, f"{base}_scored.json") |
| | aggregate_path = os.path.join(self.output_dir, f"{base}_aggregate_metrics.json") |
| |
|
| | with open(per_record_path, "w", encoding="utf-8") as f: |
| | json.dump(per_record, f, ensure_ascii=False, indent=2) |
| |
|
| | with open(aggregate_path, "w", encoding="utf-8") as f: |
| | json.dump(dataset_metrics, f, ensure_ascii=False, indent=2) |
| |
|
| | print("Saved:") |
| | print(f"- Per-record scores: {per_record_path}") |
| | print(f"- Aggregate metrics: {aggregate_path}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Evaluate B1/B2/B3 summaries vs gold. Metrics: ROUGE-Lsum F1, BERTScore F1. Per-level categories: best/good/worst + is_good." |
| | ) |
| | parser.add_argument("--input_path", required=True, help="Path to the es_syntheticV3.json file") |
| | parser.add_argument("--output_dir", default="metrics", help="Where to save outputs") |
| | parser.add_argument("--batch_size", type=int, default=16, help="BERTScore batch size") |
| | parser.add_argument("--max_length", type=int, default=512, help="Max tokens for truncation (BERTScore)") |
| | parser.add_argument("--rescale_with_baseline", action="store_true", help="Use BERTScore baseline rescaling") |
| | parser.add_argument("--include_article", action="store_true", help="Include full article text in output JSON") |
| | parser.add_argument("--w_rouge", type=float, default=0.5, help="Weight for ROUGE-L-Sum in combined score") |
| | parser.add_argument("--w_bert", type=float, default=0.5, help="Weight for BERTScore_F in combined score") |
| | parser.add_argument("--worst_quantile", type=float, default=0.33, help="Bottom quantile -> 'worst'") |
| | parser.add_argument("--best_quantile", type=float, default=0.67, help="Top quantile boundary -> 'best'") |
| | parser.add_argument("--good_quantile", type=float, default=0.5, help="Quantile for is_good=True") |
| | args = parser.parse_args() |
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| |
|
| | evaluator = SyntheticSummariesEvaluator( |
| | input_path=args.input_path, |
| | output_dir=args.output_dir, |
| | batch_size=args.batch_size, |
| | max_length=args.max_length, |
| | rescale_with_baseline=args.rescale_with_baseline, |
| | include_article=args.include_article, |
| | w_rouge=args.w_rouge, |
| | w_bert=args.w_bert, |
| | worst_quantile=args.worst_quantile, |
| | best_quantile=args.best_quantile, |
| | good_quantile=args.good_quantile, |
| | ) |
| | per_record, dataset_metrics = evaluator.evaluate() |
| | evaluator.save(per_record, dataset_metrics) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |