Prompt_Squirrel_RAG / scripts /analyze_compact_eval.py
Claude
Redesign structural inference as group-based system with wiki data
684cf99
"""Analyze compact eval results (n=50) for patterns in missed and extra tags.
Works with the new compact JSONL format (missed/extra diff sets, not full tag lists).
"""
from __future__ import annotations
import csv, json, re, sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
_REPO_ROOT = Path(__file__).resolve().parents[1]
TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"}
def load_tag_db():
tag_type, tag_count = {}, {}
with (_REPO_ROOT / "fluffyrock_3m.csv").open("r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 3: continue
tag = row[0].strip()
try: tid = int(row[1]) if row[1].strip() else -1
except ValueError: tid = -1
try: cnt = int(row[2]) if row[2].strip() else 0
except ValueError: cnt = 0
tag_type[tag] = tid
tag_count[tag] = cnt
return tag_type, tag_count
def load_implications():
impl = defaultdict(list)
p = _REPO_ROOT / "tag_implications-2023-07-20.csv"
if not p.is_file(): return impl
with p.open("r", encoding="utf-8") as f:
for row in csv.DictReader(f):
if row.get("status") == "active":
impl[row["antecedent_name"].strip()].append(row["consequent_name"].strip())
return dict(impl)
def get_leaf_tags(tags, impl):
non_leaves = set()
for tag in tags:
q = [tag]; vis = set()
while q:
t = q.pop()
for p in impl.get(t, []):
if p not in vis:
vis.add(p)
if p in tags: non_leaves.add(p)
q.append(p)
return tags - non_leaves
# ── Categorization ──
_TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine","mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid","amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine","oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid","suid","suina"})
_BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
_POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"})
_COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)")
_STRUCTURAL = frozenset({
# Character count
"solo","duo","trio","group","zero_pictured",
# Body type
"anthro","feral","humanoid","taur",
# Gender
"male","female","ambiguous_gender","intersex",
# Clothing state
"clothed","nude","topless","bottomless",
# Visual elements
"looking_at_viewer","text",
})
def categorize(tag, tag_type):
tid = tag_type.get(tag, -1)
tn = TYPE_ID_NAMES.get(tid, "unknown")
if tn == "species": return "species"
if tn in ("artist","copyright","character","meta"): return tn
if tag in _TAXONOMY: return "taxonomy"
if tag in _BODY_PLAN: return "body_plan"
if tag in _POSE: return "pose/composition"
if _COUNT_RE.match(tag): return "count/anatomy"
if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender"
if any(k in tag for k in ("clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants","shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt","collar","scarf","cape","armor","suit","uniform","costume","outfit")): return "clothing"
if any(tag.startswith(c+"_") for c in ("red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream","gold","silver","teal","cyan","magenta")): return "color/marking"
if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking"
if "hair" in tag: return "hair"
if any(k in tag for k in ("muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw","paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane")): return "body/anatomy"
if any(k in tag for k in ("smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth","closed_eyes","wink")): return "expression"
return "other_general"
def main():
path = Path(sys.argv[1]) if len(sys.argv) > 1 else sorted((_REPO_ROOT/"data"/"eval_results").glob("eval_*.jsonl"))[-1]
tag_type, tag_count = load_tag_db()
impl = load_implications()
samples = []
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
print(f"Config: min_why={row.get('min_why')}, expand_impl={row.get('expand_implications')}, "
f"structural={row.get('infer_structural')}, n={row.get('n_samples')}")
continue
if row.get("err"): continue
samples.append(row)
N = len(samples)
print(f"Analyzing {N} samples from {path.name}\n")
# ── 1. Missed tags (GT tags not in selected) ──
missed_counter = Counter()
extra_counter = Counter()
structural_results = []
for s in samples:
for t in s.get("missed", []): missed_counter[t] += 1
for t in s.get("extra", []): extra_counter[t] += 1
structural_results.append(s.get("structural", []))
# ── REPORT 1: Missed by category ──
print("=" * 70)
print(f"MISSED TAGS β€” GT tags not selected ({sum(missed_counter.values())} total misses, {len(missed_counter)} unique)")
print("=" * 70)
cat_missed = defaultdict(Counter)
for tag, cnt in missed_counter.items():
cat_missed[categorize(tag, tag_type)][tag] = cnt
cat_totals = {c: sum(v.values()) for c, v in cat_missed.items()}
for cat in sorted(cat_totals, key=cat_totals.get, reverse=True):
tags = cat_missed[cat]
total = cat_totals[cat]
# Is this category covered by structural inference?
struct_covered = sum(1 for t in tags if t in _STRUCTURAL)
struct_note = f" ({struct_covered} structural-coverable)" if struct_covered else ""
print(f"\n [{cat}] β€” {total} misses across {len(tags)} unique tags{struct_note}")
for tag, cnt in tags.most_common(10):
freq = tag_count.get(tag, 0)
struct_mark = " *STRUCTURAL*" if tag in _STRUCTURAL else ""
print(f" {tag:40s} missed {cnt:>2}/{N}{struct_mark} freq={freq:>9,}")
# ── REPORT 2: Missed tags that structural should catch ──
print("\n" + "=" * 70)
print("STRUCTURAL TAG ACCURACY")
print("=" * 70)
# Which structural tags are still being missed?
structural_missed = {t: c for t, c in missed_counter.items() if t in _STRUCTURAL}
if structural_missed:
print("\n Structural tags STILL missed (Stage 3s should catch these):")
for t, c in sorted(structural_missed.items(), key=lambda x: -x[1]):
print(f" {t:30s} missed {c}/{N}")
else:
print("\n All structural tags covered!")
# What structural tags are over-applied (false positives)?
structural_extra = {t: c for t, c in extra_counter.items() if t in _STRUCTURAL}
if structural_extra:
print(f"\n Structural tags wrongly added (false positives):")
for t, c in sorted(structural_extra.items(), key=lambda x: -x[1]):
print(f" {t:30s} extra {c}/{N}")
# Per-structural-tag stats from the structural field
struct_tag_counts = Counter()
for sl in structural_results:
for t in sl: struct_tag_counts[t] += 1
print(f"\n Structural tag selection frequency (how often Stage 3s picks each):")
for t, c in struct_tag_counts.most_common():
missed_c = structural_missed.get(t, 0)
extra_c = structural_extra.get(t, 0)
print(f" {t:30s} picked {c:>2}/{N} missed_in_GT={missed_c} false_pos={extra_c}")
# ── REPORT 3: Extra tags (false positives) by category ──
print("\n" + "=" * 70)
print(f"EXTRA TAGS β€” Selected but not in GT ({sum(extra_counter.values())} total, {len(extra_counter)} unique)")
print("=" * 70)
cat_extra = defaultdict(Counter)
for tag, cnt in extra_counter.items():
cat_extra[categorize(tag, tag_type)][tag] = cnt
cat_extra_totals = {c: sum(v.values()) for c, v in cat_extra.items()}
for cat in sorted(cat_extra_totals, key=cat_extra_totals.get, reverse=True):
tags = cat_extra[cat]
total = cat_extra_totals[cat]
print(f"\n [{cat}] β€” {total} false positives across {len(tags)} unique tags")
for tag, cnt in tags.most_common(8):
freq = tag_count.get(tag, 0)
print(f" {tag:40s} extra {cnt:>2}/{N} freq={freq:>9,}")
# ── REPORT 3b: Evidence sources for false positives ──
# (Only available in new format with extra_evidence field)
source_counts = Counter() # source -> count of FP tags
why_fp_counts = Counter() # why level -> count of FP tags from stage3
score_buckets = {"high (>0.5)": 0, "medium (0.2-0.5)": 0, "low (<0.2)": 0}
has_evidence = False
for s in samples:
ev = s.get("extra_evidence", {})
if ev:
has_evidence = True
for tag, info in ev.items():
src = info.get("source", "unknown")
source_counts[src] += 1
if src == "stage3":
why_fp_counts[info.get("why", "unknown")] += 1
score = info.get("retrieval_score", 0)
if score > 0.5: score_buckets["high (>0.5)"] += 1
elif score > 0.2: score_buckets["medium (0.2-0.5)"] += 1
else: score_buckets["low (<0.2)"] += 1
if has_evidence:
print("\n" + "=" * 70)
print("FALSE POSITIVE EVIDENCE SOURCES")
print("=" * 70)
total_fp = sum(source_counts.values())
print(f"\n How did {total_fp} false positive tags get through?")
for src, cnt in source_counts.most_common():
print(f" {src:20s} {cnt:>4} ({cnt/max(1,total_fp)*100:.0f}%)")
if why_fp_counts:
print(f"\n Stage 3 false positives by 'why' level:")
for why, cnt in why_fp_counts.most_common():
print(f" {why:20s} {cnt:>4}")
print(f"\n Stage 3 false positives by retrieval score:")
for bucket, cnt in score_buckets.items():
print(f" {bucket:20s} {cnt:>4}")
# ── REPORT 4: Leaf vs non-leaf in missed ──
print("\n" + "=" * 70)
print("MISSED: LEAF vs IMPLIED ANCESTORS")
print("=" * 70)
all_missed = set(missed_counter.keys())
leaf_missed = get_leaf_tags(all_missed, impl)
anc_missed = all_missed - leaf_missed
leaf_vol = sum(missed_counter[t] for t in leaf_missed)
anc_vol = sum(missed_counter[t] for t in anc_missed)
total_vol = leaf_vol + anc_vol
print(f"\n Unique missed: {len(all_missed)} tags")
print(f" Leaf: {len(leaf_missed)} ({len(leaf_missed)/max(1,len(all_missed))*100:.0f}%)")
print(f" Ancestor: {len(anc_missed)} ({len(anc_missed)/max(1,len(all_missed))*100:.0f}%)")
print(f" Miss volume: {total_vol}")
print(f" From leaf: {leaf_vol} ({leaf_vol/max(1,total_vol)*100:.0f}%)")
print(f" From ancestor: {anc_vol} ({anc_vol/max(1,total_vol)*100:.0f}%) β€” recoverable via implications")
# ── REPORT 5: Frequency distribution ──
print("\n" + "=" * 70)
print("FREQUENCY DISTRIBUTION OF MISSED TAGS")
print("=" * 70)
buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0,
"common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0}
for tag in missed_counter:
freq = tag_count.get(tag, -1)
if freq < 0: buckets["not_in_db"] += 1
elif freq < 100: buckets["very_rare (<100)"] += 1
elif freq < 1000: buckets["rare (100-1k)"] += 1
elif freq < 10000: buckets["medium (1k-10k)"] += 1
elif freq < 100000: buckets["common (10k-100k)"] += 1
else: buckets["very_common (100k+)"] += 1
for b, c in buckets.items():
print(f" {b:25s} {c:4d} unique tags ({c/max(1,len(missed_counter))*100:.0f}%)")
# ── REPORT 6: Over-selection analysis ──
print("\n" + "=" * 70)
print("OVER-SELECTION ANALYSIS")
print("=" * 70)
over_sels = [s["over_sel"] for s in samples]
over_sels.sort()
print(f"\n Avg over-selection ratio: {sum(over_sels)/N:.2f}x")
print(f" Median: {over_sels[N//2]:.2f}x")
print(f" Min: {over_sels[0]:.2f}x")
print(f" Max: {over_sels[-1]:.2f}x")
tight = sum(1 for x in over_sels if 0.8 <= x <= 1.5)
over = sum(1 for x in over_sels if x > 2.0)
under = sum(1 for x in over_sels if x < 0.5)
print(f" Tight (0.8-1.5x): {tight}/{N}")
print(f" Over (>2.0x): {over}/{N}")
print(f" Under (<0.5x): {under}/{N}")
# Worst over-selectors
worst = sorted(samples, key=lambda s: -s["over_sel"])[:5]
print(f"\n Worst over-selectors:")
for s in worst:
print(f" id={s['id']:>8} over_sel={s['over_sel']:.2f}x selected={s['n_selected']} gt={s['n_gt']} "
f"F1={s['F1']:.3f} n_extra={len(s.get('extra',[]))}")
# ── REPORT 7: Aggregate metrics ──
print("\n" + "=" * 70)
print("AGGREGATE METRICS")
print("=" * 70)
for metric, key in [("F1", "F1"), ("Precision", "P"), ("Recall", "R"),
("Leaf F1", "leaf_F1"), ("Leaf P", "leaf_P"), ("Leaf R", "leaf_R"),
("Retrieval Recall", "ret_R")]:
vals = [s[key] for s in samples]
avg = sum(vals)/N
vals.sort()
med = vals[N//2]
print(f" {metric:20s} avg={avg:.4f} median={med:.4f} min={vals[0]:.4f} max={vals[-1]:.4f}")
# ── REPORT 8: Samples sorted by F1 ──
print("\n" + "=" * 70)
print("WORST 10 SAMPLES BY F1")
print("=" * 70)
by_f1 = sorted(samples, key=lambda s: s["F1"])
for s in by_f1[:10]:
n_missed = len(s.get("missed", []))
n_extra = len(s.get("extra", []))
print(f" id={s['id']:>8} F1={s['F1']:.3f} P={s['P']:.3f} R={s['R']:.3f} "
f"gt={s['n_gt']} sel={s['n_selected']} missed={n_missed} extra={n_extra} "
f"structural={s.get('structural',[])} over_sel={s['over_sel']:.2f}x")
print()
if __name__ == "__main__":
main()