"""Build a simplified probe tag set from informativeness data. This script creates a small, bundle-balanced probe list intended for a single structured LLM probe query. If reliability results are provided, it also computes reliability-aware final selection flags. Outputs (overwrite): - data/simplified_probe_tags.csv - data/analysis/simplified_probe_tags_summary.json """ from __future__ import annotations import argparse import csv import json from pathlib import Path from typing import Dict, List, Set REPO = Path(__file__).resolve().parents[1] PROBE_INFO_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" OUT_CSV = REPO / "data" / "simplified_probe_tags.csv" OUT_SUMMARY = REPO / "data" / "analysis" / "simplified_probe_tags_summary.json" BUNDLE_SPECS = { "clothing_state": { "cap": 6, "force": ["clothing", "clothed", "topwear", "bottomwear", "topless", "nude"], "deny": {"5_fingers"}, }, "scene_pose": { "cap": 4, "force": ["simple_background", "standing", "sitting", "outside"], "deny": set(), }, "gaze_expression": { "cap": 5, "force": ["smile", "looking_at_viewer", "open_mouth", "blush", "eyes_closed"], "deny": set(), }, "text_symbols": { "cap": 3, "force": ["text", "dialogue", "<3"], "deny": set(), }, "body_type_presence": { "cap": 4, "force": ["anthro", "feral", "biped", "humanoid"], "deny": set(), }, "count_cardinality": { "cap": 5, "force": ["zero_pictured", "solo", "duo", "trio", "group"], "deny": {"husky", "marsupial", "black_bars"}, }, "body_shape_breasts": { "cap": 4, "force": ["breasts", "big_breasts", "wide_hips", "thick_thighs"], "deny": set(), }, "species_taxonomy": { "cap": 6, "force": ["canid", "canis", "felid", "leporid", "bird", "bear", "unicorn", "equid"], "deny": {"mammal"}, }, } def _load_probe_rows(path: Path) -> List[Dict[str, str]]: with path.open("r", encoding="utf-8", newline="") as f: return list(csv.DictReader(f)) def _load_reliability(path: Path) -> Dict[str, Dict[str, str]]: if not path or not path.is_file(): return {} with path.open("r", encoding="utf-8", newline="") as f: rows = list(csv.DictReader(f)) return {r["tag"]: r for r in rows} def _as_float(v: str, default: float = 0.0) -> float: try: return float(v) except Exception: return default def _as_int(v: str, default: int = 0) -> int: try: return int(v) except Exception: return default def main() -> None: ap = argparse.ArgumentParser(description="Build simplified probe set.") ap.add_argument("--probe-info", type=Path, default=PROBE_INFO_CSV) ap.add_argument("--reliability-csv", type=Path, default=None, help="Optional probe reliability CSV.") ap.add_argument("--min-prevalence", type=float, default=0.01) ap.add_argument("--max-prevalence", type=float, default=0.70) ap.add_argument("--min-support-pos", type=int, default=5) ap.add_argument("--min-f1-strong", type=float, default=0.45) ap.add_argument("--min-precision-strong", type=float, default=0.50) args = ap.parse_args() if not args.probe_info.is_file(): raise FileNotFoundError(f"Missing probe informativeness CSV: {args.probe_info}") rows = _load_probe_rows(args.probe_info) rel = _load_reliability(args.reliability_csv) if args.reliability_csv else {} by_bundle: Dict[str, List[Dict[str, str]]] = {} by_tag: Dict[str, Dict[str, str]] = {} for r in rows: tag = r["tag"] by_tag[tag] = r b = r.get("suggested_probe_bundle", "other") by_bundle.setdefault(b, []).append(r) for b in by_bundle: by_bundle[b].sort(key=lambda x: _as_float(x.get("actionable_score", "0")), reverse=True) selected_initial: List[Dict[str, str]] = [] selected_tags: Set[str] = set() for bundle, spec in BUNDLE_SPECS.items(): cap = int(spec["cap"]) deny = set(spec["deny"]) forced = spec["force"] candidates = by_bundle.get(bundle, []) def ok(r: Dict[str, str]) -> bool: tag = r["tag"] p = _as_float(r.get("prevalence", "0")) return tag not in deny and args.min_prevalence <= p <= args.max_prevalence for t in forced: r = by_tag.get(t) if not r or not ok(r): continue if t in selected_tags: continue selected_initial.append(r) selected_tags.add(t) if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: break if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: continue for r in candidates: if not ok(r): continue t = r["tag"] if t in selected_tags: continue selected_initial.append(r) selected_tags.add(t) if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: break # Reliability-aware scoring (if reliability CSV exists). out_rows: List[Dict[str, str]] = [] for r in selected_initial: tag = r["tag"] rr = rel.get(tag, {}) support_pos = _as_int(rr.get("support_pos", "0")) precision_strong = _as_float(rr.get("precision_strong", "0")) recall_strong = _as_float(rr.get("recall_strong", "0")) f1_strong = _as_float(rr.get("f1_strong", "0")) actionable = _as_float(r.get("actionable_score", "0")) has_rel = bool(rr) if has_rel: reliability_weight = f1_strong final_score = actionable * (0.25 + 0.75 * reliability_weight) selected_final = int( support_pos >= args.min_support_pos and f1_strong >= args.min_f1_strong and precision_strong >= args.min_precision_strong ) rel_note = ( f"support={support_pos}, f1={f1_strong:.3f}, " f"prec={precision_strong:.3f}, rec={recall_strong:.3f}" ) else: reliability_weight = 0.0 final_score = actionable selected_final = 0 rel_note = "no_reliability_data" out_rows.append( { "tag": tag, "bundle": r.get("suggested_probe_bundle", "other"), "needs_glossary": r.get("needs_glossary", "0"), "prevalence": r.get("prevalence", ""), "actionable_score": f"{actionable:.6f}", "selected_initial": "1", "support_pos": str(support_pos), "precision_strong": f"{precision_strong:.6f}", "recall_strong": f"{recall_strong:.6f}", "f1_strong": f"{f1_strong:.6f}", "reliability_weight": f"{reliability_weight:.6f}", "final_score": f"{final_score:.6f}", "selected_final": str(selected_final), "reliability_note": rel_note, } ) out_rows.sort(key=lambda x: (_as_float(x["final_score"]), _as_float(x["actionable_score"])), reverse=True) OUT_CSV.parent.mkdir(parents=True, exist_ok=True) with OUT_CSV.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "tag", "bundle", "needs_glossary", "prevalence", "actionable_score", "selected_initial", "support_pos", "precision_strong", "recall_strong", "f1_strong", "reliability_weight", "final_score", "selected_final", "reliability_note", ], ) writer.writeheader() writer.writerows(out_rows) selected_final_tags = [r["tag"] for r in out_rows if r["selected_final"] == "1"] bundle_specs_json = {} for k, v in BUNDLE_SPECS.items(): bundle_specs_json[k] = { "cap": v["cap"], "force": list(v["force"]), "deny": sorted(list(v["deny"])), } summary = { "probe_info_csv": str(args.probe_info), "reliability_csv": str(args.reliability_csv) if args.reliability_csv else None, "n_selected_initial": len(out_rows), "n_selected_final": len(selected_final_tags), "selected_final_tags": selected_final_tags, "bundle_specs": bundle_specs_json, "thresholds": { "min_prevalence": args.min_prevalence, "max_prevalence": args.max_prevalence, "min_support_pos": args.min_support_pos, "min_f1_strong": args.min_f1_strong, "min_precision_strong": args.min_precision_strong, }, "outputs": { "csv": str(OUT_CSV), "summary_json": str(OUT_SUMMARY), }, } with OUT_SUMMARY.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"Selected initial probes: {len(out_rows)}") print(f"Selected final probes: {len(selected_final_tags)}") print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}") if __name__ == "__main__": main()