readctrl / code /old /evalV3.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
import json
import logging
from typing import Dict, List, Tuple, Any
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from transformers import AutoTokenizer
import torch
import argparse
class SyntheticSummariesEvaluator:
def __init__(
self,
input_path: str,
output_dir: str = "metrics",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
max_length: int = 512,
batch_size: int = 16,
rescale_with_baseline: bool = False,
include_article: bool = False,
w_rouge: float = 0.5,
w_bert: float = 0.5,
worst_quantile: float = 0.33,
good_quantile: float = 0.5,
best_quantile: float = 0.67,
# per-level threshold for is_good
):
self.input_path = input_path
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
with open(input_path, "r", encoding="utf-8") as f:
self.data: List[Dict[str, Any]] = json.load(f)
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
self.max_length = max_length
self.batch_size = batch_size
self.rescale_with_baseline = rescale_with_baseline
self.include_article = include_article
# Normalize weights
s = (w_rouge + w_bert) or 1.0
self.w_rouge = float(w_rouge) / s
self.w_bert = float(w_bert) / s
# Quantiles per level (B1/B2/B3)
if not (0.0 <= worst_quantile < best_quantile <= 1.0):
logging.warning("Invalid quantiles; resetting to worst=0.33, best=0.67")
worst_quantile, best_quantile = 0.33, 0.67
self.worst_q = worst_quantile
self.best_q = best_quantile
self.good_q = good_quantile
self.rouge = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True)
def _truncate(self, text: str) -> str:
tokens = self.tokenizer.encode(
text,
add_special_tokens=True,
max_length=self.max_length,
truncation=True,
)
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _compute_rougeLsum_f1(self, ref: str, hyp: str) -> float:
result = self.rouge.score(ref, hyp)
return float(result["rougeLsum"].fmeasure)
def _combine(self, rouge: float, bert_f: float) -> float:
# Weighted average, ignoring NaNs
vals, ws = [], []
if rouge == rouge:
vals.append(rouge); ws.append(self.w_rouge)
if bert_f == bert_f:
vals.append(bert_f); ws.append(self.w_bert)
if not ws:
return float("nan")
s = sum(ws)
ws = [w / s for w in ws]
return float(sum(v * w for v, w in zip(vals, ws)))
def evaluate(self):
# Build pairs for batched BERTScore
pair_indices: List[Tuple[int, str]] = [] # (record_idx, "B1"/"B2"/"B3")
cands_trunc, refs_trunc = [], []
rouge_store: Dict[Tuple[int, str], float] = {}
for i, rec in enumerate(self.data):
gold = rec.get("gold_summary", "")
syn = rec.get("synthetic_summary", {}) or {}
for key in syn.keys(): # B1/B2/B3
cand = syn[key] if isinstance(syn[key], str) else str(syn[key])
cands_trunc.append(self._truncate(cand))
refs_trunc.append(self._truncate(gold))
pair_indices.append((i, key))
rouge_store[(i, key)] = self._compute_rougeLsum_f1(gold, cand)
# Compute BERTScore F1
F_vals = [np.nan] * len(pair_indices)
if len(pair_indices) > 0:
try:
_, _, F = bert_score(
cands=cands_trunc,
refs=refs_trunc,
model_type="emilyalsentzer/Bio_ClinicalBERT",
num_layers=12,
lang="en",
device=self.device,
rescale_with_baseline=self.rescale_with_baseline,
batch_size=self.batch_size,
)
F_vals = F.tolist()
except Exception as e:
logging.error(f"Error computing BERTScore: {e}", exc_info=True)
# Prepare per-record output
results_per_record: List[Dict[str, Any]] = []
for i, rec in enumerate(self.data):
out = {
"id": i,
"gold_summary": rec.get("gold_summary", ""),
"synthetic_summary": {}
}
if self.include_article:
out["article"] = rec.get("article", "")
syn = rec.get("synthetic_summary", {}) or {}
for key in syn.keys():
out["synthetic_summary"][key] = {
"text": syn[key] if isinstance(syn[key], str) else str(syn[key]),
"score": {}
}
results_per_record.append(out)
# Map (i,key) -> idx
idx_map = {(i_k[0], i_k[1]): idx for idx, i_k in enumerate(pair_indices)}
# Compute combined scores and collect per-level distributions
per_pair_combined: Dict[Tuple[int, str], float] = {}
level_scores = {"B1": [], "B2": [], "B3": []}
for (i, key), idx in idx_map.items():
r = rouge_store[(i, key)]
f = F_vals[idx]
c = self._combine(r, f)
per_pair_combined[(i, key)] = c
if key in level_scores:
level_scores[key].append(c)
# Per-level thresholds
thresholds = {}
for key in ["B1", "B2", "B3"]:
scores = np.array(level_scores[key], dtype=float)
if scores.size > 0 and np.any(scores == scores): # any non-NaN
worst_thr = float(np.nanpercentile(scores, self.worst_q * 100))
best_thr = float(np.nanpercentile(scores, self.best_q * 100))
good_thr = float(np.nanpercentile(scores, self.good_q * 100))
else:
worst_thr = best_thr = good_thr = float("-inf")
thresholds[key] = {
"worst_thr": worst_thr,
"best_thr": best_thr,
"good_thr": good_thr
}
# Fill per-record metrics and categories (independent per level)
agg = {
"B1": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0,
"best": 0, "good": 0, "worst": 0, "good_true": 0},
"B2": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0,
"best": 0, "good": 0, "worst": 0, "good_true": 0},
"B3": {"ROUGE-L-Sum": [], "BERTScore_F": [], "combined": [], "count": 0,
"best": 0, "good": 0, "worst": 0, "good_true": 0},
}
for (i, key), idx in idx_map.items():
r = rouge_store[(i, key)]
f = F_vals[idx]
c = per_pair_combined[(i, key)]
# Save scores
results_per_record[i]["synthetic_summary"][key]["score"] = {
"ROUGE-L-Sum": float(r) if r == r else None,
"BERTScore_F": float(f) if f == f else None,
}
# Independent per-level category
thr = thresholds.get(key, {"worst_thr": float("-inf"), "best_thr": float("-inf"), "good_thr": float("-inf")})
if not (c == c): # NaN
category = "worst"
is_good = False
else:
if c < thr["worst_thr"]:
category = "worst"
elif c < thr["best_thr"]:
category = "good"
else:
category = "best"
is_good = c >= thr["good_thr"]
results_per_record[i]["synthetic_summary"][key]["quality"] = {
"category": category,
"is_good": bool(is_good),
"combined_score": float(c) if c == c else None
}
# Aggregates
if key in agg:
if r == r:
agg[key]["ROUGE-L-Sum"].append(float(r))
if f == f:
agg[key]["BERTScore_F"].append(float(f))
if c == c:
agg[key]["combined"].append(float(c))
agg[key]["count"] += 1
agg[key][category] += 1
if is_good:
agg[key]["good_true"] += 1
# Dataset-level summary
dataset_level_metrics = {
"config": {
"weights": {"w_rouge": self.w_rouge, "w_bert": self.w_bert},
"quantiles": {"worst_q": self.worst_q, "best_q": self.best_q, "good_q": self.good_q},
"thresholds": thresholds, # per-level thresholds used
}
}
for key, m in agg.items():
count = max(1, m["count"])
dataset_level_metrics[key] = {
"ROUGE-L-Sum": float(np.mean(m["ROUGE-L-Sum"])) if m["ROUGE-L-Sum"] else None,
"BERTScore_F": float(np.mean(m["BERTScore_F"])) if m["BERTScore_F"] else None,
"combined_mean": float(np.mean(m["combined"])) if m["combined"] else None,
"count": m["count"],
"best_rate": m["best"] / count,
"good_rate": m["good"] / count,
"worst_rate": m["worst"] / count,
"is_good_rate": m["good_true"] / count
}
return results_per_record, dataset_level_metrics
def save(self, per_record: List[Dict[str, Any]], dataset_metrics: Dict[str, Dict[str, float]]):
base = os.path.splitext(os.path.basename(self.input_path))[0]
per_record_path = os.path.join(self.output_dir, f"{base}_scored.json")
aggregate_path = os.path.join(self.output_dir, f"{base}_aggregate_metrics.json")
with open(per_record_path, "w", encoding="utf-8") as f:
json.dump(per_record, f, ensure_ascii=False, indent=2)
with open(aggregate_path, "w", encoding="utf-8") as f:
json.dump(dataset_metrics, f, ensure_ascii=False, indent=2)
print("Saved:")
print(f"- Per-record scores: {per_record_path}")
print(f"- Aggregate metrics: {aggregate_path}")
def main():
parser = argparse.ArgumentParser(
description="Evaluate B1/B2/B3 summaries vs gold. Metrics: ROUGE-Lsum F1, BERTScore F1. Per-level categories: best/good/worst + is_good."
)
parser.add_argument("--input_path", required=True, help="Path to the es_syntheticV3.json file")
parser.add_argument("--output_dir", default="metrics", help="Where to save outputs")
parser.add_argument("--batch_size", type=int, default=16, help="BERTScore batch size")
parser.add_argument("--max_length", type=int, default=512, help="Max tokens for truncation (BERTScore)")
parser.add_argument("--rescale_with_baseline", action="store_true", help="Use BERTScore baseline rescaling")
parser.add_argument("--include_article", action="store_true", help="Include full article text in output JSON")
parser.add_argument("--w_rouge", type=float, default=0.5, help="Weight for ROUGE-L-Sum in combined score")
parser.add_argument("--w_bert", type=float, default=0.5, help="Weight for BERTScore_F in combined score")
parser.add_argument("--worst_quantile", type=float, default=0.33, help="Bottom quantile -> 'worst'")
parser.add_argument("--best_quantile", type=float, default=0.67, help="Top quantile boundary -> 'best'")
parser.add_argument("--good_quantile", type=float, default=0.5, help="Quantile for is_good=True")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
evaluator = SyntheticSummariesEvaluator(
input_path=args.input_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
max_length=args.max_length,
rescale_with_baseline=args.rescale_with_baseline,
include_article=args.include_article,
w_rouge=args.w_rouge,
w_bert=args.w_bert,
worst_quantile=args.worst_quantile,
best_quantile=args.best_quantile,
good_quantile=args.good_quantile,
)
per_record, dataset_metrics = evaluator.evaluate()
evaluator.save(per_record, dataset_metrics)
if __name__ == "__main__":
main()