"""Measure per-tag LLM reliability for probe tags (selection-only, no retrieval). Process: - Use caption as query text. - Ask Stage 3 selector to choose among a fixed probe-tag candidate list. - Compare selected tags to ground-truth tag presence. This estimates whether a probe tag is worth asking the LLM about. Outputs (overwrite by suffix): - data/analysis/probe_reliability_.csv - data/analysis/probe_reliability_.json """ from __future__ import annotations import argparse import csv import json import random import sys from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple REPO = Path(__file__).resolve().parents[1] if str(REPO) not in sys.path: sys.path.insert(0, str(REPO)) os_chdir = __import__("os").chdir os_chdir(REPO) EVAL_DATA_RAW = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" PROBE_SET_CSV = REPO / "data" / "simplified_probe_tags.csv" OUT_DIR = REPO / "data" / "analysis" def _flatten_ground_truth(tags_categorized_str: str) -> Set[str]: if not tags_categorized_str: return set() try: cats = json.loads(tags_categorized_str) except Exception: return set() out: Set[str] = set() if isinstance(cats, dict): for vals in cats.values(): if isinstance(vals, list): for t in vals: if isinstance(t, str): out.add(t.strip()) return out def _metrics(tp: int, fp: int, fn: int) -> Tuple[float, float, float]: p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 return p, r, f1 def main() -> None: ap = argparse.ArgumentParser(description="Evaluate per-tag probe reliability (selection-only).") ap.add_argument("--probe-csv", type=Path, default=PROBE_SET_CSV) ap.add_argument("--data", type=Path, default=EVAL_DATA_RAW) ap.add_argument("--caption-field", default="caption_cogvlm") ap.add_argument("--n", type=int, default=10, help="Number of samples.") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--suffix", default="sanity10") ap.add_argument("--retries", type=int, default=2) ap.add_argument("--temperature", type=float, default=0.0) ap.add_argument("--max-tokens", type=int, default=700) ap.add_argument("--workers-note", default="sequential", help="for logging only; this script runs sequentially.") ap.add_argument("--verbose", action="store_true") args = ap.parse_args() if not args.probe_csv.is_file(): raise FileNotFoundError(f"Probe CSV not found: {args.probe_csv}") if not args.data.is_file(): raise FileNotFoundError(f"Eval data not found: {args.data}") from psq_rag.llm.select import llm_select_indices, WHY_RANK # Load probe tags from selected_initial list. probe_rows = list(csv.DictReader(args.probe_csv.open("r", encoding="utf-8", newline=""))) probe_rows = [r for r in probe_rows if (r.get("selected_initial") or "0").strip() in {"1", "true", "True"}] probe_tags = [r["tag"] for r in probe_rows if r.get("tag")] if not probe_tags: raise RuntimeError("No probe tags found with selected_initial=1.") tag_meta = {r["tag"]: r for r in probe_rows} # Load and sample data. all_rows = [] with args.data.open("r", encoding="utf-8") as f: for line in f: row = json.loads(line) cap = (row.get(args.caption_field) or "").strip() if not cap: continue gt = _flatten_ground_truth(row.get("tags_ground_truth_categorized", "")) if not gt: continue all_rows.append({"id": row.get("id"), "caption": cap, "gt": gt}) if not all_rows: raise RuntimeError(f"No usable rows in {args.data}.") rnd = random.Random(args.seed) rnd.shuffle(all_rows) samples = all_rows[: max(1, min(args.n, len(all_rows)))] # Tag-level confusion by threshold. thresholds = { "explicit": {"max_rank": WHY_RANK["explicit"]}, "strong": {"max_rank": WHY_RANK["strong_implied"]}, # explicit + strong_implied } conf = {th: {t: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for t in probe_tags} for th in thresholds} overall = {th: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for th in thresholds} diag_rows = [] parse_fail_count = 0 call_exhaust_count = 0 def _log(msg: str) -> None: if args.verbose: print(msg) for i, s in enumerate(samples): caption = s["caption"] gt = s["gt"] # IMPORTANT: per_phrase_k controls per-call budget when candidate strings have no sources. # Set it to len(probe_tags) so the model can choose all true tags if needed. idxs, tag_why, diag = llm_select_indices( query_text=caption, candidates=probe_tags, max_pick=len(probe_tags), log=_log, retries=args.retries, mode="single_shot", chunk_size=max(1, len(probe_tags)), per_phrase_k=max(1, len(probe_tags)), temperature=args.temperature, max_tokens=args.max_tokens, return_metadata=True, return_diagnostics=True, min_why=None, ) # Map selected indices to tags. selected_all = set() for idx in idxs: if 0 <= idx < len(probe_tags): selected_all.add(probe_tags[idx]) if float(diag.get("attempt_failure_rate", 0.0)) > 0.0: parse_fail_count += 1 if float(diag.get("call_exhaustion_rate", 0.0)) > 0.0: call_exhaust_count += 1 diag_rows.append( { "sample_id": s["id"], "selected_any": len(selected_all), "attempt_failure_rate": float(diag.get("attempt_failure_rate", 0.0)), "call_exhaustion_rate": float(diag.get("call_exhaustion_rate", 0.0)), } ) # Apply thresholds by why rank. for th, cfg in thresholds.items(): max_rank = cfg["max_rank"] selected = set() for t in selected_all: why = tag_why.get(t, "other") if WHY_RANK.get(why, 999) <= max_rank: selected.add(t) for t in probe_tags: gt_pos = t in gt pred_pos = t in selected if gt_pos and pred_pos: conf[th][t]["tp"] += 1 overall[th]["tp"] += 1 elif (not gt_pos) and pred_pos: conf[th][t]["fp"] += 1 overall[th]["fp"] += 1 elif gt_pos and (not pred_pos): conf[th][t]["fn"] += 1 overall[th]["fn"] += 1 else: conf[th][t]["tn"] += 1 overall[th]["tn"] += 1 # Per-tag reliability table. out_rows = [] for t in probe_tags: r = {"tag": t} r["bundle"] = tag_meta[t].get("bundle", "") r["needs_glossary"] = tag_meta[t].get("needs_glossary", "") support_pos = conf["strong"][t]["tp"] + conf["strong"][t]["fn"] support_neg = conf["strong"][t]["tn"] + conf["strong"][t]["fp"] r["support_pos"] = str(support_pos) r["support_neg"] = str(support_neg) for th in ("explicit", "strong"): tp = conf[th][t]["tp"] fp = conf[th][t]["fp"] fn = conf[th][t]["fn"] p, rc, f1 = _metrics(tp, fp, fn) r[f"tp_{th}"] = str(tp) r[f"fp_{th}"] = str(fp) r[f"fn_{th}"] = str(fn) r[f"precision_{th}"] = f"{p:.6f}" r[f"recall_{th}"] = f"{rc:.6f}" r[f"f1_{th}"] = f"{f1:.6f}" out_rows.append(r) out_rows.sort( key=lambda x: (float(x["f1_strong"]), int(x["support_pos"]), -int(x["needs_glossary"] or "0")), reverse=True, ) # Overall metrics. overall_metrics = {} for th in ("explicit", "strong"): tp = overall[th]["tp"] fp = overall[th]["fp"] fn = overall[th]["fn"] p, rc, f1 = _metrics(tp, fp, fn) overall_metrics[th] = { "tp": tp, "fp": fp, "fn": fn, "precision": round(p, 6), "recall": round(rc, 6), "f1": round(f1, 6), } suffix = args.suffix.strip() or f"n{len(samples)}" out_csv = OUT_DIR / f"probe_reliability_{suffix}.csv" out_json = OUT_DIR / f"probe_reliability_{suffix}.json" OUT_DIR.mkdir(parents=True, exist_ok=True) with out_csv.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "tag", "bundle", "needs_glossary", "support_pos", "support_neg", "tp_explicit", "fp_explicit", "fn_explicit", "precision_explicit", "recall_explicit", "f1_explicit", "tp_strong", "fp_strong", "fn_strong", "precision_strong", "recall_strong", "f1_strong", ], ) writer.writeheader() writer.writerows(out_rows) summary = { "settings": { "n": len(samples), "seed": args.seed, "caption_field": args.caption_field, "probe_count": len(probe_tags), "retries": args.retries, "temperature": args.temperature, "max_tokens": args.max_tokens, "model_env": __import__("os").environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct"), }, "overall_metrics": overall_metrics, "diagnostics": { "samples_with_attempt_failures": parse_fail_count, "samples_with_call_exhaustion": call_exhaust_count, "avg_attempt_failure_rate": sum(d["attempt_failure_rate"] for d in diag_rows) / len(diag_rows), "avg_call_exhaustion_rate": sum(d["call_exhaustion_rate"] for d in diag_rows) / len(diag_rows), }, "top_tags_by_f1_strong": out_rows[:20], "outputs": { "csv": str(out_csv), "json": str(out_json), }, } with out_json.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"Samples evaluated: {len(samples)}") print(f"Probe tags evaluated: {len(probe_tags)}") print(f"Overall strong: P={overall_metrics['strong']['precision']:.4f} " f"R={overall_metrics['strong']['recall']:.4f} F1={overall_metrics['strong']['f1']:.4f}") print(f"Diagnostics: attempt_fail_samples={parse_fail_count}, call_exhaust_samples={call_exhaust_count}") print(f"Outputs: {out_csv}, {out_json}") if __name__ == "__main__": main()