"""Analyze compact eval results (n=50) for patterns in missed and extra tags. Works with the new compact JSONL format (missed/extra diff sets, not full tag lists). """ from __future__ import annotations import csv, json, re, sys from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple _REPO_ROOT = Path(__file__).resolve().parents[1] TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"} def load_tag_db(): tag_type, tag_count = {}, {} with (_REPO_ROOT / "fluffyrock_3m.csv").open("r", encoding="utf-8") as f: for row in csv.reader(f): if len(row) < 3: continue tag = row[0].strip() try: tid = int(row[1]) if row[1].strip() else -1 except ValueError: tid = -1 try: cnt = int(row[2]) if row[2].strip() else 0 except ValueError: cnt = 0 tag_type[tag] = tid tag_count[tag] = cnt return tag_type, tag_count def load_implications(): impl = defaultdict(list) p = _REPO_ROOT / "tag_implications-2023-07-20.csv" if not p.is_file(): return impl with p.open("r", encoding="utf-8") as f: for row in csv.DictReader(f): if row.get("status") == "active": impl[row["antecedent_name"].strip()].append(row["consequent_name"].strip()) return dict(impl) def get_leaf_tags(tags, impl): non_leaves = set() for tag in tags: q = [tag]; vis = set() while q: t = q.pop() for p in impl.get(t, []): if p not in vis: vis.add(p) if p in tags: non_leaves.add(p) q.append(p) return tags - non_leaves # ── Categorization ── _TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine","mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid","amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine","oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid","suid","suina"}) _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"}) _POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"}) _COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)") _STRUCTURAL = frozenset({ # Character count "solo","duo","trio","group","zero_pictured", # Body type "anthro","feral","humanoid","taur", # Gender "male","female","ambiguous_gender","intersex", # Clothing state "clothed","nude","topless","bottomless", # Visual elements "looking_at_viewer","text", }) def categorize(tag, tag_type): tid = tag_type.get(tag, -1) tn = TYPE_ID_NAMES.get(tid, "unknown") if tn == "species": return "species" if tn in ("artist","copyright","character","meta"): return tn if tag in _TAXONOMY: return "taxonomy" if tag in _BODY_PLAN: return "body_plan" if tag in _POSE: return "pose/composition" if _COUNT_RE.match(tag): return "count/anatomy" if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender" if any(k in tag for k in ("clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants","shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt","collar","scarf","cape","armor","suit","uniform","costume","outfit")): return "clothing" if any(tag.startswith(c+"_") for c in ("red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream","gold","silver","teal","cyan","magenta")): return "color/marking" if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking" if "hair" in tag: return "hair" if any(k in tag for k in ("muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw","paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane")): return "body/anatomy" if any(k in tag for k in ("smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth","closed_eyes","wink")): return "expression" return "other_general" def main(): path = Path(sys.argv[1]) if len(sys.argv) > 1 else sorted((_REPO_ROOT/"data"/"eval_results").glob("eval_*.jsonl"))[-1] tag_type, tag_count = load_tag_db() impl = load_implications() samples = [] with path.open("r", encoding="utf-8") as f: for line in f: row = json.loads(line) if row.get("_meta"): print(f"Config: min_why={row.get('min_why')}, expand_impl={row.get('expand_implications')}, " f"structural={row.get('infer_structural')}, n={row.get('n_samples')}") continue if row.get("err"): continue samples.append(row) N = len(samples) print(f"Analyzing {N} samples from {path.name}\n") # ── 1. Missed tags (GT tags not in selected) ── missed_counter = Counter() extra_counter = Counter() structural_results = [] for s in samples: for t in s.get("missed", []): missed_counter[t] += 1 for t in s.get("extra", []): extra_counter[t] += 1 structural_results.append(s.get("structural", [])) # ── REPORT 1: Missed by category ── print("=" * 70) print(f"MISSED TAGS — GT tags not selected ({sum(missed_counter.values())} total misses, {len(missed_counter)} unique)") print("=" * 70) cat_missed = defaultdict(Counter) for tag, cnt in missed_counter.items(): cat_missed[categorize(tag, tag_type)][tag] = cnt cat_totals = {c: sum(v.values()) for c, v in cat_missed.items()} for cat in sorted(cat_totals, key=cat_totals.get, reverse=True): tags = cat_missed[cat] total = cat_totals[cat] # Is this category covered by structural inference? struct_covered = sum(1 for t in tags if t in _STRUCTURAL) struct_note = f" ({struct_covered} structural-coverable)" if struct_covered else "" print(f"\n [{cat}] — {total} misses across {len(tags)} unique tags{struct_note}") for tag, cnt in tags.most_common(10): freq = tag_count.get(tag, 0) struct_mark = " *STRUCTURAL*" if tag in _STRUCTURAL else "" print(f" {tag:40s} missed {cnt:>2}/{N}{struct_mark} freq={freq:>9,}") # ── REPORT 2: Missed tags that structural should catch ── print("\n" + "=" * 70) print("STRUCTURAL TAG ACCURACY") print("=" * 70) # Which structural tags are still being missed? structural_missed = {t: c for t, c in missed_counter.items() if t in _STRUCTURAL} if structural_missed: print("\n Structural tags STILL missed (Stage 3s should catch these):") for t, c in sorted(structural_missed.items(), key=lambda x: -x[1]): print(f" {t:30s} missed {c}/{N}") else: print("\n All structural tags covered!") # What structural tags are over-applied (false positives)? structural_extra = {t: c for t, c in extra_counter.items() if t in _STRUCTURAL} if structural_extra: print(f"\n Structural tags wrongly added (false positives):") for t, c in sorted(structural_extra.items(), key=lambda x: -x[1]): print(f" {t:30s} extra {c}/{N}") # Per-structural-tag stats from the structural field struct_tag_counts = Counter() for sl in structural_results: for t in sl: struct_tag_counts[t] += 1 print(f"\n Structural tag selection frequency (how often Stage 3s picks each):") for t, c in struct_tag_counts.most_common(): missed_c = structural_missed.get(t, 0) extra_c = structural_extra.get(t, 0) print(f" {t:30s} picked {c:>2}/{N} missed_in_GT={missed_c} false_pos={extra_c}") # ── REPORT 3: Extra tags (false positives) by category ── print("\n" + "=" * 70) print(f"EXTRA TAGS — Selected but not in GT ({sum(extra_counter.values())} total, {len(extra_counter)} unique)") print("=" * 70) cat_extra = defaultdict(Counter) for tag, cnt in extra_counter.items(): cat_extra[categorize(tag, tag_type)][tag] = cnt cat_extra_totals = {c: sum(v.values()) for c, v in cat_extra.items()} for cat in sorted(cat_extra_totals, key=cat_extra_totals.get, reverse=True): tags = cat_extra[cat] total = cat_extra_totals[cat] print(f"\n [{cat}] — {total} false positives across {len(tags)} unique tags") for tag, cnt in tags.most_common(8): freq = tag_count.get(tag, 0) print(f" {tag:40s} extra {cnt:>2}/{N} freq={freq:>9,}") # ── REPORT 3b: Evidence sources for false positives ── # (Only available in new format with extra_evidence field) source_counts = Counter() # source -> count of FP tags why_fp_counts = Counter() # why level -> count of FP tags from stage3 score_buckets = {"high (>0.5)": 0, "medium (0.2-0.5)": 0, "low (<0.2)": 0} has_evidence = False for s in samples: ev = s.get("extra_evidence", {}) if ev: has_evidence = True for tag, info in ev.items(): src = info.get("source", "unknown") source_counts[src] += 1 if src == "stage3": why_fp_counts[info.get("why", "unknown")] += 1 score = info.get("retrieval_score", 0) if score > 0.5: score_buckets["high (>0.5)"] += 1 elif score > 0.2: score_buckets["medium (0.2-0.5)"] += 1 else: score_buckets["low (<0.2)"] += 1 if has_evidence: print("\n" + "=" * 70) print("FALSE POSITIVE EVIDENCE SOURCES") print("=" * 70) total_fp = sum(source_counts.values()) print(f"\n How did {total_fp} false positive tags get through?") for src, cnt in source_counts.most_common(): print(f" {src:20s} {cnt:>4} ({cnt/max(1,total_fp)*100:.0f}%)") if why_fp_counts: print(f"\n Stage 3 false positives by 'why' level:") for why, cnt in why_fp_counts.most_common(): print(f" {why:20s} {cnt:>4}") print(f"\n Stage 3 false positives by retrieval score:") for bucket, cnt in score_buckets.items(): print(f" {bucket:20s} {cnt:>4}") # ── REPORT 4: Leaf vs non-leaf in missed ── print("\n" + "=" * 70) print("MISSED: LEAF vs IMPLIED ANCESTORS") print("=" * 70) all_missed = set(missed_counter.keys()) leaf_missed = get_leaf_tags(all_missed, impl) anc_missed = all_missed - leaf_missed leaf_vol = sum(missed_counter[t] for t in leaf_missed) anc_vol = sum(missed_counter[t] for t in anc_missed) total_vol = leaf_vol + anc_vol print(f"\n Unique missed: {len(all_missed)} tags") print(f" Leaf: {len(leaf_missed)} ({len(leaf_missed)/max(1,len(all_missed))*100:.0f}%)") print(f" Ancestor: {len(anc_missed)} ({len(anc_missed)/max(1,len(all_missed))*100:.0f}%)") print(f" Miss volume: {total_vol}") print(f" From leaf: {leaf_vol} ({leaf_vol/max(1,total_vol)*100:.0f}%)") print(f" From ancestor: {anc_vol} ({anc_vol/max(1,total_vol)*100:.0f}%) — recoverable via implications") # ── REPORT 5: Frequency distribution ── print("\n" + "=" * 70) print("FREQUENCY DISTRIBUTION OF MISSED TAGS") print("=" * 70) buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0, "common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0} for tag in missed_counter: freq = tag_count.get(tag, -1) if freq < 0: buckets["not_in_db"] += 1 elif freq < 100: buckets["very_rare (<100)"] += 1 elif freq < 1000: buckets["rare (100-1k)"] += 1 elif freq < 10000: buckets["medium (1k-10k)"] += 1 elif freq < 100000: buckets["common (10k-100k)"] += 1 else: buckets["very_common (100k+)"] += 1 for b, c in buckets.items(): print(f" {b:25s} {c:4d} unique tags ({c/max(1,len(missed_counter))*100:.0f}%)") # ── REPORT 6: Over-selection analysis ── print("\n" + "=" * 70) print("OVER-SELECTION ANALYSIS") print("=" * 70) over_sels = [s["over_sel"] for s in samples] over_sels.sort() print(f"\n Avg over-selection ratio: {sum(over_sels)/N:.2f}x") print(f" Median: {over_sels[N//2]:.2f}x") print(f" Min: {over_sels[0]:.2f}x") print(f" Max: {over_sels[-1]:.2f}x") tight = sum(1 for x in over_sels if 0.8 <= x <= 1.5) over = sum(1 for x in over_sels if x > 2.0) under = sum(1 for x in over_sels if x < 0.5) print(f" Tight (0.8-1.5x): {tight}/{N}") print(f" Over (>2.0x): {over}/{N}") print(f" Under (<0.5x): {under}/{N}") # Worst over-selectors worst = sorted(samples, key=lambda s: -s["over_sel"])[:5] print(f"\n Worst over-selectors:") for s in worst: print(f" id={s['id']:>8} over_sel={s['over_sel']:.2f}x selected={s['n_selected']} gt={s['n_gt']} " f"F1={s['F1']:.3f} n_extra={len(s.get('extra',[]))}") # ── REPORT 7: Aggregate metrics ── print("\n" + "=" * 70) print("AGGREGATE METRICS") print("=" * 70) for metric, key in [("F1", "F1"), ("Precision", "P"), ("Recall", "R"), ("Leaf F1", "leaf_F1"), ("Leaf P", "leaf_P"), ("Leaf R", "leaf_R"), ("Retrieval Recall", "ret_R")]: vals = [s[key] for s in samples] avg = sum(vals)/N vals.sort() med = vals[N//2] print(f" {metric:20s} avg={avg:.4f} median={med:.4f} min={vals[0]:.4f} max={vals[-1]:.4f}") # ── REPORT 8: Samples sorted by F1 ── print("\n" + "=" * 70) print("WORST 10 SAMPLES BY F1") print("=" * 70) by_f1 = sorted(samples, key=lambda s: s["F1"]) for s in by_f1[:10]: n_missed = len(s.get("missed", [])) n_extra = len(s.get("extra", [])) print(f" id={s['id']:>8} F1={s['F1']:.3f} P={s['P']:.3f} R={s['R']:.3f} " f"gt={s['n_gt']} sel={s['n_selected']} missed={n_missed} extra={n_extra} " f"structural={s.get('structural',[])} over_sel={s['over_sel']:.2f}x") print() if __name__ == "__main__": main()