Spaces:
Running
Running
| """Analyze compact eval results (n=50) for patterns in missed and extra tags. | |
| Works with the new compact JSONL format (missed/extra diff sets, not full tag lists). | |
| """ | |
| from __future__ import annotations | |
| import csv, json, re, sys | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"} | |
| def load_tag_db(): | |
| tag_type, tag_count = {}, {} | |
| with (_REPO_ROOT / "fluffyrock_3m.csv").open("r", encoding="utf-8") as f: | |
| for row in csv.reader(f): | |
| if len(row) < 3: continue | |
| tag = row[0].strip() | |
| try: tid = int(row[1]) if row[1].strip() else -1 | |
| except ValueError: tid = -1 | |
| try: cnt = int(row[2]) if row[2].strip() else 0 | |
| except ValueError: cnt = 0 | |
| tag_type[tag] = tid | |
| tag_count[tag] = cnt | |
| return tag_type, tag_count | |
| def load_implications(): | |
| impl = defaultdict(list) | |
| p = _REPO_ROOT / "tag_implications-2023-07-20.csv" | |
| if not p.is_file(): return impl | |
| with p.open("r", encoding="utf-8") as f: | |
| for row in csv.DictReader(f): | |
| if row.get("status") == "active": | |
| impl[row["antecedent_name"].strip()].append(row["consequent_name"].strip()) | |
| return dict(impl) | |
| def get_leaf_tags(tags, impl): | |
| non_leaves = set() | |
| for tag in tags: | |
| q = [tag]; vis = set() | |
| while q: | |
| t = q.pop() | |
| for p in impl.get(t, []): | |
| if p not in vis: | |
| vis.add(p) | |
| if p in tags: non_leaves.add(p) | |
| q.append(p) | |
| return tags - non_leaves | |
| # ββ Categorization ββ | |
| _TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine","mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid","amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine","oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid","suid","suina"}) | |
| _BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"}) | |
| _POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"}) | |
| _COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)") | |
| _STRUCTURAL = frozenset({ | |
| # Character count | |
| "solo","duo","trio","group","zero_pictured", | |
| # Body type | |
| "anthro","feral","humanoid","taur", | |
| # Gender | |
| "male","female","ambiguous_gender","intersex", | |
| # Clothing state | |
| "clothed","nude","topless","bottomless", | |
| # Visual elements | |
| "looking_at_viewer","text", | |
| }) | |
| def categorize(tag, tag_type): | |
| tid = tag_type.get(tag, -1) | |
| tn = TYPE_ID_NAMES.get(tid, "unknown") | |
| if tn == "species": return "species" | |
| if tn in ("artist","copyright","character","meta"): return tn | |
| if tag in _TAXONOMY: return "taxonomy" | |
| if tag in _BODY_PLAN: return "body_plan" | |
| if tag in _POSE: return "pose/composition" | |
| if _COUNT_RE.match(tag): return "count/anatomy" | |
| if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender" | |
| if any(k in tag for k in ("clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants","shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt","collar","scarf","cape","armor","suit","uniform","costume","outfit")): return "clothing" | |
| if any(tag.startswith(c+"_") for c in ("red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream","gold","silver","teal","cyan","magenta")): return "color/marking" | |
| if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking" | |
| if "hair" in tag: return "hair" | |
| if any(k in tag for k in ("muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw","paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane")): return "body/anatomy" | |
| if any(k in tag for k in ("smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth","closed_eyes","wink")): return "expression" | |
| return "other_general" | |
| def main(): | |
| path = Path(sys.argv[1]) if len(sys.argv) > 1 else sorted((_REPO_ROOT/"data"/"eval_results").glob("eval_*.jsonl"))[-1] | |
| tag_type, tag_count = load_tag_db() | |
| impl = load_implications() | |
| samples = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| if row.get("_meta"): | |
| print(f"Config: min_why={row.get('min_why')}, expand_impl={row.get('expand_implications')}, " | |
| f"structural={row.get('infer_structural')}, n={row.get('n_samples')}") | |
| continue | |
| if row.get("err"): continue | |
| samples.append(row) | |
| N = len(samples) | |
| print(f"Analyzing {N} samples from {path.name}\n") | |
| # ββ 1. Missed tags (GT tags not in selected) ββ | |
| missed_counter = Counter() | |
| extra_counter = Counter() | |
| structural_results = [] | |
| for s in samples: | |
| for t in s.get("missed", []): missed_counter[t] += 1 | |
| for t in s.get("extra", []): extra_counter[t] += 1 | |
| structural_results.append(s.get("structural", [])) | |
| # ββ REPORT 1: Missed by category ββ | |
| print("=" * 70) | |
| print(f"MISSED TAGS β GT tags not selected ({sum(missed_counter.values())} total misses, {len(missed_counter)} unique)") | |
| print("=" * 70) | |
| cat_missed = defaultdict(Counter) | |
| for tag, cnt in missed_counter.items(): | |
| cat_missed[categorize(tag, tag_type)][tag] = cnt | |
| cat_totals = {c: sum(v.values()) for c, v in cat_missed.items()} | |
| for cat in sorted(cat_totals, key=cat_totals.get, reverse=True): | |
| tags = cat_missed[cat] | |
| total = cat_totals[cat] | |
| # Is this category covered by structural inference? | |
| struct_covered = sum(1 for t in tags if t in _STRUCTURAL) | |
| struct_note = f" ({struct_covered} structural-coverable)" if struct_covered else "" | |
| print(f"\n [{cat}] β {total} misses across {len(tags)} unique tags{struct_note}") | |
| for tag, cnt in tags.most_common(10): | |
| freq = tag_count.get(tag, 0) | |
| struct_mark = " *STRUCTURAL*" if tag in _STRUCTURAL else "" | |
| print(f" {tag:40s} missed {cnt:>2}/{N}{struct_mark} freq={freq:>9,}") | |
| # ββ REPORT 2: Missed tags that structural should catch ββ | |
| print("\n" + "=" * 70) | |
| print("STRUCTURAL TAG ACCURACY") | |
| print("=" * 70) | |
| # Which structural tags are still being missed? | |
| structural_missed = {t: c for t, c in missed_counter.items() if t in _STRUCTURAL} | |
| if structural_missed: | |
| print("\n Structural tags STILL missed (Stage 3s should catch these):") | |
| for t, c in sorted(structural_missed.items(), key=lambda x: -x[1]): | |
| print(f" {t:30s} missed {c}/{N}") | |
| else: | |
| print("\n All structural tags covered!") | |
| # What structural tags are over-applied (false positives)? | |
| structural_extra = {t: c for t, c in extra_counter.items() if t in _STRUCTURAL} | |
| if structural_extra: | |
| print(f"\n Structural tags wrongly added (false positives):") | |
| for t, c in sorted(structural_extra.items(), key=lambda x: -x[1]): | |
| print(f" {t:30s} extra {c}/{N}") | |
| # Per-structural-tag stats from the structural field | |
| struct_tag_counts = Counter() | |
| for sl in structural_results: | |
| for t in sl: struct_tag_counts[t] += 1 | |
| print(f"\n Structural tag selection frequency (how often Stage 3s picks each):") | |
| for t, c in struct_tag_counts.most_common(): | |
| missed_c = structural_missed.get(t, 0) | |
| extra_c = structural_extra.get(t, 0) | |
| print(f" {t:30s} picked {c:>2}/{N} missed_in_GT={missed_c} false_pos={extra_c}") | |
| # ββ REPORT 3: Extra tags (false positives) by category ββ | |
| print("\n" + "=" * 70) | |
| print(f"EXTRA TAGS β Selected but not in GT ({sum(extra_counter.values())} total, {len(extra_counter)} unique)") | |
| print("=" * 70) | |
| cat_extra = defaultdict(Counter) | |
| for tag, cnt in extra_counter.items(): | |
| cat_extra[categorize(tag, tag_type)][tag] = cnt | |
| cat_extra_totals = {c: sum(v.values()) for c, v in cat_extra.items()} | |
| for cat in sorted(cat_extra_totals, key=cat_extra_totals.get, reverse=True): | |
| tags = cat_extra[cat] | |
| total = cat_extra_totals[cat] | |
| print(f"\n [{cat}] β {total} false positives across {len(tags)} unique tags") | |
| for tag, cnt in tags.most_common(8): | |
| freq = tag_count.get(tag, 0) | |
| print(f" {tag:40s} extra {cnt:>2}/{N} freq={freq:>9,}") | |
| # ββ REPORT 3b: Evidence sources for false positives ββ | |
| # (Only available in new format with extra_evidence field) | |
| source_counts = Counter() # source -> count of FP tags | |
| why_fp_counts = Counter() # why level -> count of FP tags from stage3 | |
| score_buckets = {"high (>0.5)": 0, "medium (0.2-0.5)": 0, "low (<0.2)": 0} | |
| has_evidence = False | |
| for s in samples: | |
| ev = s.get("extra_evidence", {}) | |
| if ev: | |
| has_evidence = True | |
| for tag, info in ev.items(): | |
| src = info.get("source", "unknown") | |
| source_counts[src] += 1 | |
| if src == "stage3": | |
| why_fp_counts[info.get("why", "unknown")] += 1 | |
| score = info.get("retrieval_score", 0) | |
| if score > 0.5: score_buckets["high (>0.5)"] += 1 | |
| elif score > 0.2: score_buckets["medium (0.2-0.5)"] += 1 | |
| else: score_buckets["low (<0.2)"] += 1 | |
| if has_evidence: | |
| print("\n" + "=" * 70) | |
| print("FALSE POSITIVE EVIDENCE SOURCES") | |
| print("=" * 70) | |
| total_fp = sum(source_counts.values()) | |
| print(f"\n How did {total_fp} false positive tags get through?") | |
| for src, cnt in source_counts.most_common(): | |
| print(f" {src:20s} {cnt:>4} ({cnt/max(1,total_fp)*100:.0f}%)") | |
| if why_fp_counts: | |
| print(f"\n Stage 3 false positives by 'why' level:") | |
| for why, cnt in why_fp_counts.most_common(): | |
| print(f" {why:20s} {cnt:>4}") | |
| print(f"\n Stage 3 false positives by retrieval score:") | |
| for bucket, cnt in score_buckets.items(): | |
| print(f" {bucket:20s} {cnt:>4}") | |
| # ββ REPORT 4: Leaf vs non-leaf in missed ββ | |
| print("\n" + "=" * 70) | |
| print("MISSED: LEAF vs IMPLIED ANCESTORS") | |
| print("=" * 70) | |
| all_missed = set(missed_counter.keys()) | |
| leaf_missed = get_leaf_tags(all_missed, impl) | |
| anc_missed = all_missed - leaf_missed | |
| leaf_vol = sum(missed_counter[t] for t in leaf_missed) | |
| anc_vol = sum(missed_counter[t] for t in anc_missed) | |
| total_vol = leaf_vol + anc_vol | |
| print(f"\n Unique missed: {len(all_missed)} tags") | |
| print(f" Leaf: {len(leaf_missed)} ({len(leaf_missed)/max(1,len(all_missed))*100:.0f}%)") | |
| print(f" Ancestor: {len(anc_missed)} ({len(anc_missed)/max(1,len(all_missed))*100:.0f}%)") | |
| print(f" Miss volume: {total_vol}") | |
| print(f" From leaf: {leaf_vol} ({leaf_vol/max(1,total_vol)*100:.0f}%)") | |
| print(f" From ancestor: {anc_vol} ({anc_vol/max(1,total_vol)*100:.0f}%) β recoverable via implications") | |
| # ββ REPORT 5: Frequency distribution ββ | |
| print("\n" + "=" * 70) | |
| print("FREQUENCY DISTRIBUTION OF MISSED TAGS") | |
| print("=" * 70) | |
| buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0, | |
| "common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0} | |
| for tag in missed_counter: | |
| freq = tag_count.get(tag, -1) | |
| if freq < 0: buckets["not_in_db"] += 1 | |
| elif freq < 100: buckets["very_rare (<100)"] += 1 | |
| elif freq < 1000: buckets["rare (100-1k)"] += 1 | |
| elif freq < 10000: buckets["medium (1k-10k)"] += 1 | |
| elif freq < 100000: buckets["common (10k-100k)"] += 1 | |
| else: buckets["very_common (100k+)"] += 1 | |
| for b, c in buckets.items(): | |
| print(f" {b:25s} {c:4d} unique tags ({c/max(1,len(missed_counter))*100:.0f}%)") | |
| # ββ REPORT 6: Over-selection analysis ββ | |
| print("\n" + "=" * 70) | |
| print("OVER-SELECTION ANALYSIS") | |
| print("=" * 70) | |
| over_sels = [s["over_sel"] for s in samples] | |
| over_sels.sort() | |
| print(f"\n Avg over-selection ratio: {sum(over_sels)/N:.2f}x") | |
| print(f" Median: {over_sels[N//2]:.2f}x") | |
| print(f" Min: {over_sels[0]:.2f}x") | |
| print(f" Max: {over_sels[-1]:.2f}x") | |
| tight = sum(1 for x in over_sels if 0.8 <= x <= 1.5) | |
| over = sum(1 for x in over_sels if x > 2.0) | |
| under = sum(1 for x in over_sels if x < 0.5) | |
| print(f" Tight (0.8-1.5x): {tight}/{N}") | |
| print(f" Over (>2.0x): {over}/{N}") | |
| print(f" Under (<0.5x): {under}/{N}") | |
| # Worst over-selectors | |
| worst = sorted(samples, key=lambda s: -s["over_sel"])[:5] | |
| print(f"\n Worst over-selectors:") | |
| for s in worst: | |
| print(f" id={s['id']:>8} over_sel={s['over_sel']:.2f}x selected={s['n_selected']} gt={s['n_gt']} " | |
| f"F1={s['F1']:.3f} n_extra={len(s.get('extra',[]))}") | |
| # ββ REPORT 7: Aggregate metrics ββ | |
| print("\n" + "=" * 70) | |
| print("AGGREGATE METRICS") | |
| print("=" * 70) | |
| for metric, key in [("F1", "F1"), ("Precision", "P"), ("Recall", "R"), | |
| ("Leaf F1", "leaf_F1"), ("Leaf P", "leaf_P"), ("Leaf R", "leaf_R"), | |
| ("Retrieval Recall", "ret_R")]: | |
| vals = [s[key] for s in samples] | |
| avg = sum(vals)/N | |
| vals.sort() | |
| med = vals[N//2] | |
| print(f" {metric:20s} avg={avg:.4f} median={med:.4f} min={vals[0]:.4f} max={vals[-1]:.4f}") | |
| # ββ REPORT 8: Samples sorted by F1 ββ | |
| print("\n" + "=" * 70) | |
| print("WORST 10 SAMPLES BY F1") | |
| print("=" * 70) | |
| by_f1 = sorted(samples, key=lambda s: s["F1"]) | |
| for s in by_f1[:10]: | |
| n_missed = len(s.get("missed", [])) | |
| n_extra = len(s.get("extra", [])) | |
| print(f" id={s['id']:>8} F1={s['F1']:.3f} P={s['P']:.3f} R={s['R']:.3f} " | |
| f"gt={s['n_gt']} sel={s['n_selected']} missed={n_missed} extra={n_extra} " | |
| f"structural={s.get('structural',[])} over_sel={s['over_sel']:.2f}x") | |
| print() | |
| if __name__ == "__main__": | |
| main() | |