Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Sequence, Set | |
| from psq_rag.llm.rewrite_local_t5 import local_t5_rewrite_prompt | |
| from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_SAMPLES = ( | |
| REPO_ROOT | |
| / "data" | |
| / "eval_samples" | |
| / "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl" | |
| ) | |
| def _canon_tag(tag: str) -> str: | |
| t = " ".join(str(tag or "").strip().split()).lower() | |
| return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
| def _parse_tag_set(text: str) -> Set[str]: | |
| out: Set[str] = set() | |
| for raw in (text or "").split(","): | |
| t = _canon_tag(raw) | |
| if t: | |
| out.add(t) | |
| return out | |
| def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]: | |
| n = len(pred_sets) | |
| if n == 0: | |
| return { | |
| "n": 0, | |
| "set_precision": 0.0, | |
| "set_recall": 0.0, | |
| "set_f1": 0.0, | |
| "exact_set_match": 0.0, | |
| "avg_pred_tags": 0.0, | |
| "avg_gold_tags": 0.0, | |
| } | |
| p_vals: List[float] = [] | |
| r_vals: List[float] = [] | |
| f_vals: List[float] = [] | |
| exact = 0 | |
| pred_sizes: List[int] = [] | |
| gold_sizes: List[int] = [] | |
| for pset, gset in zip(pred_sets, gold_sets): | |
| pred_sizes.append(len(pset)) | |
| gold_sizes.append(len(gset)) | |
| if pset == gset: | |
| exact += 1 | |
| if not pset and not gset: | |
| p_vals.append(1.0) | |
| r_vals.append(1.0) | |
| f_vals.append(1.0) | |
| continue | |
| if not pset or not gset: | |
| p_vals.append(0.0) | |
| r_vals.append(0.0) | |
| f_vals.append(0.0) | |
| continue | |
| tp = len(pset & gset) | |
| p = tp / len(pset) | |
| r = tp / len(gset) | |
| f = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 | |
| p_vals.append(p) | |
| r_vals.append(r) | |
| f_vals.append(f) | |
| return { | |
| "n": float(n), | |
| "set_precision": sum(p_vals) / n, | |
| "set_recall": sum(r_vals) / n, | |
| "set_f1": sum(f_vals) / n, | |
| "exact_set_match": exact / n, | |
| "avg_pred_tags": sum(pred_sizes) / n, | |
| "avg_gold_tags": sum(gold_sizes) / n, | |
| } | |
| class SampleRow: | |
| sample_id: int | |
| caption: str | |
| gold_tags: Set[str] | |
| def _load_rows(samples_path: Path, caption_field: str) -> List[SampleRow]: | |
| rows: List[SampleRow] = [] | |
| with samples_path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| obj = json.loads(line) | |
| if obj.get("_meta"): | |
| continue | |
| cap = str(obj.get(caption_field, "") or "").strip() | |
| if not cap: | |
| continue | |
| gt: Set[str] = set() | |
| expanded = obj.get("tags_ground_truth_expanded") | |
| if isinstance(expanded, list): | |
| for t in expanded: | |
| c = _canon_tag(str(t)) | |
| if c: | |
| gt.add(c) | |
| if not gt: | |
| cat = obj.get("tags_ground_truth_categorized") | |
| if isinstance(cat, str): | |
| try: | |
| cat = json.loads(cat) | |
| except Exception: | |
| cat = {} | |
| if isinstance(cat, dict): | |
| for vals in cat.values(): | |
| if isinstance(vals, list): | |
| for t in vals: | |
| c = _canon_tag(str(t)) | |
| if c: | |
| gt.add(c) | |
| rows.append( | |
| SampleRow( | |
| sample_id=int(obj.get("id", -1)), | |
| caption=cap, | |
| gold_tags=gt, | |
| ) | |
| ) | |
| return rows | |
| def _eval_model( | |
| model_dir: Path, | |
| rows: Sequence[SampleRow], | |
| num_beams: int, | |
| max_new_tokens: int, | |
| max_source_length: int, | |
| ) -> Dict[str, Dict[str, float]]: | |
| pred_no_heur: List[Set[str]] = [] | |
| pred_with_heur: List[Set[str]] = [] | |
| gold_sets: List[Set[str]] = [] | |
| for r in rows: | |
| out = local_t5_rewrite_prompt( | |
| r.caption, | |
| log=lambda _msg: None, | |
| model_dir=str(model_dir), | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| max_source_length=max_source_length, | |
| ) | |
| p0 = _parse_tag_set(out) | |
| p1 = set(p0) | |
| if out.strip(): | |
| for term in extract_user_provided_tags_upto_3_words(r.caption): | |
| c = _canon_tag(term) | |
| if c: | |
| p1.add(c) | |
| pred_no_heur.append(p0) | |
| pred_with_heur.append(p1) | |
| gold_sets.append(r.gold_tags) | |
| return { | |
| "t5_local_rewrite_no_heur": _set_metrics(pred_no_heur, gold_sets), | |
| "t5_local_rewrite_with_heur": _set_metrics(pred_with_heur, gold_sets), | |
| } | |
| def _iter_model_dirs(raw_models: Iterable[str]) -> List[Path]: | |
| dirs: List[Path] = [] | |
| for raw in raw_models: | |
| p = Path(raw) | |
| if not p.is_absolute(): | |
| p = (REPO_ROOT / p).resolve() | |
| dirs.append(p) | |
| return dirs | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Rewrite-only T5 evaluation on caption-evident n30 set.") | |
| ap.add_argument("--samples", type=Path, default=DEFAULT_SAMPLES) | |
| ap.add_argument("--caption-field", type=str, default="caption_cogvlm") | |
| ap.add_argument("--limit", type=int, default=30) | |
| ap.add_argument("--num-beams", type=int, default=4) | |
| ap.add_argument("--max-new-tokens", type=int, default=128) | |
| ap.add_argument("--max-source-length", type=int, default=160) | |
| ap.add_argument("--model-dir", action="append", required=True, | |
| help="Model directory; repeat for multiple models.") | |
| ap.add_argument("--out-json", type=Path, default=REPO_ROOT / "data" / "analysis") | |
| args = ap.parse_args() | |
| samples_path = args.samples if args.samples.is_absolute() else (REPO_ROOT / args.samples).resolve() | |
| if not samples_path.is_file(): | |
| raise FileNotFoundError(f"Samples file not found: {samples_path}") | |
| model_dirs = _iter_model_dirs(args.model_dir) | |
| for d in model_dirs: | |
| if not d.is_dir(): | |
| raise FileNotFoundError(f"Model directory not found: {d}") | |
| if not (d / "model.safetensors").is_file(): | |
| raise FileNotFoundError(f"model.safetensors missing in: {d}") | |
| rows = _load_rows(samples_path, args.caption_field) | |
| if args.limit > 0: | |
| rows = rows[: min(args.limit, len(rows))] | |
| result_rows = [] | |
| for d in model_dirs: | |
| metrics = _eval_model( | |
| d, | |
| rows=rows, | |
| num_beams=max(1, args.num_beams), | |
| max_new_tokens=max(8, args.max_new_tokens), | |
| max_source_length=max(16, args.max_source_length), | |
| ) | |
| result_rows.append( | |
| { | |
| "model_dir": str(d), | |
| "metrics": metrics, | |
| } | |
| ) | |
| no_h = metrics["t5_local_rewrite_no_heur"] | |
| with_h = metrics["t5_local_rewrite_with_heur"] | |
| print( | |
| f"{d.name}: " | |
| f"no_heur R={no_h['set_recall']:.4f} F1={no_h['set_f1']:.4f} " | |
| f"| with_heur R={with_h['set_recall']:.4f} F1={with_h['set_f1']:.4f}" | |
| ) | |
| out_base = args.out_json if args.out_json.is_absolute() else (REPO_ROOT / args.out_json).resolve() | |
| if out_base.suffix.lower() == ".json": | |
| out_path = out_base | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| else: | |
| out_base.mkdir(parents=True, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_path = out_base / f"rewrite_only_compare_n30_t5_sweep_{ts}.json" | |
| payload = { | |
| "meta": { | |
| "timestamp": datetime.now().isoformat(), | |
| "samples_path": str(samples_path), | |
| "caption_field": args.caption_field, | |
| "n_samples": len(rows), | |
| "num_beams": args.num_beams, | |
| "max_new_tokens": args.max_new_tokens, | |
| "max_source_length": args.max_source_length, | |
| "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), | |
| }, | |
| "rows": result_rows, | |
| } | |
| out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| print(f"Saved: {out_path}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |