import os import json import logging from typing import Dict, List, Tuple, Any import numpy as np from rouge_score import rouge_scorer from bert_score import score as bert_score from transformers import AutoTokenizer import torch import argparse class SyntheticSummariesEvaluator: def __init__( self, input_path: str, output_dir: str = "metrics", device: str = "cuda" if torch.cuda.is_available() else "cpu", max_length: int = 512, batch_size: int = 16, rescale_with_baseline: bool = False, include_article: bool = False, w_rouge: float = 0.5, w_bert: float = 0.5, worst_quantile: float = 0.33, good_quantile: float = 0.5, best_quantile: float = 0.67, # per-level threshold for is_good ): self.input_path = input_path self.output_dir = output_dir os.makedirs(output_dir, exist_ok=True) with open(input_path, "r", encoding="utf-8") as f: self.data: List[Dict[str, Any]] = json.load(f) self.device = device self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") self.max_length = max_length self.batch_size = batch_size self.rescale_with_baseline = rescale_with_baseline self.include_article = include_article # Normalize weights s = (w_rouge + w_bert) or 1.0 self.w_rouge = float(w_rouge) / s self.w_bert = float(w_bert) / s # Quantiles per level (B1/B2/B3) if not (0.0 <= worst_quantile < best_quantile <= 1.0): logging.warning("Invalid quantiles; resetting to worst=0.33, best=0.67") worst_quantile, best_quantile = 0.33, 0.67 self.worst_q = worst_quantile self.best_q = best_quantile self.good_q = good_quantile self.rouge = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True) def _truncate(self, text: str) -> str: tokens = self.tokenizer.encode( text, add_special_tokens=True, max_length=self.max_length, truncation=True, ) return self.tokenizer.decode(tokens, skip_special_tokens=True) def _compute_rougeLsum_f1(self, ref: str, hyp: str) -> float: result = self.rouge.score(ref, hyp) return float(result["rougeLsum"].fmeasure) def _combine(self, rouge: float, bert_f: float) -> float: # Weighted average, ignoring NaNs vals, ws = [], [] if rouge == rouge: vals.append(rouge); ws.append(self.w_rouge) if bert_f == bert_f: vals.append(bert_f); ws.append(self.w_bert) if not ws: return float("nan") s = sum(ws) ws = [w / s for w in ws] return float(sum(v * w for v, w in zip(vals, ws))) def evaluate(self): # Build pairs for batched BERTScore pair_indices: List[Tuple[int, str]] = [] # (record_idx, "B1"/"B2"/"B3") cands_trunc, refs_trunc = [], [] rouge_store: Dict[Tuple[int, str], float] = {} for i, rec in enumerate(self.data): gold = rec.get("gold_summary", "") syn = rec.get("synthetic_summary", {}) or {} for key in syn.keys(): # B1/B2/B3 cand = syn[key] if isinstance(syn[key], str) else str(syn[key]) cands_trunc.append(self._truncate(cand)) refs_trunc.append(self._truncate(gold)) pair_indices.append((i, key)) rouge_store[(i, key)] = self._compute_rougeLsum_f1(gold, cand) # Compute BERTScore F1 F_vals = [np.nan] * len(pair_indices) if len(pair_indices) > 0: try: _, _, F = bert_score( cands=cands_trunc, refs=refs_trunc, model_type="emilyalsentzer/Bio_ClinicalBERT", num_layers=12, lang="en", device=self.device, rescale_with_baseline=self.rescale_with_baseline, batch_size=self.batch_size, ) F_vals = F.tolist() except Exception as e: logging.error(f"Error computing BERTScore: {e}", exc_info=True) # Prepare per-record output results_per_record: List[Dict[str, Any]] = [] for i, rec in enumerate(self.data): out = { "id": i, "gold_summary": rec.get("gold_summary", ""), "synthetic_summary": {} } if self.include_article: out["article"] = rec.get("article", "") syn = rec.get("synthetic_summary", {}) or {} for key in syn.keys(): out["synthetic_summary"][key] = { "text": syn[key] if isinstance(syn[key], str) else str(syn[key]), "score": {} } results_per_record.append(out) # Map (i,key) -> idx idx_map = {(i_k[0], i_k[1]): idx for idx, i_k in enumerate(pair_indices)} # Compute combined scores and collect per-level distributions per_pair_combined: Dict[Tuple[int, str], float] = {} level_scores = {"B1": [], "B2": [], "B3": []} for (i, key), idx in idx_map.items(): r = rouge_store[(i, key)] f = F_vals[idx] c = self._combine(r, f) per_pair_combined[(i, key)] = c if key in level_scores: level_scores[key].append(c) # Per-level thresholds thresholds = {} for key in ["B1", "B2", "B3"]: scores = np.array(level_scores[key], dtype=float) if scores.size > 0 and np.any(scores == scores): # any non-NaN worst_thr = float(np.nanpercentile(scores, self.worst_q * 100)) best_thr = float(np.nanpercentile(scores, self.best_q * 100)) good_thr = float(np.nanpercentile(scores, self.good_q * 100)) else: worst_thr = best_thr = good_thr = float("-inf") thresholds[key] = { "worst_thr": worst_thr, "best_thr": best_thr, "good_thr": good_thr } # Fill per-record metrics and categories (independent per level) agg = { "B1": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, "best": 0, "good": 0, "worst": 0, "good_true": 0}, "B2": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, "best": 0, "good": 0, "worst": 0, "good_true": 0}, "B3": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, "best": 0, "good": 0, "worst": 0, "good_true": 0}, } for (i, key), idx in idx_map.items(): r = rouge_store[(i, key)] f = F_vals[idx] c = per_pair_combined[(i, key)] # Save scores results_per_record[i]["synthetic_summary"][key]["score"] = { "ROUGE-L-Sum": float(r) if r == r else None, "BERTScore_F": float(f) if f == f else None, } # Independent per-level category thr = thresholds.get(key, {"worst_thr": float("-inf"), "best_thr": float("-inf"), "good_thr": float("-inf")}) if not (c == c): # NaN category = "worst" is_good = False else: if c < thr["worst_thr"]: category = "worst" elif c < thr["best_thr"]: category = "good" else: category = "best" is_good = c >= thr["good_thr"] results_per_record[i]["synthetic_summary"][key]["quality"] = { "category": category, "is_good": bool(is_good), "combined_score": float(c) if c == c else None } # Aggregates if key in agg: if r == r: agg[key]["ROUGE-L-Sum"].append(float(r)) if f == f: agg[key]["BERTScore_F"].append(float(f)) if c == c: agg[key]["combined"].append(float(c)) agg[key]["count"] += 1 agg[key][category] += 1 if is_good: agg[key]["good_true"] += 1 # Dataset-level summary dataset_level_metrics = { "config": { "weights": {"w_rouge": self.w_rouge, "w_bert": self.w_bert}, "quantiles": {"worst_q": self.worst_q, "best_q": self.best_q, "good_q": self.good_q}, "thresholds": thresholds, # per-level thresholds used } } for key, m in agg.items(): count = max(1, m["count"]) dataset_level_metrics[key] = { "ROUGE-L-Sum": float(np.mean(m["ROUGE-L-Sum"])) if m["ROUGE-L-Sum"] else None, "BERTScore_F": float(np.mean(m["BERTScore_F"])) if m["BERTScore_F"] else None, "combined_mean": float(np.mean(m["combined"])) if m["combined"] else None, "count": m["count"], "best_rate": m["best"] / count, "good_rate": m["good"] / count, "worst_rate": m["worst"] / count, "is_good_rate": m["good_true"] / count } return results_per_record, dataset_level_metrics def save(self, per_record: List[Dict[str, Any]], dataset_metrics: Dict[str, Dict[str, float]]): base = os.path.splitext(os.path.basename(self.input_path))[0] per_record_path = os.path.join(self.output_dir, f"{base}_scored.json") aggregate_path = os.path.join(self.output_dir, f"{base}_aggregate_metrics.json") with open(per_record_path, "w", encoding="utf-8") as f: json.dump(per_record, f, ensure_ascii=False, indent=2) with open(aggregate_path, "w", encoding="utf-8") as f: json.dump(dataset_metrics, f, ensure_ascii=False, indent=2) print("Saved:") print(f"- Per-record scores: {per_record_path}") print(f"- Aggregate metrics: {aggregate_path}") def main(): parser = argparse.ArgumentParser( description="Evaluate B1/B2/B3 summaries vs gold. Metrics: ROUGE-Lsum F1, BERTScore F1. Per-level categories: best/good/worst + is_good." ) parser.add_argument("--input_path", required=True, help="Path to the es_syntheticV3.json file") parser.add_argument("--output_dir", default="metrics", help="Where to save outputs") parser.add_argument("--batch_size", type=int, default=16, help="BERTScore batch size") parser.add_argument("--max_length", type=int, default=512, help="Max tokens for truncation (BERTScore)") parser.add_argument("--rescale_with_baseline", action="store_true", help="Use BERTScore baseline rescaling") parser.add_argument("--include_article", action="store_true", help="Include full article text in output JSON") parser.add_argument("--w_rouge", type=float, default=0.5, help="Weight for ROUGE-L-Sum in combined score") parser.add_argument("--w_bert", type=float, default=0.5, help="Weight for BERTScore_F in combined score") parser.add_argument("--worst_quantile", type=float, default=0.33, help="Bottom quantile -> 'worst'") parser.add_argument("--best_quantile", type=float, default=0.67, help="Top quantile boundary -> 'best'") parser.add_argument("--good_quantile", type=float, default=0.5, help="Quantile for is_good=True") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") evaluator = SyntheticSummariesEvaluator( input_path=args.input_path, output_dir=args.output_dir, batch_size=args.batch_size, max_length=args.max_length, rescale_with_baseline=args.rescale_with_baseline, include_article=args.include_article, w_rouge=args.w_rouge, w_bert=args.w_bert, worst_quantile=args.worst_quantile, best_quantile=args.best_quantile, good_quantile=args.good_quantile, ) per_record, dataset_metrics = evaluator.evaluate() evaluator.save(per_record, dataset_metrics) if __name__ == "__main__": main()