Prompt_Squirrel_RAG / scripts /analyze_threshold_grid.py
Food Desert
Add eval audit tools, caption-evident set, and logging
73f56cf
"""
Analyze post-hoc retrieval score thresholds on Stage 3 selections.
This script re-scores evaluation outputs by removing Stage 3 selections
with retrieval score <= threshold, then recomputing metrics. This is an
approximation that avoids re-running the LLMs.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Set, Tuple
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
import csv
from collections import defaultdict
from psq_rag.retrieval.state import expand_tags_via_implications, get_leaf_tags
from scripts.eval_pipeline import _EVAL_EXCLUDED_TAGS # reuse eval exclusions
def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float, float, float]:
if not predicted and not ground_truth:
return 1.0, 1.0, 1.0
if not predicted:
return 0.0, 0.0, 0.0
if not ground_truth:
return 0.0, 0.0, 0.0
tp = len(predicted & ground_truth)
precision = tp / len(predicted)
recall = tp / len(ground_truth)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return precision, recall, f1
def _load_rows(path: Path) -> Tuple[dict, List[dict]]:
meta = None
rows = []
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
meta = row
continue
rows.append(row)
if meta is None:
meta = {}
return meta, rows
def _load_tag_db(repo_root: Path) -> Dict[str, int]:
tag_type: Dict[str, int] = {}
db_path = repo_root / "fluffyrock_3m.csv"
if not db_path.exists():
return tag_type
with db_path.open("r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 2:
continue
tag = row[0].strip()
try:
tid = int(row[1]) if row[1].strip() else -1
except ValueError:
tid = -1
tag_type[tag] = tid
return tag_type
TYPE_ID_NAMES = {
0: "general",
1: "artist",
3: "copyright",
4: "character",
5: "species",
7: "meta",
}
_TAXONOMY = frozenset({
"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine",
"mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid",
"amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine",
"oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid",
"suid","suina"
})
_BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
_POSE = frozenset({
"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching",
"kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down",
"looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up",
"portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"
})
def _categorize(tag: str, tag_type: Dict[str, int]) -> str:
tid = tag_type.get(tag, -1)
tn = TYPE_ID_NAMES.get(tid, "unknown")
if tn == "species":
return "species"
if tn in ("artist", "copyright", "character", "meta"):
return tn
if tag in _TAXONOMY:
return "taxonomy"
if tag in _BODY_PLAN:
return "body_plan"
if tag in _POSE:
return "pose/composition"
if tag.startswith(tuple(str(i) + "_" for i in range(10))) and any(
tag.endswith(s) for s in ("fingers","toes","horns","arms","legs","eyes","ears","wings","tails")
):
return "count/anatomy"
if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"):
return "gender"
if any(k in tag for k in (
"clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants",
"shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt",
"collar","scarf","cape","armor","suit","uniform","costume","outfit"
)):
return "clothing"
if any(tag.startswith(c + "_") for c in (
"red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream",
"gold","silver","teal","cyan","magenta"
)):
return "color/marking"
if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings":
return "color/marking"
if "hair" in tag:
return "hair"
if any(k in tag for k in (
"muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw",
"paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane"
)):
return "body/anatomy"
if any(k in tag for k in (
"smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth",
"closed_eyes","wink"
)):
return "expression"
return "other_general"
def _iter_thresholds(values: Iterable[float], min_v: float, max_v: float, step: float) -> List[float]:
if values:
return sorted(set(values))
thresholds = []
v = min_v
while v <= max_v + 1e-9:
thresholds.append(round(v, 4))
v += step
return thresholds
def _sparkline(values: List[float], width: int = 50) -> str:
if not values:
return ""
charset = " .:-=+*#%@"
vmin = min(values)
vmax = max(values)
if vmax == vmin:
return charset[0] * min(width, len(values))
out = []
for v in values:
norm = (v - vmin) / (vmax - vmin)
idx = int(round(norm * (len(charset) - 1)))
out.append(charset[idx])
return "".join(out)
def analyze(
path: Path,
thresholds: List[float],
expand_implications: bool,
category_curves: bool,
mode: str,
) -> Tuple[List[dict], List[dict]]:
meta, rows = _load_rows(path)
expand = expand_implications or bool(meta.get("expand_implications"))
tag_type = _load_tag_db(_REPO_ROOT) if category_curves else {}
results = []
category_rows = []
for thr in thresholds:
total_p = total_r = total_f1 = 0.0
total_lp = total_lr = total_lf1 = 0.0
total_sel = 0
total_gt = 0
total_oracle_r = 0.0
total_oracle_f1 = 0.0
n = 0
if category_curves:
cat_totals = defaultdict(lambda: {"p": 0.0, "r": 0.0, "f1": 0.0, "n": 0})
for row in rows:
gt = set(row.get("ground_truth_tags", []))
gt -= _EVAL_EXCLUDED_TAGS
stage3_selected = set(row.get("stage3_selected", []))
stage3_scores: Dict[str, float] = row.get("stage3_selected_scores", {}) or {}
stage3_ranks: Dict[str, int] = row.get("stage3_selected_ranks", {}) or {}
stage3_phrase_ranks: Dict[str, int] = row.get("stage3_selected_phrase_ranks", {}) or {}
structural = set(row.get("structural", []))
# Remove low-scoring Stage 3 selections.
filtered_stage3 = set()
for t in stage3_selected:
if mode == "rank":
rank = stage3_ranks.get(t)
if rank is None:
filtered_stage3.add(t)
elif rank <= int(thr):
filtered_stage3.add(t)
elif mode == "phrase_rank":
rank = stage3_phrase_ranks.get(t)
if rank is None:
filtered_stage3.add(t)
elif rank <= int(thr):
filtered_stage3.add(t)
else:
score = stage3_scores.get(t)
if score is None:
filtered_stage3.add(t)
elif score > thr:
filtered_stage3.add(t)
available = filtered_stage3 | structural
if expand and available:
available, _ = expand_tags_via_implications(available)
selected = available
selected -= _EVAL_EXCLUDED_TAGS
p, r, f1 = _compute_metrics(selected, gt)
total_p += p
total_r += r
total_f1 += f1
leaf_sel = get_leaf_tags(selected)
leaf_gt = get_leaf_tags(gt)
lp, lr, lf1 = _compute_metrics(leaf_sel, leaf_gt)
total_lp += lp
total_lr += lr
total_lf1 += lf1
# Oracle max: perfect selection from available tags.
if gt:
oracle_r = len(gt & available) / len(gt)
oracle_f1 = (2 * oracle_r / (1 + oracle_r)) if oracle_r > 0 else 0.0
else:
oracle_r = 1.0
oracle_f1 = 1.0
total_oracle_r += oracle_r
total_oracle_f1 += oracle_f1
if category_curves:
cat_gt: Dict[str, Set[str]] = defaultdict(set)
cat_sel: Dict[str, Set[str]] = defaultdict(set)
for t in gt:
cat_gt[_categorize(t, tag_type)].add(t)
for t in selected:
cat_sel[_categorize(t, tag_type)].add(t)
for cat in set(cat_gt.keys()) | set(cat_sel.keys()):
cp, cr, cf1 = _compute_metrics(cat_sel.get(cat, set()), cat_gt.get(cat, set()))
cat_totals[cat]["p"] += cp
cat_totals[cat]["r"] += cr
cat_totals[cat]["f1"] += cf1
cat_totals[cat]["n"] += 1
total_sel += len(selected)
total_gt += len(gt)
n += 1
if n == 0:
continue
results.append({
"threshold": thr,
"P": total_p / n,
"R": total_r / n,
"F1": total_f1 / n,
"leaf_P": total_lp / n,
"leaf_R": total_lr / n,
"leaf_F1": total_lf1 / n,
"avg_selected": total_sel / n,
"avg_gt": total_gt / n,
"oracle_R": total_oracle_r / n,
"oracle_F1": total_oracle_f1 / n,
})
if category_curves:
for cat, stats in sorted(cat_totals.items()):
if stats["n"] == 0:
continue
category_rows.append({
"threshold": thr,
"category": cat,
"P": stats["p"] / stats["n"],
"R": stats["r"] / stats["n"],
"F1": stats["f1"] / stats["n"],
})
return results, category_rows
def main() -> int:
ap = argparse.ArgumentParser(description="Analyze post-hoc Stage3 score thresholds.")
ap.add_argument("path", nargs="?", type=str, default=None,
help="Path to compact eval JSONL (default: latest in data/eval_results)")
ap.add_argument("--min", dest="min_v", type=float, default=0.0, help="Min threshold")
ap.add_argument("--max", dest="max_v", type=float, default=1.0, help="Max threshold")
ap.add_argument("--step", type=float, default=0.05, help="Threshold step size")
ap.add_argument("--values", type=str, default="",
help="Comma-separated explicit thresholds (overrides min/max/step)")
ap.add_argument("--mode", choices=["score", "rank", "phrase_rank"], default="score",
help="Threshold mode: score (default), rank (global), or phrase_rank (per-phrase)")
ap.add_argument("--rank-min", type=int, default=1, help="Min rank threshold (rank mode)")
ap.add_argument("--rank-max", type=int, default=300, help="Max rank threshold (rank mode)")
ap.add_argument("--rank-step", type=int, default=10, help="Rank threshold step (rank mode)")
ap.add_argument("--no-expand-implications", action="store_true",
help="Do not re-expand tags via implications")
ap.add_argument("--category-curves", action="store_true",
help="Emit category-level precision/recall/F1 curves")
args = ap.parse_args()
if args.path:
path = Path(args.path)
else:
path = sorted((_REPO_ROOT / "data" / "eval_results").glob("eval_*.jsonl"))[-1]
values = []
if args.values.strip():
values = [float(v.strip()) for v in args.values.split(",") if v.strip()]
if args.mode in ("rank", "phrase_rank"):
if values:
thresholds = sorted(set(int(v) for v in values))
else:
thresholds = list(range(args.rank_min, args.rank_max + 1, args.rank_step))
else:
thresholds = _iter_thresholds(values, args.min_v, args.max_v, args.step)
results, category_rows = analyze(
path,
thresholds,
expand_implications=not args.no_expand_implications,
category_curves=args.category_curves,
mode=args.mode,
)
# Write CSV to stdout
if args.mode in ("rank", "phrase_rank"):
print("rank_max,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1")
else:
print("threshold,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1")
for row in results:
if args.mode in ("rank", "phrase_rank"):
print(
f"{int(row['threshold'])},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f},"
f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f},"
f"{row['avg_selected']:.2f},{row['avg_gt']:.2f},"
f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}"
)
else:
print(
f"{row['threshold']:.4f},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f},"
f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f},"
f"{row['avg_selected']:.2f},{row['avg_gt']:.2f},"
f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}"
)
# ASCII sparkline graph for core metrics
p_vals = [r["P"] for r in results]
r_vals = [r["R"] for r in results]
f1_vals = [r["F1"] for r in results]
print("\nP " + _sparkline(p_vals))
print("R " + _sparkline(r_vals))
print("F1 " + _sparkline(f1_vals))
if args.category_curves and category_rows:
print("\nCATEGORY_CURVES")
if args.mode in ("rank", "phrase_rank"):
print("rank_max,category,P,R,F1")
else:
print("threshold,category,P,R,F1")
for row in category_rows:
if args.mode in ("rank", "phrase_rank"):
print(
f"{int(row['threshold'])},{row['category']},"
f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}"
)
else:
print(
f"{row['threshold']:.4f},{row['category']},"
f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())