from __future__ import annotations import argparse import json import os from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Iterable, List, Sequence, Set from psq_rag.llm.rewrite_local_t5 import local_t5_rewrite_prompt from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words REPO_ROOT = Path(__file__).resolve().parents[1] DEFAULT_SAMPLES = ( REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl" ) def _canon_tag(tag: str) -> str: t = " ".join(str(tag or "").strip().split()).lower() return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") def _parse_tag_set(text: str) -> Set[str]: out: Set[str] = set() for raw in (text or "").split(","): t = _canon_tag(raw) if t: out.add(t) return out def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]: n = len(pred_sets) if n == 0: return { "n": 0, "set_precision": 0.0, "set_recall": 0.0, "set_f1": 0.0, "exact_set_match": 0.0, "avg_pred_tags": 0.0, "avg_gold_tags": 0.0, } p_vals: List[float] = [] r_vals: List[float] = [] f_vals: List[float] = [] exact = 0 pred_sizes: List[int] = [] gold_sizes: List[int] = [] for pset, gset in zip(pred_sets, gold_sets): pred_sizes.append(len(pset)) gold_sizes.append(len(gset)) if pset == gset: exact += 1 if not pset and not gset: p_vals.append(1.0) r_vals.append(1.0) f_vals.append(1.0) continue if not pset or not gset: p_vals.append(0.0) r_vals.append(0.0) f_vals.append(0.0) continue tp = len(pset & gset) p = tp / len(pset) r = tp / len(gset) f = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 p_vals.append(p) r_vals.append(r) f_vals.append(f) return { "n": float(n), "set_precision": sum(p_vals) / n, "set_recall": sum(r_vals) / n, "set_f1": sum(f_vals) / n, "exact_set_match": exact / n, "avg_pred_tags": sum(pred_sizes) / n, "avg_gold_tags": sum(gold_sizes) / n, } @dataclass class SampleRow: sample_id: int caption: str gold_tags: Set[str] def _load_rows(samples_path: Path, caption_field: str) -> List[SampleRow]: rows: List[SampleRow] = [] with samples_path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue obj = json.loads(line) if obj.get("_meta"): continue cap = str(obj.get(caption_field, "") or "").strip() if not cap: continue gt: Set[str] = set() expanded = obj.get("tags_ground_truth_expanded") if isinstance(expanded, list): for t in expanded: c = _canon_tag(str(t)) if c: gt.add(c) if not gt: cat = obj.get("tags_ground_truth_categorized") if isinstance(cat, str): try: cat = json.loads(cat) except Exception: cat = {} if isinstance(cat, dict): for vals in cat.values(): if isinstance(vals, list): for t in vals: c = _canon_tag(str(t)) if c: gt.add(c) rows.append( SampleRow( sample_id=int(obj.get("id", -1)), caption=cap, gold_tags=gt, ) ) return rows def _eval_model( model_dir: Path, rows: Sequence[SampleRow], num_beams: int, max_new_tokens: int, max_source_length: int, ) -> Dict[str, Dict[str, float]]: pred_no_heur: List[Set[str]] = [] pred_with_heur: List[Set[str]] = [] gold_sets: List[Set[str]] = [] for r in rows: out = local_t5_rewrite_prompt( r.caption, log=lambda _msg: None, model_dir=str(model_dir), num_beams=num_beams, max_new_tokens=max_new_tokens, max_source_length=max_source_length, ) p0 = _parse_tag_set(out) p1 = set(p0) if out.strip(): for term in extract_user_provided_tags_upto_3_words(r.caption): c = _canon_tag(term) if c: p1.add(c) pred_no_heur.append(p0) pred_with_heur.append(p1) gold_sets.append(r.gold_tags) return { "t5_local_rewrite_no_heur": _set_metrics(pred_no_heur, gold_sets), "t5_local_rewrite_with_heur": _set_metrics(pred_with_heur, gold_sets), } def _iter_model_dirs(raw_models: Iterable[str]) -> List[Path]: dirs: List[Path] = [] for raw in raw_models: p = Path(raw) if not p.is_absolute(): p = (REPO_ROOT / p).resolve() dirs.append(p) return dirs def main() -> int: ap = argparse.ArgumentParser(description="Rewrite-only T5 evaluation on caption-evident n30 set.") ap.add_argument("--samples", type=Path, default=DEFAULT_SAMPLES) ap.add_argument("--caption-field", type=str, default="caption_cogvlm") ap.add_argument("--limit", type=int, default=30) ap.add_argument("--num-beams", type=int, default=4) ap.add_argument("--max-new-tokens", type=int, default=128) ap.add_argument("--max-source-length", type=int, default=160) ap.add_argument("--model-dir", action="append", required=True, help="Model directory; repeat for multiple models.") ap.add_argument("--out-json", type=Path, default=REPO_ROOT / "data" / "analysis") args = ap.parse_args() samples_path = args.samples if args.samples.is_absolute() else (REPO_ROOT / args.samples).resolve() if not samples_path.is_file(): raise FileNotFoundError(f"Samples file not found: {samples_path}") model_dirs = _iter_model_dirs(args.model_dir) for d in model_dirs: if not d.is_dir(): raise FileNotFoundError(f"Model directory not found: {d}") if not (d / "model.safetensors").is_file(): raise FileNotFoundError(f"model.safetensors missing in: {d}") rows = _load_rows(samples_path, args.caption_field) if args.limit > 0: rows = rows[: min(args.limit, len(rows))] result_rows = [] for d in model_dirs: metrics = _eval_model( d, rows=rows, num_beams=max(1, args.num_beams), max_new_tokens=max(8, args.max_new_tokens), max_source_length=max(16, args.max_source_length), ) result_rows.append( { "model_dir": str(d), "metrics": metrics, } ) no_h = metrics["t5_local_rewrite_no_heur"] with_h = metrics["t5_local_rewrite_with_heur"] print( f"{d.name}: " f"no_heur R={no_h['set_recall']:.4f} F1={no_h['set_f1']:.4f} " f"| with_heur R={with_h['set_recall']:.4f} F1={with_h['set_f1']:.4f}" ) out_base = args.out_json if args.out_json.is_absolute() else (REPO_ROOT / args.out_json).resolve() if out_base.suffix.lower() == ".json": out_path = out_base out_path.parent.mkdir(parents=True, exist_ok=True) else: out_base.mkdir(parents=True, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") out_path = out_base / f"rewrite_only_compare_n30_t5_sweep_{ts}.json" payload = { "meta": { "timestamp": datetime.now().isoformat(), "samples_path": str(samples_path), "caption_field": args.caption_field, "n_samples": len(rows), "num_beams": args.num_beams, "max_new_tokens": args.max_new_tokens, "max_source_length": args.max_source_length, "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), }, "rows": result_rows, } out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") print(f"Saved: {out_path}") return 0 if __name__ == "__main__": raise SystemExit(main())