Spaces:
Running
Running
| """Measure per-tag LLM reliability for probe tags (selection-only, no retrieval). | |
| Process: | |
| - Use caption as query text. | |
| - Ask Stage 3 selector to choose among a fixed probe-tag candidate list. | |
| - Compare selected tags to ground-truth tag presence. | |
| This estimates whether a probe tag is worth asking the LLM about. | |
| Outputs (overwrite by suffix): | |
| - data/analysis/probe_reliability_<suffix>.csv | |
| - data/analysis/probe_reliability_<suffix>.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import random | |
| import sys | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| REPO = Path(__file__).resolve().parents[1] | |
| if str(REPO) not in sys.path: | |
| sys.path.insert(0, str(REPO)) | |
| os_chdir = __import__("os").chdir | |
| os_chdir(REPO) | |
| EVAL_DATA_RAW = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" | |
| PROBE_SET_CSV = REPO / "data" / "simplified_probe_tags.csv" | |
| OUT_DIR = REPO / "data" / "analysis" | |
| def _flatten_ground_truth(tags_categorized_str: str) -> Set[str]: | |
| if not tags_categorized_str: | |
| return set() | |
| try: | |
| cats = json.loads(tags_categorized_str) | |
| except Exception: | |
| return set() | |
| out: Set[str] = set() | |
| if isinstance(cats, dict): | |
| for vals in cats.values(): | |
| if isinstance(vals, list): | |
| for t in vals: | |
| if isinstance(t, str): | |
| out.add(t.strip()) | |
| return out | |
| def _metrics(tp: int, fp: int, fn: int) -> Tuple[float, float, float]: | |
| p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 | |
| return p, r, f1 | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Evaluate per-tag probe reliability (selection-only).") | |
| ap.add_argument("--probe-csv", type=Path, default=PROBE_SET_CSV) | |
| ap.add_argument("--data", type=Path, default=EVAL_DATA_RAW) | |
| ap.add_argument("--caption-field", default="caption_cogvlm") | |
| ap.add_argument("--n", type=int, default=10, help="Number of samples.") | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--suffix", default="sanity10") | |
| ap.add_argument("--retries", type=int, default=2) | |
| ap.add_argument("--temperature", type=float, default=0.0) | |
| ap.add_argument("--max-tokens", type=int, default=700) | |
| ap.add_argument("--workers-note", default="sequential", help="for logging only; this script runs sequentially.") | |
| ap.add_argument("--verbose", action="store_true") | |
| args = ap.parse_args() | |
| if not args.probe_csv.is_file(): | |
| raise FileNotFoundError(f"Probe CSV not found: {args.probe_csv}") | |
| if not args.data.is_file(): | |
| raise FileNotFoundError(f"Eval data not found: {args.data}") | |
| from psq_rag.llm.select import llm_select_indices, WHY_RANK | |
| # Load probe tags from selected_initial list. | |
| probe_rows = list(csv.DictReader(args.probe_csv.open("r", encoding="utf-8", newline=""))) | |
| probe_rows = [r for r in probe_rows if (r.get("selected_initial") or "0").strip() in {"1", "true", "True"}] | |
| probe_tags = [r["tag"] for r in probe_rows if r.get("tag")] | |
| if not probe_tags: | |
| raise RuntimeError("No probe tags found with selected_initial=1.") | |
| tag_meta = {r["tag"]: r for r in probe_rows} | |
| # Load and sample data. | |
| all_rows = [] | |
| with args.data.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| cap = (row.get(args.caption_field) or "").strip() | |
| if not cap: | |
| continue | |
| gt = _flatten_ground_truth(row.get("tags_ground_truth_categorized", "")) | |
| if not gt: | |
| continue | |
| all_rows.append({"id": row.get("id"), "caption": cap, "gt": gt}) | |
| if not all_rows: | |
| raise RuntimeError(f"No usable rows in {args.data}.") | |
| rnd = random.Random(args.seed) | |
| rnd.shuffle(all_rows) | |
| samples = all_rows[: max(1, min(args.n, len(all_rows)))] | |
| # Tag-level confusion by threshold. | |
| thresholds = { | |
| "explicit": {"max_rank": WHY_RANK["explicit"]}, | |
| "strong": {"max_rank": WHY_RANK["strong_implied"]}, # explicit + strong_implied | |
| } | |
| conf = {th: {t: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for t in probe_tags} for th in thresholds} | |
| overall = {th: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for th in thresholds} | |
| diag_rows = [] | |
| parse_fail_count = 0 | |
| call_exhaust_count = 0 | |
| def _log(msg: str) -> None: | |
| if args.verbose: | |
| print(msg) | |
| for i, s in enumerate(samples): | |
| caption = s["caption"] | |
| gt = s["gt"] | |
| # IMPORTANT: per_phrase_k controls per-call budget when candidate strings have no sources. | |
| # Set it to len(probe_tags) so the model can choose all true tags if needed. | |
| idxs, tag_why, diag = llm_select_indices( | |
| query_text=caption, | |
| candidates=probe_tags, | |
| max_pick=len(probe_tags), | |
| log=_log, | |
| retries=args.retries, | |
| mode="single_shot", | |
| chunk_size=max(1, len(probe_tags)), | |
| per_phrase_k=max(1, len(probe_tags)), | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| return_metadata=True, | |
| return_diagnostics=True, | |
| min_why=None, | |
| ) | |
| # Map selected indices to tags. | |
| selected_all = set() | |
| for idx in idxs: | |
| if 0 <= idx < len(probe_tags): | |
| selected_all.add(probe_tags[idx]) | |
| if float(diag.get("attempt_failure_rate", 0.0)) > 0.0: | |
| parse_fail_count += 1 | |
| if float(diag.get("call_exhaustion_rate", 0.0)) > 0.0: | |
| call_exhaust_count += 1 | |
| diag_rows.append( | |
| { | |
| "sample_id": s["id"], | |
| "selected_any": len(selected_all), | |
| "attempt_failure_rate": float(diag.get("attempt_failure_rate", 0.0)), | |
| "call_exhaustion_rate": float(diag.get("call_exhaustion_rate", 0.0)), | |
| } | |
| ) | |
| # Apply thresholds by why rank. | |
| for th, cfg in thresholds.items(): | |
| max_rank = cfg["max_rank"] | |
| selected = set() | |
| for t in selected_all: | |
| why = tag_why.get(t, "other") | |
| if WHY_RANK.get(why, 999) <= max_rank: | |
| selected.add(t) | |
| for t in probe_tags: | |
| gt_pos = t in gt | |
| pred_pos = t in selected | |
| if gt_pos and pred_pos: | |
| conf[th][t]["tp"] += 1 | |
| overall[th]["tp"] += 1 | |
| elif (not gt_pos) and pred_pos: | |
| conf[th][t]["fp"] += 1 | |
| overall[th]["fp"] += 1 | |
| elif gt_pos and (not pred_pos): | |
| conf[th][t]["fn"] += 1 | |
| overall[th]["fn"] += 1 | |
| else: | |
| conf[th][t]["tn"] += 1 | |
| overall[th]["tn"] += 1 | |
| # Per-tag reliability table. | |
| out_rows = [] | |
| for t in probe_tags: | |
| r = {"tag": t} | |
| r["bundle"] = tag_meta[t].get("bundle", "") | |
| r["needs_glossary"] = tag_meta[t].get("needs_glossary", "") | |
| support_pos = conf["strong"][t]["tp"] + conf["strong"][t]["fn"] | |
| support_neg = conf["strong"][t]["tn"] + conf["strong"][t]["fp"] | |
| r["support_pos"] = str(support_pos) | |
| r["support_neg"] = str(support_neg) | |
| for th in ("explicit", "strong"): | |
| tp = conf[th][t]["tp"] | |
| fp = conf[th][t]["fp"] | |
| fn = conf[th][t]["fn"] | |
| p, rc, f1 = _metrics(tp, fp, fn) | |
| r[f"tp_{th}"] = str(tp) | |
| r[f"fp_{th}"] = str(fp) | |
| r[f"fn_{th}"] = str(fn) | |
| r[f"precision_{th}"] = f"{p:.6f}" | |
| r[f"recall_{th}"] = f"{rc:.6f}" | |
| r[f"f1_{th}"] = f"{f1:.6f}" | |
| out_rows.append(r) | |
| out_rows.sort( | |
| key=lambda x: (float(x["f1_strong"]), int(x["support_pos"]), -int(x["needs_glossary"] or "0")), | |
| reverse=True, | |
| ) | |
| # Overall metrics. | |
| overall_metrics = {} | |
| for th in ("explicit", "strong"): | |
| tp = overall[th]["tp"] | |
| fp = overall[th]["fp"] | |
| fn = overall[th]["fn"] | |
| p, rc, f1 = _metrics(tp, fp, fn) | |
| overall_metrics[th] = { | |
| "tp": tp, | |
| "fp": fp, | |
| "fn": fn, | |
| "precision": round(p, 6), | |
| "recall": round(rc, 6), | |
| "f1": round(f1, 6), | |
| } | |
| suffix = args.suffix.strip() or f"n{len(samples)}" | |
| out_csv = OUT_DIR / f"probe_reliability_{suffix}.csv" | |
| out_json = OUT_DIR / f"probe_reliability_{suffix}.json" | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| with out_csv.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=[ | |
| "tag", | |
| "bundle", | |
| "needs_glossary", | |
| "support_pos", | |
| "support_neg", | |
| "tp_explicit", | |
| "fp_explicit", | |
| "fn_explicit", | |
| "precision_explicit", | |
| "recall_explicit", | |
| "f1_explicit", | |
| "tp_strong", | |
| "fp_strong", | |
| "fn_strong", | |
| "precision_strong", | |
| "recall_strong", | |
| "f1_strong", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(out_rows) | |
| summary = { | |
| "settings": { | |
| "n": len(samples), | |
| "seed": args.seed, | |
| "caption_field": args.caption_field, | |
| "probe_count": len(probe_tags), | |
| "retries": args.retries, | |
| "temperature": args.temperature, | |
| "max_tokens": args.max_tokens, | |
| "model_env": __import__("os").environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct"), | |
| }, | |
| "overall_metrics": overall_metrics, | |
| "diagnostics": { | |
| "samples_with_attempt_failures": parse_fail_count, | |
| "samples_with_call_exhaustion": call_exhaust_count, | |
| "avg_attempt_failure_rate": sum(d["attempt_failure_rate"] for d in diag_rows) / len(diag_rows), | |
| "avg_call_exhaustion_rate": sum(d["call_exhaustion_rate"] for d in diag_rows) / len(diag_rows), | |
| }, | |
| "top_tags_by_f1_strong": out_rows[:20], | |
| "outputs": { | |
| "csv": str(out_csv), | |
| "json": str(out_json), | |
| }, | |
| } | |
| with out_json.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| print(f"Samples evaluated: {len(samples)}") | |
| print(f"Probe tags evaluated: {len(probe_tags)}") | |
| print(f"Overall strong: P={overall_metrics['strong']['precision']:.4f} " | |
| f"R={overall_metrics['strong']['recall']:.4f} F1={overall_metrics['strong']['f1']:.4f}") | |
| print(f"Diagnostics: attempt_fail_samples={parse_fail_count}, call_exhaust_samples={call_exhaust_count}") | |
| print(f"Outputs: {out_csv}, {out_json}") | |
| if __name__ == "__main__": | |
| main() | |