Spaces:
Running
Running
| """ | |
| Analyze post-hoc retrieval score thresholds on Stage 3 selections. | |
| This script re-scores evaluation outputs by removing Stage 3 selections | |
| with retrieval score <= threshold, then recomputing metrics. This is an | |
| approximation that avoids re-running the LLMs. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Set, Tuple | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| import csv | |
| from collections import defaultdict | |
| from psq_rag.retrieval.state import expand_tags_via_implications, get_leaf_tags | |
| from scripts.eval_pipeline import _EVAL_EXCLUDED_TAGS # reuse eval exclusions | |
| def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float, float, float]: | |
| if not predicted and not ground_truth: | |
| return 1.0, 1.0, 1.0 | |
| if not predicted: | |
| return 0.0, 0.0, 0.0 | |
| if not ground_truth: | |
| return 0.0, 0.0, 0.0 | |
| tp = len(predicted & ground_truth) | |
| precision = tp / len(predicted) | |
| recall = tp / len(ground_truth) | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return precision, recall, f1 | |
| def _load_rows(path: Path) -> Tuple[dict, List[dict]]: | |
| meta = None | |
| rows = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| if row.get("_meta"): | |
| meta = row | |
| continue | |
| rows.append(row) | |
| if meta is None: | |
| meta = {} | |
| return meta, rows | |
| def _load_tag_db(repo_root: Path) -> Dict[str, int]: | |
| tag_type: Dict[str, int] = {} | |
| db_path = repo_root / "fluffyrock_3m.csv" | |
| if not db_path.exists(): | |
| return tag_type | |
| with db_path.open("r", encoding="utf-8") as f: | |
| for row in csv.reader(f): | |
| if len(row) < 2: | |
| continue | |
| tag = row[0].strip() | |
| try: | |
| tid = int(row[1]) if row[1].strip() else -1 | |
| except ValueError: | |
| tid = -1 | |
| tag_type[tag] = tid | |
| return tag_type | |
| TYPE_ID_NAMES = { | |
| 0: "general", | |
| 1: "artist", | |
| 3: "copyright", | |
| 4: "character", | |
| 5: "species", | |
| 7: "meta", | |
| } | |
| _TAXONOMY = frozenset({ | |
| "mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine", | |
| "mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid", | |
| "amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine", | |
| "oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid", | |
| "suid","suina" | |
| }) | |
| _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"}) | |
| _POSE = frozenset({ | |
| "solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching", | |
| "kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down", | |
| "looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up", | |
| "portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms" | |
| }) | |
| def _categorize(tag: str, tag_type: Dict[str, int]) -> str: | |
| tid = tag_type.get(tag, -1) | |
| tn = TYPE_ID_NAMES.get(tid, "unknown") | |
| if tn == "species": | |
| return "species" | |
| if tn in ("artist", "copyright", "character", "meta"): | |
| return tn | |
| if tag in _TAXONOMY: | |
| return "taxonomy" | |
| if tag in _BODY_PLAN: | |
| return "body_plan" | |
| if tag in _POSE: | |
| return "pose/composition" | |
| if tag.startswith(tuple(str(i) + "_" for i in range(10))) and any( | |
| tag.endswith(s) for s in ("fingers","toes","horns","arms","legs","eyes","ears","wings","tails") | |
| ): | |
| return "count/anatomy" | |
| if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): | |
| return "gender" | |
| if any(k in tag for k in ( | |
| "clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants", | |
| "shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt", | |
| "collar","scarf","cape","armor","suit","uniform","costume","outfit" | |
| )): | |
| return "clothing" | |
| if any(tag.startswith(c + "_") for c in ( | |
| "red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream", | |
| "gold","silver","teal","cyan","magenta" | |
| )): | |
| return "color/marking" | |
| if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": | |
| return "color/marking" | |
| if "hair" in tag: | |
| return "hair" | |
| if any(k in tag for k in ( | |
| "muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw", | |
| "paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane" | |
| )): | |
| return "body/anatomy" | |
| if any(k in tag for k in ( | |
| "smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth", | |
| "closed_eyes","wink" | |
| )): | |
| return "expression" | |
| return "other_general" | |
| def _iter_thresholds(values: Iterable[float], min_v: float, max_v: float, step: float) -> List[float]: | |
| if values: | |
| return sorted(set(values)) | |
| thresholds = [] | |
| v = min_v | |
| while v <= max_v + 1e-9: | |
| thresholds.append(round(v, 4)) | |
| v += step | |
| return thresholds | |
| def _sparkline(values: List[float], width: int = 50) -> str: | |
| if not values: | |
| return "" | |
| charset = " .:-=+*#%@" | |
| vmin = min(values) | |
| vmax = max(values) | |
| if vmax == vmin: | |
| return charset[0] * min(width, len(values)) | |
| out = [] | |
| for v in values: | |
| norm = (v - vmin) / (vmax - vmin) | |
| idx = int(round(norm * (len(charset) - 1))) | |
| out.append(charset[idx]) | |
| return "".join(out) | |
| def analyze( | |
| path: Path, | |
| thresholds: List[float], | |
| expand_implications: bool, | |
| category_curves: bool, | |
| mode: str, | |
| ) -> Tuple[List[dict], List[dict]]: | |
| meta, rows = _load_rows(path) | |
| expand = expand_implications or bool(meta.get("expand_implications")) | |
| tag_type = _load_tag_db(_REPO_ROOT) if category_curves else {} | |
| results = [] | |
| category_rows = [] | |
| for thr in thresholds: | |
| total_p = total_r = total_f1 = 0.0 | |
| total_lp = total_lr = total_lf1 = 0.0 | |
| total_sel = 0 | |
| total_gt = 0 | |
| total_oracle_r = 0.0 | |
| total_oracle_f1 = 0.0 | |
| n = 0 | |
| if category_curves: | |
| cat_totals = defaultdict(lambda: {"p": 0.0, "r": 0.0, "f1": 0.0, "n": 0}) | |
| for row in rows: | |
| gt = set(row.get("ground_truth_tags", [])) | |
| gt -= _EVAL_EXCLUDED_TAGS | |
| stage3_selected = set(row.get("stage3_selected", [])) | |
| stage3_scores: Dict[str, float] = row.get("stage3_selected_scores", {}) or {} | |
| stage3_ranks: Dict[str, int] = row.get("stage3_selected_ranks", {}) or {} | |
| stage3_phrase_ranks: Dict[str, int] = row.get("stage3_selected_phrase_ranks", {}) or {} | |
| structural = set(row.get("structural", [])) | |
| # Remove low-scoring Stage 3 selections. | |
| filtered_stage3 = set() | |
| for t in stage3_selected: | |
| if mode == "rank": | |
| rank = stage3_ranks.get(t) | |
| if rank is None: | |
| filtered_stage3.add(t) | |
| elif rank <= int(thr): | |
| filtered_stage3.add(t) | |
| elif mode == "phrase_rank": | |
| rank = stage3_phrase_ranks.get(t) | |
| if rank is None: | |
| filtered_stage3.add(t) | |
| elif rank <= int(thr): | |
| filtered_stage3.add(t) | |
| else: | |
| score = stage3_scores.get(t) | |
| if score is None: | |
| filtered_stage3.add(t) | |
| elif score > thr: | |
| filtered_stage3.add(t) | |
| available = filtered_stage3 | structural | |
| if expand and available: | |
| available, _ = expand_tags_via_implications(available) | |
| selected = available | |
| selected -= _EVAL_EXCLUDED_TAGS | |
| p, r, f1 = _compute_metrics(selected, gt) | |
| total_p += p | |
| total_r += r | |
| total_f1 += f1 | |
| leaf_sel = get_leaf_tags(selected) | |
| leaf_gt = get_leaf_tags(gt) | |
| lp, lr, lf1 = _compute_metrics(leaf_sel, leaf_gt) | |
| total_lp += lp | |
| total_lr += lr | |
| total_lf1 += lf1 | |
| # Oracle max: perfect selection from available tags. | |
| if gt: | |
| oracle_r = len(gt & available) / len(gt) | |
| oracle_f1 = (2 * oracle_r / (1 + oracle_r)) if oracle_r > 0 else 0.0 | |
| else: | |
| oracle_r = 1.0 | |
| oracle_f1 = 1.0 | |
| total_oracle_r += oracle_r | |
| total_oracle_f1 += oracle_f1 | |
| if category_curves: | |
| cat_gt: Dict[str, Set[str]] = defaultdict(set) | |
| cat_sel: Dict[str, Set[str]] = defaultdict(set) | |
| for t in gt: | |
| cat_gt[_categorize(t, tag_type)].add(t) | |
| for t in selected: | |
| cat_sel[_categorize(t, tag_type)].add(t) | |
| for cat in set(cat_gt.keys()) | set(cat_sel.keys()): | |
| cp, cr, cf1 = _compute_metrics(cat_sel.get(cat, set()), cat_gt.get(cat, set())) | |
| cat_totals[cat]["p"] += cp | |
| cat_totals[cat]["r"] += cr | |
| cat_totals[cat]["f1"] += cf1 | |
| cat_totals[cat]["n"] += 1 | |
| total_sel += len(selected) | |
| total_gt += len(gt) | |
| n += 1 | |
| if n == 0: | |
| continue | |
| results.append({ | |
| "threshold": thr, | |
| "P": total_p / n, | |
| "R": total_r / n, | |
| "F1": total_f1 / n, | |
| "leaf_P": total_lp / n, | |
| "leaf_R": total_lr / n, | |
| "leaf_F1": total_lf1 / n, | |
| "avg_selected": total_sel / n, | |
| "avg_gt": total_gt / n, | |
| "oracle_R": total_oracle_r / n, | |
| "oracle_F1": total_oracle_f1 / n, | |
| }) | |
| if category_curves: | |
| for cat, stats in sorted(cat_totals.items()): | |
| if stats["n"] == 0: | |
| continue | |
| category_rows.append({ | |
| "threshold": thr, | |
| "category": cat, | |
| "P": stats["p"] / stats["n"], | |
| "R": stats["r"] / stats["n"], | |
| "F1": stats["f1"] / stats["n"], | |
| }) | |
| return results, category_rows | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Analyze post-hoc Stage3 score thresholds.") | |
| ap.add_argument("path", nargs="?", type=str, default=None, | |
| help="Path to compact eval JSONL (default: latest in data/eval_results)") | |
| ap.add_argument("--min", dest="min_v", type=float, default=0.0, help="Min threshold") | |
| ap.add_argument("--max", dest="max_v", type=float, default=1.0, help="Max threshold") | |
| ap.add_argument("--step", type=float, default=0.05, help="Threshold step size") | |
| ap.add_argument("--values", type=str, default="", | |
| help="Comma-separated explicit thresholds (overrides min/max/step)") | |
| ap.add_argument("--mode", choices=["score", "rank", "phrase_rank"], default="score", | |
| help="Threshold mode: score (default), rank (global), or phrase_rank (per-phrase)") | |
| ap.add_argument("--rank-min", type=int, default=1, help="Min rank threshold (rank mode)") | |
| ap.add_argument("--rank-max", type=int, default=300, help="Max rank threshold (rank mode)") | |
| ap.add_argument("--rank-step", type=int, default=10, help="Rank threshold step (rank mode)") | |
| ap.add_argument("--no-expand-implications", action="store_true", | |
| help="Do not re-expand tags via implications") | |
| ap.add_argument("--category-curves", action="store_true", | |
| help="Emit category-level precision/recall/F1 curves") | |
| args = ap.parse_args() | |
| if args.path: | |
| path = Path(args.path) | |
| else: | |
| path = sorted((_REPO_ROOT / "data" / "eval_results").glob("eval_*.jsonl"))[-1] | |
| values = [] | |
| if args.values.strip(): | |
| values = [float(v.strip()) for v in args.values.split(",") if v.strip()] | |
| if args.mode in ("rank", "phrase_rank"): | |
| if values: | |
| thresholds = sorted(set(int(v) for v in values)) | |
| else: | |
| thresholds = list(range(args.rank_min, args.rank_max + 1, args.rank_step)) | |
| else: | |
| thresholds = _iter_thresholds(values, args.min_v, args.max_v, args.step) | |
| results, category_rows = analyze( | |
| path, | |
| thresholds, | |
| expand_implications=not args.no_expand_implications, | |
| category_curves=args.category_curves, | |
| mode=args.mode, | |
| ) | |
| # Write CSV to stdout | |
| if args.mode in ("rank", "phrase_rank"): | |
| print("rank_max,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1") | |
| else: | |
| print("threshold,P,R,F1,leaf_P,leaf_R,leaf_F1,avg_selected,avg_gt,oracle_R,oracle_F1") | |
| for row in results: | |
| if args.mode in ("rank", "phrase_rank"): | |
| print( | |
| f"{int(row['threshold'])},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}," | |
| f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f}," | |
| f"{row['avg_selected']:.2f},{row['avg_gt']:.2f}," | |
| f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}" | |
| ) | |
| else: | |
| print( | |
| f"{row['threshold']:.4f},{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}," | |
| f"{row['leaf_P']:.4f},{row['leaf_R']:.4f},{row['leaf_F1']:.4f}," | |
| f"{row['avg_selected']:.2f},{row['avg_gt']:.2f}," | |
| f"{row['oracle_R']:.4f},{row['oracle_F1']:.4f}" | |
| ) | |
| # ASCII sparkline graph for core metrics | |
| p_vals = [r["P"] for r in results] | |
| r_vals = [r["R"] for r in results] | |
| f1_vals = [r["F1"] for r in results] | |
| print("\nP " + _sparkline(p_vals)) | |
| print("R " + _sparkline(r_vals)) | |
| print("F1 " + _sparkline(f1_vals)) | |
| if args.category_curves and category_rows: | |
| print("\nCATEGORY_CURVES") | |
| if args.mode in ("rank", "phrase_rank"): | |
| print("rank_max,category,P,R,F1") | |
| else: | |
| print("threshold,category,P,R,F1") | |
| for row in category_rows: | |
| if args.mode in ("rank", "phrase_rank"): | |
| print( | |
| f"{int(row['threshold'])},{row['category']}," | |
| f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}" | |
| ) | |
| else: | |
| print( | |
| f"{row['threshold']:.4f},{row['category']}," | |
| f"{row['P']:.4f},{row['R']:.4f},{row['F1']:.4f}" | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |