Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Sequence, Set | |
| from scripts.eval_pipeline import run_eval | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_EVAL_PATH = REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl" | |
| def _canon_tag(tag: str) -> str: | |
| t = " ".join(str(tag or "").strip().split()).lower() | |
| return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
| def _parse_tag_set(text: str) -> Set[str]: | |
| out: Set[str] = set() | |
| for raw in (text or "").split(","): | |
| t = _canon_tag(raw) | |
| if t: | |
| out.add(t) | |
| return out | |
| def _set_metrics(pred_sets: Sequence[Set[str]], gold_sets: Sequence[Set[str]]) -> Dict[str, float]: | |
| n = len(pred_sets) | |
| if n == 0: | |
| return { | |
| "set_precision": 0.0, | |
| "set_recall": 0.0, | |
| "set_f1": 0.0, | |
| "avg_pred_tags": 0.0, | |
| "avg_gold_tags": 0.0, | |
| } | |
| p_vals: List[float] = [] | |
| r_vals: List[float] = [] | |
| f_vals: List[float] = [] | |
| pred_sizes: List[int] = [] | |
| gold_sizes: List[int] = [] | |
| for pset, gset in zip(pred_sets, gold_sets): | |
| pred_sizes.append(len(pset)) | |
| gold_sizes.append(len(gset)) | |
| if not pset or not gset: | |
| p_vals.append(0.0 if pset or gset else 1.0) | |
| r_vals.append(0.0 if pset or gset else 1.0) | |
| f_vals.append(0.0 if pset or gset else 1.0) | |
| continue | |
| tp = len(pset & gset) | |
| p = tp / len(pset) | |
| r = tp / len(gset) | |
| f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 | |
| p_vals.append(p) | |
| r_vals.append(r) | |
| f_vals.append(f1) | |
| return { | |
| "set_precision": sum(p_vals) / n, | |
| "set_recall": sum(r_vals) / n, | |
| "set_f1": sum(f_vals) / n, | |
| "avg_pred_tags": sum(pred_sizes) / n, | |
| "avg_gold_tags": sum(gold_sizes) / n, | |
| } | |
| def _summarize(results) -> Dict[str, float]: | |
| valid = [r for r in results if r.error is None] | |
| if not valid: | |
| return { | |
| "n_valid": 0, | |
| "n_errors": len(results), | |
| "ret_R": 0.0, | |
| "P": 0.0, | |
| "R": 0.0, | |
| "F1": 0.0, | |
| "leaf_F1": 0.0, | |
| "t1": 0.0, | |
| "t2": 0.0, | |
| "t3": 0.0, | |
| "t_total": 0.0, | |
| "rw_P": 0.0, | |
| "rw_R": 0.0, | |
| "rw_F1": 0.0, | |
| "rw_avg_pred": 0.0, | |
| "rw_avg_gt": 0.0, | |
| } | |
| n = len(valid) | |
| avg = lambda xs: sum(xs) / n | |
| pred_sets = [] | |
| gold_sets = [] | |
| for r in valid: | |
| phrase_text = ", ".join((r.rewrite_phrases or [])) | |
| pred_sets.append(_parse_tag_set(phrase_text)) | |
| gold_sets.append({_canon_tag(t) for t in (r.ground_truth_tags or set()) if t}) | |
| rewrite = _set_metrics(pred_sets, gold_sets) | |
| t1 = avg([r.stage1_time for r in valid]) | |
| t2 = avg([r.stage2_time for r in valid]) | |
| t3 = avg([r.stage3_time for r in valid]) | |
| return { | |
| "n_valid": n, | |
| "n_errors": len(results) - n, | |
| "ret_R": avg([r.retrieval_recall for r in valid]), | |
| "P": avg([r.selection_precision for r in valid]), | |
| "R": avg([r.selection_recall for r in valid]), | |
| "F1": avg([r.selection_f1 for r in valid]), | |
| "leaf_F1": avg([r.leaf_f1 for r in valid]), | |
| "t1": t1, | |
| "t2": t2, | |
| "t3": t3, | |
| "t_total": t1 + t2 + t3, | |
| "rw_P": rewrite["set_precision"], | |
| "rw_R": rewrite["set_recall"], | |
| "rw_F1": rewrite["set_f1"], | |
| "rw_avg_pred": rewrite["avg_pred_tags"], | |
| "rw_avg_gt": rewrite["avg_gold_tags"], | |
| } | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Run n30 rewrite ablation: LLM vs T5, heuristic phrase append off/on") | |
| ap.add_argument("--eval-path", type=Path, default=DEFAULT_EVAL_PATH) | |
| ap.add_argument("--caption-field", type=str, default="caption_cogvlm") | |
| ap.add_argument("--n", type=int, default=30) | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--workers", type=int, default=1) | |
| ap.add_argument("--mode", type=str, default="chunked_map_union", choices=["single_shot", "chunked_map_union"]) | |
| ap.add_argument("--chunk-size", type=int, default=60) | |
| ap.add_argument("--per-phrase-k", type=int, default=2) | |
| ap.add_argument("--per-phrase-final-k", type=int, default=1) | |
| ap.add_argument("--min-why", type=str, default="strong_implied") | |
| ap.add_argument("--infer-structural", action="store_true", default=True) | |
| ap.add_argument("--no-infer-structural", dest="infer_structural", action="store_false") | |
| ap.add_argument("--infer-probe", action="store_true", default=True) | |
| ap.add_argument("--no-infer-probe", dest="infer_probe", action="store_false") | |
| ap.add_argument("--t5-model-dir", type=str, default="models/finetune/t5-rewrite") | |
| ap.add_argument("--t5-num-beams", type=int, default=4) | |
| ap.add_argument("--t5-max-new-tokens", type=int, default=128) | |
| args = ap.parse_args() | |
| eval_path = args.eval_path if args.eval_path.is_absolute() else (REPO_ROOT / args.eval_path).resolve() | |
| if not eval_path.is_file(): | |
| raise FileNotFoundError(f"Eval path not found: {eval_path}") | |
| configs = [ | |
| {"rewrite_source": "llm", "append_heuristic_phrases": False}, | |
| {"rewrite_source": "llm", "append_heuristic_phrases": True}, | |
| {"rewrite_source": "t5", "append_heuristic_phrases": False}, | |
| {"rewrite_source": "t5", "append_heuristic_phrases": True}, | |
| ] | |
| rows: List[Dict[str, Any]] = [] | |
| details: Dict[str, Any] = {} | |
| for cfg in configs: | |
| name = f"{cfg['rewrite_source']}_heur_{'on' if cfg['append_heuristic_phrases'] else 'off'}" | |
| print("\n" + "=" * 80) | |
| print(f"Running config: {name}") | |
| print("=" * 80) | |
| results = run_eval( | |
| n_samples=args.n, | |
| caption_field=args.caption_field, | |
| skip_rewrite=False, | |
| allow_nsfw=False, | |
| mode=args.mode, | |
| chunk_size=args.chunk_size, | |
| per_phrase_k=args.per_phrase_k, | |
| per_phrase_final_k=args.per_phrase_final_k, | |
| temperature=0.0, | |
| max_tokens=512, | |
| verbose=False, | |
| shuffle=True, | |
| seed=args.seed, | |
| workers=args.workers, | |
| min_why=None if args.min_why == "none" else args.min_why, | |
| eval_path=str(eval_path), | |
| expand_implications=False, | |
| infer_structural=args.infer_structural, | |
| infer_probe=args.infer_probe, | |
| rewrite_source=cfg["rewrite_source"], | |
| t5_model_dir=args.t5_model_dir, | |
| t5_num_beams=args.t5_num_beams, | |
| t5_max_new_tokens=args.t5_max_new_tokens, | |
| append_heuristic_phrases=cfg["append_heuristic_phrases"], | |
| ) | |
| summary = _summarize(results) | |
| summary.update(cfg) | |
| rows.append(summary) | |
| details[name] = { | |
| "summary": summary, | |
| "errors": [ | |
| { | |
| "id": r.sample_id, | |
| "error": r.error, | |
| "issues": r.issues, | |
| } | |
| for r in results | |
| if r.error | |
| ], | |
| } | |
| print(json.dumps(summary, ensure_ascii=False, indent=2)) | |
| out_dir = REPO_ROOT / "data" / "eval_results" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| json_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.json" | |
| csv_out = out_dir / f"rewrite_ablation_n{args.n}_{ts}.csv" | |
| payload = { | |
| "meta": { | |
| "timestamp": datetime.now().isoformat(), | |
| "eval_path": str(eval_path), | |
| "caption_field": args.caption_field, | |
| "n": args.n, | |
| "seed": args.seed, | |
| "workers": args.workers, | |
| "mode": args.mode, | |
| "chunk_size": args.chunk_size, | |
| "per_phrase_k": args.per_phrase_k, | |
| "per_phrase_final_k": args.per_phrase_final_k, | |
| "min_why": args.min_why, | |
| "infer_structural": args.infer_structural, | |
| "infer_probe": args.infer_probe, | |
| "t5_model_dir": args.t5_model_dir, | |
| "t5_num_beams": args.t5_num_beams, | |
| "t5_max_new_tokens": args.t5_max_new_tokens, | |
| }, | |
| "rows": rows, | |
| "details": details, | |
| } | |
| with json_out.open("w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2) | |
| fieldnames = [ | |
| "rewrite_source", | |
| "append_heuristic_phrases", | |
| "n_valid", | |
| "n_errors", | |
| "rw_P", | |
| "rw_R", | |
| "rw_F1", | |
| "rw_avg_pred", | |
| "rw_avg_gt", | |
| "ret_R", | |
| "P", | |
| "R", | |
| "F1", | |
| "leaf_F1", | |
| "t1", | |
| "t2", | |
| "t3", | |
| "t_total", | |
| ] | |
| with csv_out.open("w", encoding="utf-8", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=fieldnames) | |
| w.writeheader() | |
| for row in rows: | |
| w.writerow(row) | |
| print(f"\nSaved ablation JSON: {json_out}") | |
| print(f"Saved ablation CSV: {csv_out}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |