""" Analyze post-hoc retrieval score thresholds on Stage 3 selections. This script re-scores evaluation outputs by removing Stage 3 selections with retrieval score <= threshold, then recomputing metrics. This is an approximation that avoids re-running the LLMs. """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Dict, Iterable, List, Set, Tuple _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) import csv from collections import defaultdict from psq_rag.retrieval.state import expand_tags_via_implications, get_leaf_tags from scripts.eval_pipeline import _EVAL_EXCLUDED_TAGS # reuse eval exclusions def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float, float, float]: if not predicted and not ground_truth: return 1.0, 1.0, 1.0 if not predicted: return 0.0, 0.0, 0.0 if not ground_truth: return 0.0, 0.0, 0.0 tp = len(predicted & ground_truth) precision = tp / len(predicted) recall = tp / len(ground_truth) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1 def _load_rows(path: Path) -> Tuple[dict, List[dict]]: meta = None rows = [] with path.open("r", encoding="utf-8") as f: for line in f: row = json.loads(line) if row.get("_meta"): meta = row continue rows.append(row) if meta is None: meta = {} return meta, rows def _load_tag_db(repo_root: Path) -> Dict[str, int]: tag_type: Dict[str, int] = {} db_path = repo_root / "fluffyrock_3m.csv" if not db_path.exists(): return tag_type with db_path.open("r", encoding="utf-8") as f: for row in csv.reader(f): if len(row) < 2: continue tag = row[0].strip() try: tid = int(row[1]) if row[1].strip() else -1 except ValueError: tid = -1 tag_type[tag] = tid return tag_type TYPE_ID_NAMES = { 0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta", } _TAXONOMY = frozenset({ "mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine", "mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid", "amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine", "oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid", "suid","suina" }) _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"}) _POSE = frozenset({ "solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching", "kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down", "looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up", "portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms" }) def _categorize(tag: str, tag_type: Dict[str, int]) -> str: tid = tag_type.get(tag, -1) tn = TYPE_ID_NAMES.get(tid, "unknown") if tn == "species": return "species" if tn in ("artist", "copyright", "character", "meta"): return tn if tag in _TAXONOMY: return "taxonomy" if tag in _BODY_PLAN: return "body_plan" if tag in _POSE: return "pose/composition" if tag.startswith(tuple(str(i) + "_" for i in range(10))) and any( tag.endswith(s) for s in ("fingers","toes","horns","arms","legs","eyes","ears","wings","tails") ): return "count/anatomy" if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender" if any(k in tag for k in ( "clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants", "shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt", "collar","scarf","cape","armor","suit","uniform","costume","outfit" )): return "clothing" if any(tag.startswith(c + "_") for c in ( "red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream", "gold","silver","teal","cyan","magenta" )): return "color/marking" if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking" if "hair" in tag: return "hair" if any(k in tag for k in ( "muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw", "paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane" )): return "body/anatomy" if any(k in tag for k in ( "smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth", "closed_eyes","wink" )): return "expression" return "other_general" def _iter_thresholds(values: Iterable[float], min_v: float, max_v: float, step: float) -> List[float]: if values: return sorted(set(values)) thresholds = [] v = min_v while v <= max_v + 1e-9: thresholds.append(round(v, 4)) v += step return thresholds def _sparkline(values: List[float], width: int = 50) -> str: if not values: return "" charset = " .:-=+*#%@" vmin = min(values) vmax = max(values) if vmax == vmin: return charset[0] * min(width, len(values)) out = [] for v in values: norm = (v - vmin) / (vmax - vmin) idx = int(round(norm * (len(charset) - 1))) out.append(charset[idx]) return "".join(out) def analyze( path: Path, thresholds: List[float], expand_implications: bool, category_curves: bool, mode: str, ) -> Tuple[List[dict], List[dict]]: meta, rows = _load_rows(path) expand = expand_implications or bool(meta.get("expand_implications")) tag_type = _load_tag_db(_REPO_ROOT) if category_curves else {} results = [] category_rows = [] for thr in thresholds: total_p = total_r = total_f1 = 0.0 total_lp = total_lr = total_lf1 = 0.0 total_sel = 0 total_gt = 0 total_oracle_r = 0.0 total_oracle_f1 = 0.0 n = 0 if category_curves: cat_totals = defaultdict(lambda: {"p": 0.0, "r": 0.0, "f1": 0.0, "n": 0}) for row in rows: gt = set(row.get("ground_truth_tags", [])) gt -= _EVAL_EXCLUDED_TAGS stage3_selected = set(row.get("stage3_selected", [])) stage3_scores: Dict[str, float] = row.get("stage3_selected_scores", {}) or {} stage3_ranks: Dict[str, int] = row.get("stage3_selected_ranks", {}) or {} stage3_phrase_ranks: Dict[str, int] = row.get("stage3_selected_phrase_ranks", {}) or {} structural = set(row.get("structural", [])) # Remove low-scoring Stage 3 selections. filtered_stage3 = set() for t in stage3_selected: if mode == "rank": rank = stage3_ranks.get(t) if rank is None: filtered_stage3.add(t) elif rank <= int(thr): filtered_stage3.add(t) elif mode == "phrase_rank": rank = stage3_phrase_ranks.get(t) if rank is None: filtered_stage3.add(t) elif rank <= int(thr): filtered_stage3.add(t) else: score = stage3_scores.get(t) if score is None: filtered_stage3.add(t) elif score > thr: filtered_stage3.add(t) available = filtered_stage3 | structural if expand and available: available, _ = expand_tags_via_implications(available) selected = available selected -= _EVAL_EXCLUDED_TAGS p, r, f1 = _compute_metrics(selected, gt) total_p += p total_r += r total_f1 += f1 leaf_sel = get_leaf_tags(selected) leaf_gt = get_leaf_tags(gt) lp, lr, lf1 = _compute_metrics(leaf_sel, leaf_gt) total_lp += lp total_lr += lr total_lf1 += lf1 # Oracle max: perfect selection from available tags. if gt: oracle_r = len(gt & available) / len(gt) oracle_f1 = (2 * oracle_r / (1 + oracle_r)) if oracle_r > 0 else 0.0 else: oracle_r = 1.0 oracle_f1 = 1.0 total_oracle_r += oracle_r total_oracle_f1 += oracle_f1 if category_curves: cat_gt: Dict[str, Set[str]] = defaultdict(set) cat_sel: Dict[str, Set[str]] = defaultdict(set) for t in gt: cat_gt[_categorize(t, tag_type)].add(t) for t in selected: cat_sel[_categorize(t, tag_type)].add(t) for cat in set(cat_gt.keys()) | set(cat_sel.keys()): cp, cr, cf1 = _compute_metrics(cat_sel.get(cat, set()), cat_gt.get(cat, set())) cat_totals[cat]["p"] += cp cat_totals[cat]["r"] += cr cat_totals[cat]["f1"] += cf1 cat_totals[cat]["n"] += 1 total_sel += len(selected) total_gt += len(gt) n += 1 if n == 0: continue results.append({ "threshold": thr, "P": total_p / n, "R": total_r / n, "F1": total_f1 / n, "leaf_P": total_lp / n, "leaf_R": total_lr / n, "leaf_F1": total_lf1 / n, "avg_selected": total_sel / n, "avg_gt": total_gt / n, "oracle_R": total_oracle_r / n, "oracle_F1": total_oracle_f1 / n, }) if category_curves: for cat, stats in sorted(cat_totals.items()): if stats["n"] == 0: continue category_rows.append({ "threshold": thr, "category": cat, "P": stats["p"] / stats["n"], "R": stats["r"] / stats["n"], "F1": stats["f1"] / stats["n"], }) return results, category_rows def main() -> int: ap = argparse.ArgumentParser(description="Analyze post-hoc Stage3 score thresholds.") ap.add_argument("path", nargs="?", type=str, default=None, help="Path to compact eval JSONL (default: latest in data/eval_results)") ap.add_argument("--min", dest="min_v", type=float, default=0.0, help="Min threshold") ap.add_argument("--max", dest="max_v", type=float, default=1.0, help="Max threshold") ap.add_argument("--step", type=float, default=0.05, help="Threshold step size") ap.add_argument("--values", type=str, default="", help="Comma-separated explicit thresholds (overrides min/max/step)") ap.add_argument("--mode", choices=["score", "rank", "phrase_rank"], default="score", help="Threshold mode: score (default), rank (global), or phrase_rank (per-phrase)") ap.add_argument("--rank-min", type=int, default=1, help="Min rank threshold (rank mode)") ap.add_argument("--rank-max", type=int, default=300, help="Max rank threshold (rank mode)") ap.add_argument("--rank-step", type=int, default=10, help="Rank threshold step (rank mode)") ap.add_argument("--no-expand-implications", action="store_true", help="Do not re-expand tags via implications") ap.add_argument("--category-curves", action="store_true", help="Emit category-level precision/recall/F1 curves") args = ap.parse_args() if args.path: path = Path(args.path) else: path = sorted((_REPO_ROOT / "data" / "eval_results").glob("eval_*.jsonl"))[-1] values = [] if args.values.strip(): values = [float(v.strip()) for v in args.values.split(",") if v.strip()] if args.mode in ("rank", "phrase_rank"): if values: thresholds = sorted(set(int(v) for v in values)) else: thresholds = list(range(args.rank_min, args.rank_max + 1, args.rank_step)) else: thresholds = _iter_thresholds(values, args.min_v, args.max_v, args.step) results, category_rows = analyze( path, thresholds, expand_implications=not args.no_expand_implications, category_curves=args.category_curves, mode=args.mode, ) # Write CSV to stdout if args.mode in ("rank", "phrase_rank"): print("rank_max,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1") else: print("threshold,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1") for row in results: if args.mode in ("rank", "phrase_rank"): print( f"{int(row['threshold'])},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}," f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f}," f"{row['avg_selected']:.2f},{row['avg_gt']:.2f}," f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}" ) else: print( f"{row['threshold']:.4f},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}," f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f}," f"{row['avg_selected']:.2f},{row['avg_gt']:.2f}," f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}" ) # ASCII sparkline graph for core metrics p_vals = [r["P"] for r in results] r_vals = [r["R"] for r in results] f1_vals = [r["F1"] for r in results] print("\nP " + _sparkline(p_vals)) print("R " + _sparkline(r_vals)) print("F1 " + _sparkline(f1_vals)) if args.category_curves and category_rows: print("\nCATEGORY_CURVES") if args.mode in ("rank", "phrase_rank"): print("rank_max,category,P,R,F1") else: print("threshold,category,P,R,F1") for row in category_rows: if args.mode in ("rank", "phrase_rank"): print( f"{int(row['threshold'])},{row['category']}," f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}" ) else: print( f"{row['threshold']:.4f},{row['category']}," f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}" ) return 0 if __name__ == "__main__": raise SystemExit(main())