| import os |
| import json |
| import logging |
| from typing import Dict, List, Tuple, Any |
| import numpy as np |
| from rouge_score import rouge_scorer |
| from bert_score import score as bert_score |
| from transformers import AutoTokenizer |
| import torch |
| import argparse |
|
|
|
|
| class SyntheticSummariesEvaluator: |
| def __init__( |
| self, |
| input_path: str, |
| output_dir: str = "metrics", |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| max_length: int = 512, |
| batch_size: int = 16, |
| rescale_with_baseline: bool = False, |
| include_article: bool = False, |
| w_rouge: float = 0.5, |
| w_bert: float = 0.5, |
| worst_quantile: float = 0.33, |
| good_quantile: float = 0.5, |
| best_quantile: float = 0.67, |
| |
| ): |
| self.input_path = input_path |
| self.output_dir = output_dir |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| with open(input_path, "r", encoding="utf-8") as f: |
| self.data: List[Dict[str, Any]] = json.load(f) |
|
|
| self.device = device |
| self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
| self.max_length = max_length |
| self.batch_size = batch_size |
| self.rescale_with_baseline = rescale_with_baseline |
| self.include_article = include_article |
|
|
| |
| s = (w_rouge + w_bert) or 1.0 |
| self.w_rouge = float(w_rouge) / s |
| self.w_bert = float(w_bert) / s |
|
|
| |
| if not (0.0 <= worst_quantile < best_quantile <= 1.0): |
| logging.warning("Invalid quantiles; resetting to worst=0.33, best=0.67") |
| worst_quantile, best_quantile = 0.33, 0.67 |
| self.worst_q = worst_quantile |
| self.best_q = best_quantile |
| self.good_q = good_quantile |
|
|
| self.rouge = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True) |
|
|
| def _truncate(self, text: str) -> str: |
| tokens = self.tokenizer.encode( |
| text, |
| add_special_tokens=True, |
| max_length=self.max_length, |
| truncation=True, |
| ) |
| return self.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
| def _compute_rougeLsum_f1(self, ref: str, hyp: str) -> float: |
| result = self.rouge.score(ref, hyp) |
| return float(result["rougeLsum"].fmeasure) |
|
|
| def _combine(self, rouge: float, bert_f: float) -> float: |
| |
| vals, ws = [], [] |
| if rouge == rouge: |
| vals.append(rouge); ws.append(self.w_rouge) |
| if bert_f == bert_f: |
| vals.append(bert_f); ws.append(self.w_bert) |
| if not ws: |
| return float("nan") |
| s = sum(ws) |
| ws = [w / s for w in ws] |
| return float(sum(v * w for v, w in zip(vals, ws))) |
|
|
| def evaluate(self): |
| |
| pair_indices: List[Tuple[int, str]] = [] |
| cands_trunc, refs_trunc = [], [] |
| rouge_store: Dict[Tuple[int, str], float] = {} |
|
|
| for i, rec in enumerate(self.data): |
| gold = rec.get("gold_summary", "") |
| syn = rec.get("synthetic_summary", {}) or {} |
|
|
| for key in syn.keys(): |
| cand = syn[key] if isinstance(syn[key], str) else str(syn[key]) |
| cands_trunc.append(self._truncate(cand)) |
| refs_trunc.append(self._truncate(gold)) |
| pair_indices.append((i, key)) |
| rouge_store[(i, key)] = self._compute_rougeLsum_f1(gold, cand) |
|
|
| |
| F_vals = [np.nan] * len(pair_indices) |
| if len(pair_indices) > 0: |
| try: |
| _, _, F = bert_score( |
| cands=cands_trunc, |
| refs=refs_trunc, |
| model_type="emilyalsentzer/Bio_ClinicalBERT", |
| num_layers=12, |
| lang="en", |
| device=self.device, |
| rescale_with_baseline=self.rescale_with_baseline, |
| batch_size=self.batch_size, |
| ) |
| F_vals = F.tolist() |
| except Exception as e: |
| logging.error(f"Error computing BERTScore: {e}", exc_info=True) |
|
|
| |
| results_per_record: List[Dict[str, Any]] = [] |
| for i, rec in enumerate(self.data): |
| out = { |
| "id": i, |
| "gold_summary": rec.get("gold_summary", ""), |
| "synthetic_summary": {} |
| } |
| if self.include_article: |
| out["article"] = rec.get("article", "") |
| syn = rec.get("synthetic_summary", {}) or {} |
| for key in syn.keys(): |
| out["synthetic_summary"][key] = { |
| "text": syn[key] if isinstance(syn[key], str) else str(syn[key]), |
| "score": {} |
| } |
| results_per_record.append(out) |
|
|
| |
| idx_map = {(i_k[0], i_k[1]): idx for idx, i_k in enumerate(pair_indices)} |
|
|
| |
| per_pair_combined: Dict[Tuple[int, str], float] = {} |
| level_scores = {"B1": [], "B2": [], "B3": []} |
| for (i, key), idx in idx_map.items(): |
| r = rouge_store[(i, key)] |
| f = F_vals[idx] |
| c = self._combine(r, f) |
| per_pair_combined[(i, key)] = c |
| if key in level_scores: |
| level_scores[key].append(c) |
|
|
| |
| thresholds = {} |
| for key in ["B1", "B2", "B3"]: |
| scores = np.array(level_scores[key], dtype=float) |
| if scores.size > 0 and np.any(scores == scores): |
| worst_thr = float(np.nanpercentile(scores, self.worst_q * 100)) |
| best_thr = float(np.nanpercentile(scores, self.best_q * 100)) |
| good_thr = float(np.nanpercentile(scores, self.good_q * 100)) |
| else: |
| worst_thr = best_thr = good_thr = float("-inf") |
| thresholds[key] = { |
| "worst_thr": worst_thr, |
| "best_thr": best_thr, |
| "good_thr": good_thr |
| } |
|
|
| |
| agg = { |
| "B1": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| "B2": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| "B3": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0, |
| "best": 0, "good": 0, "worst": 0, "good_true": 0}, |
| } |
|
|
| for (i, key), idx in idx_map.items(): |
| r = rouge_store[(i, key)] |
| f = F_vals[idx] |
| c = per_pair_combined[(i, key)] |
|
|
| |
| results_per_record[i]["synthetic_summary"][key]["score"] = { |
| "ROUGE-L-Sum": float(r) if r == r else None, |
| "BERTScore_F": float(f) if f == f else None, |
| } |
|
|
| |
| thr = thresholds.get(key, {"worst_thr": float("-inf"), "best_thr": float("-inf"), "good_thr": float("-inf")}) |
| if not (c == c): |
| category = "worst" |
| is_good = False |
| else: |
| if c < thr["worst_thr"]: |
| category = "worst" |
| elif c < thr["best_thr"]: |
| category = "good" |
| else: |
| category = "best" |
| is_good = c >= thr["good_thr"] |
|
|
| results_per_record[i]["synthetic_summary"][key]["quality"] = { |
| "category": category, |
| "is_good": bool(is_good), |
| "combined_score": float(c) if c == c else None |
| } |
|
|
| |
| if key in agg: |
| if r == r: |
| agg[key]["ROUGE-L-Sum"].append(float(r)) |
| if f == f: |
| agg[key]["BERTScore_F"].append(float(f)) |
| if c == c: |
| agg[key]["combined"].append(float(c)) |
| agg[key]["count"] += 1 |
| agg[key][category] += 1 |
| if is_good: |
| agg[key]["good_true"] += 1 |
|
|
| |
| dataset_level_metrics = { |
| "config": { |
| "weights": {"w_rouge": self.w_rouge, "w_bert": self.w_bert}, |
| "quantiles": {"worst_q": self.worst_q, "best_q": self.best_q, "good_q": self.good_q}, |
| "thresholds": thresholds, |
| } |
| } |
| for key, m in agg.items(): |
| count = max(1, m["count"]) |
| dataset_level_metrics[key] = { |
| "ROUGE-L-Sum": float(np.mean(m["ROUGE-L-Sum"])) if m["ROUGE-L-Sum"] else None, |
| "BERTScore_F": float(np.mean(m["BERTScore_F"])) if m["BERTScore_F"] else None, |
| "combined_mean": float(np.mean(m["combined"])) if m["combined"] else None, |
| "count": m["count"], |
| "best_rate": m["best"] / count, |
| "good_rate": m["good"] / count, |
| "worst_rate": m["worst"] / count, |
| "is_good_rate": m["good_true"] / count |
| } |
|
|
| return results_per_record, dataset_level_metrics |
|
|
| def save(self, per_record: List[Dict[str, Any]], dataset_metrics: Dict[str, Dict[str, float]]): |
| base = os.path.splitext(os.path.basename(self.input_path))[0] |
| per_record_path = os.path.join(self.output_dir, f"{base}_scored.json") |
| aggregate_path = os.path.join(self.output_dir, f"{base}_aggregate_metrics.json") |
|
|
| with open(per_record_path, "w", encoding="utf-8") as f: |
| json.dump(per_record, f, ensure_ascii=False, indent=2) |
|
|
| with open(aggregate_path, "w", encoding="utf-8") as f: |
| json.dump(dataset_metrics, f, ensure_ascii=False, indent=2) |
|
|
| print("Saved:") |
| print(f"- Per-record scores: {per_record_path}") |
| print(f"- Aggregate metrics: {aggregate_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Evaluate B1/B2/B3 summaries vs gold. Metrics: ROUGE-Lsum F1, BERTScore F1. Per-level categories: best/good/worst + is_good." |
| ) |
| parser.add_argument("--input_path", required=True, help="Path to the es_syntheticV3.json file") |
| parser.add_argument("--output_dir", default="metrics", help="Where to save outputs") |
| parser.add_argument("--batch_size", type=int, default=16, help="BERTScore batch size") |
| parser.add_argument("--max_length", type=int, default=512, help="Max tokens for truncation (BERTScore)") |
| parser.add_argument("--rescale_with_baseline", action="store_true", help="Use BERTScore baseline rescaling") |
| parser.add_argument("--include_article", action="store_true", help="Include full article text in output JSON") |
| parser.add_argument("--w_rouge", type=float, default=0.5, help="Weight for ROUGE-L-Sum in combined score") |
| parser.add_argument("--w_bert", type=float, default=0.5, help="Weight for BERTScore_F in combined score") |
| parser.add_argument("--worst_quantile", type=float, default=0.33, help="Bottom quantile -> 'worst'") |
| parser.add_argument("--best_quantile", type=float, default=0.67, help="Top quantile boundary -> 'best'") |
| parser.add_argument("--good_quantile", type=float, default=0.5, help="Quantile for is_good=True") |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
| evaluator = SyntheticSummariesEvaluator( |
| input_path=args.input_path, |
| output_dir=args.output_dir, |
| batch_size=args.batch_size, |
| max_length=args.max_length, |
| rescale_with_baseline=args.rescale_with_baseline, |
| include_article=args.include_article, |
| w_rouge=args.w_rouge, |
| w_bert=args.w_bert, |
| worst_quantile=args.worst_quantile, |
| best_quantile=args.best_quantile, |
| good_quantile=args.good_quantile, |
| ) |
| per_record, dataset_metrics = evaluator.evaluate() |
| evaluator.save(per_record, dataset_metrics) |
|
|
|
|
| if __name__ == "__main__": |
| main() |