Spaces:
Running
Running
File size: 15,486 Bytes
16c5aa4 684cf99 16c5aa4 019823a 16c5aa4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | """Analyze compact eval results (n=50) for patterns in missed and extra tags.
Works with the new compact JSONL format (missed/extra diff sets, not full tag lists).
"""
from __future__ import annotations
import csv, json, re, sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
_REPO_ROOT = Path(__file__).resolve().parents[1]
TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"}
def load_tag_db():
tag_type, tag_count = {}, {}
with (_REPO_ROOT / "fluffyrock_3m.csv").open("r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 3: continue
tag = row[0].strip()
try: tid = int(row[1]) if row[1].strip() else -1
except ValueError: tid = -1
try: cnt = int(row[2]) if row[2].strip() else 0
except ValueError: cnt = 0
tag_type[tag] = tid
tag_count[tag] = cnt
return tag_type, tag_count
def load_implications():
impl = defaultdict(list)
p = _REPO_ROOT / "tag_implications-2023-07-20.csv"
if not p.is_file(): return impl
with p.open("r", encoding="utf-8") as f:
for row in csv.DictReader(f):
if row.get("status") == "active":
impl[row["antecedent_name"].strip()].append(row["consequent_name"].strip())
return dict(impl)
def get_leaf_tags(tags, impl):
non_leaves = set()
for tag in tags:
q = [tag]; vis = set()
while q:
t = q.pop()
for p in impl.get(t, []):
if p not in vis:
vis.add(p)
if p in tags: non_leaves.add(p)
q.append(p)
return tags - non_leaves
# ββ Categorization ββ
_TAXONOMY = frozenset({"mammal","canid","canine","canis","felid","feline","felis","ursine","cervid","bovid","equid","equine","mustelid","procyonid","reptile","scalie","avian","bird","fish","marine","arthropod","insect","arachnid","amphibian","primate","rodent","lagomorph","leporid","galliform","gallus_(genus)","phasianid","passerine","oscine","dinosaur","theropod","cetacean","pinniped","chiroptera","marsupial","monotreme","mephitid","suid","suina"})
_BODY_PLAN = frozenset({"anthro","feral","biped","quadruped","taur","humanoid","semi-anthro","animatronic","robot","machine","plushie","kemono"})
_POSE = frozenset({"solo","duo","group","trio","standing","sitting","lying","running","walking","flying","swimming","crouching","kneeling","jumping","looking_at_viewer","looking_away","looking_back","looking_up","looking_down","looking_aside","front_view","side_view","back_view","three-quarter_view","from_above","from_below","close-up","portrait","full-length_portrait","hand_on_hip","arms_crossed","all_fours","on_back","on_side","crossed_arms"})
_COUNT_RE = re.compile(r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails)")
_STRUCTURAL = frozenset({
# Character count
"solo","duo","trio","group","zero_pictured",
# Body type
"anthro","feral","humanoid","taur",
# Gender
"male","female","ambiguous_gender","intersex",
# Clothing state
"clothed","nude","topless","bottomless",
# Visual elements
"looking_at_viewer","text",
})
def categorize(tag, tag_type):
tid = tag_type.get(tag, -1)
tn = TYPE_ID_NAMES.get(tid, "unknown")
if tn == "species": return "species"
if tn in ("artist","copyright","character","meta"): return tn
if tag in _TAXONOMY: return "taxonomy"
if tag in _BODY_PLAN: return "body_plan"
if tag in _POSE: return "pose/composition"
if _COUNT_RE.match(tag): return "count/anatomy"
if tag in ("male","female","intersex","ambiguous_gender","andromorph","gynomorph"): return "gender"
if any(k in tag for k in ("clothing","clothed","topwear","bottomwear","legwear","handwear","headwear","footwear","shirt","pants","shorts","dress","skirt","jacket","coat","hat","boots","shoes","gloves","socks","stockings","belt","collar","scarf","cape","armor","suit","uniform","costume","outfit")): return "clothing"
if any(tag.startswith(c+"_") for c in ("red","blue","green","yellow","orange","purple","pink","black","white","grey","gray","brown","tan","cream","gold","silver","teal","cyan","magenta")): return "color/marking"
if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": return "color/marking"
if "hair" in tag: return "hair"
if any(k in tag for k in ("muscle","belly","chest","abs","breast","butt","tail","wing","horn","ear","eye","teeth","fang","claw","paw","hoof","snout","muzzle","tongue","fur","scales","feather","tuft","fluff","mane")): return "body/anatomy"
if any(k in tag for k in ("smile","grin","frown","expression","blush","angry","happy","sad","crying","laughing","open_mouth","closed_eyes","wink")): return "expression"
return "other_general"
def main():
path = Path(sys.argv[1]) if len(sys.argv) > 1 else sorted((_REPO_ROOT/"data"/"eval_results").glob("eval_*.jsonl"))[-1]
tag_type, tag_count = load_tag_db()
impl = load_implications()
samples = []
with path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
print(f"Config: min_why={row.get('min_why')}, expand_impl={row.get('expand_implications')}, "
f"structural={row.get('infer_structural')}, n={row.get('n_samples')}")
continue
if row.get("err"): continue
samples.append(row)
N = len(samples)
print(f"Analyzing {N} samples from {path.name}\n")
# ββ 1. Missed tags (GT tags not in selected) ββ
missed_counter = Counter()
extra_counter = Counter()
structural_results = []
for s in samples:
for t in s.get("missed", []): missed_counter[t] += 1
for t in s.get("extra", []): extra_counter[t] += 1
structural_results.append(s.get("structural", []))
# ββ REPORT 1: Missed by category ββ
print("=" * 70)
print(f"MISSED TAGS β GT tags not selected ({sum(missed_counter.values())} total misses, {len(missed_counter)} unique)")
print("=" * 70)
cat_missed = defaultdict(Counter)
for tag, cnt in missed_counter.items():
cat_missed[categorize(tag, tag_type)][tag] = cnt
cat_totals = {c: sum(v.values()) for c, v in cat_missed.items()}
for cat in sorted(cat_totals, key=cat_totals.get, reverse=True):
tags = cat_missed[cat]
total = cat_totals[cat]
# Is this category covered by structural inference?
struct_covered = sum(1 for t in tags if t in _STRUCTURAL)
struct_note = f" ({struct_covered} structural-coverable)" if struct_covered else ""
print(f"\n [{cat}] β {total} misses across {len(tags)} unique tags{struct_note}")
for tag, cnt in tags.most_common(10):
freq = tag_count.get(tag, 0)
struct_mark = " *STRUCTURAL*" if tag in _STRUCTURAL else ""
print(f" {tag:40s} missed {cnt:>2}/{N}{struct_mark} freq={freq:>9,}")
# ββ REPORT 2: Missed tags that structural should catch ββ
print("\n" + "=" * 70)
print("STRUCTURAL TAG ACCURACY")
print("=" * 70)
# Which structural tags are still being missed?
structural_missed = {t: c for t, c in missed_counter.items() if t in _STRUCTURAL}
if structural_missed:
print("\n Structural tags STILL missed (Stage 3s should catch these):")
for t, c in sorted(structural_missed.items(), key=lambda x: -x[1]):
print(f" {t:30s} missed {c}/{N}")
else:
print("\n All structural tags covered!")
# What structural tags are over-applied (false positives)?
structural_extra = {t: c for t, c in extra_counter.items() if t in _STRUCTURAL}
if structural_extra:
print(f"\n Structural tags wrongly added (false positives):")
for t, c in sorted(structural_extra.items(), key=lambda x: -x[1]):
print(f" {t:30s} extra {c}/{N}")
# Per-structural-tag stats from the structural field
struct_tag_counts = Counter()
for sl in structural_results:
for t in sl: struct_tag_counts[t] += 1
print(f"\n Structural tag selection frequency (how often Stage 3s picks each):")
for t, c in struct_tag_counts.most_common():
missed_c = structural_missed.get(t, 0)
extra_c = structural_extra.get(t, 0)
print(f" {t:30s} picked {c:>2}/{N} missed_in_GT={missed_c} false_pos={extra_c}")
# ββ REPORT 3: Extra tags (false positives) by category ββ
print("\n" + "=" * 70)
print(f"EXTRA TAGS β Selected but not in GT ({sum(extra_counter.values())} total, {len(extra_counter)} unique)")
print("=" * 70)
cat_extra = defaultdict(Counter)
for tag, cnt in extra_counter.items():
cat_extra[categorize(tag, tag_type)][tag] = cnt
cat_extra_totals = {c: sum(v.values()) for c, v in cat_extra.items()}
for cat in sorted(cat_extra_totals, key=cat_extra_totals.get, reverse=True):
tags = cat_extra[cat]
total = cat_extra_totals[cat]
print(f"\n [{cat}] β {total} false positives across {len(tags)} unique tags")
for tag, cnt in tags.most_common(8):
freq = tag_count.get(tag, 0)
print(f" {tag:40s} extra {cnt:>2}/{N} freq={freq:>9,}")
# ββ REPORT 3b: Evidence sources for false positives ββ
# (Only available in new format with extra_evidence field)
source_counts = Counter() # source -> count of FP tags
why_fp_counts = Counter() # why level -> count of FP tags from stage3
score_buckets = {"high (>0.5)": 0, "medium (0.2-0.5)": 0, "low (<0.2)": 0}
has_evidence = False
for s in samples:
ev = s.get("extra_evidence", {})
if ev:
has_evidence = True
for tag, info in ev.items():
src = info.get("source", "unknown")
source_counts[src] += 1
if src == "stage3":
why_fp_counts[info.get("why", "unknown")] += 1
score = info.get("retrieval_score", 0)
if score > 0.5: score_buckets["high (>0.5)"] += 1
elif score > 0.2: score_buckets["medium (0.2-0.5)"] += 1
else: score_buckets["low (<0.2)"] += 1
if has_evidence:
print("\n" + "=" * 70)
print("FALSE POSITIVE EVIDENCE SOURCES")
print("=" * 70)
total_fp = sum(source_counts.values())
print(f"\n How did {total_fp} false positive tags get through?")
for src, cnt in source_counts.most_common():
print(f" {src:20s} {cnt:>4} ({cnt/max(1,total_fp)*100:.0f}%)")
if why_fp_counts:
print(f"\n Stage 3 false positives by 'why' level:")
for why, cnt in why_fp_counts.most_common():
print(f" {why:20s} {cnt:>4}")
print(f"\n Stage 3 false positives by retrieval score:")
for bucket, cnt in score_buckets.items():
print(f" {bucket:20s} {cnt:>4}")
# ββ REPORT 4: Leaf vs non-leaf in missed ββ
print("\n" + "=" * 70)
print("MISSED: LEAF vs IMPLIED ANCESTORS")
print("=" * 70)
all_missed = set(missed_counter.keys())
leaf_missed = get_leaf_tags(all_missed, impl)
anc_missed = all_missed - leaf_missed
leaf_vol = sum(missed_counter[t] for t in leaf_missed)
anc_vol = sum(missed_counter[t] for t in anc_missed)
total_vol = leaf_vol + anc_vol
print(f"\n Unique missed: {len(all_missed)} tags")
print(f" Leaf: {len(leaf_missed)} ({len(leaf_missed)/max(1,len(all_missed))*100:.0f}%)")
print(f" Ancestor: {len(anc_missed)} ({len(anc_missed)/max(1,len(all_missed))*100:.0f}%)")
print(f" Miss volume: {total_vol}")
print(f" From leaf: {leaf_vol} ({leaf_vol/max(1,total_vol)*100:.0f}%)")
print(f" From ancestor: {anc_vol} ({anc_vol/max(1,total_vol)*100:.0f}%) β recoverable via implications")
# ββ REPORT 5: Frequency distribution ββ
print("\n" + "=" * 70)
print("FREQUENCY DISTRIBUTION OF MISSED TAGS")
print("=" * 70)
buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0,
"common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0}
for tag in missed_counter:
freq = tag_count.get(tag, -1)
if freq < 0: buckets["not_in_db"] += 1
elif freq < 100: buckets["very_rare (<100)"] += 1
elif freq < 1000: buckets["rare (100-1k)"] += 1
elif freq < 10000: buckets["medium (1k-10k)"] += 1
elif freq < 100000: buckets["common (10k-100k)"] += 1
else: buckets["very_common (100k+)"] += 1
for b, c in buckets.items():
print(f" {b:25s} {c:4d} unique tags ({c/max(1,len(missed_counter))*100:.0f}%)")
# ββ REPORT 6: Over-selection analysis ββ
print("\n" + "=" * 70)
print("OVER-SELECTION ANALYSIS")
print("=" * 70)
over_sels = [s["over_sel"] for s in samples]
over_sels.sort()
print(f"\n Avg over-selection ratio: {sum(over_sels)/N:.2f}x")
print(f" Median: {over_sels[N//2]:.2f}x")
print(f" Min: {over_sels[0]:.2f}x")
print(f" Max: {over_sels[-1]:.2f}x")
tight = sum(1 for x in over_sels if 0.8 <= x <= 1.5)
over = sum(1 for x in over_sels if x > 2.0)
under = sum(1 for x in over_sels if x < 0.5)
print(f" Tight (0.8-1.5x): {tight}/{N}")
print(f" Over (>2.0x): {over}/{N}")
print(f" Under (<0.5x): {under}/{N}")
# Worst over-selectors
worst = sorted(samples, key=lambda s: -s["over_sel"])[:5]
print(f"\n Worst over-selectors:")
for s in worst:
print(f" id={s['id']:>8} over_sel={s['over_sel']:.2f}x selected={s['n_selected']} gt={s['n_gt']} "
f"F1={s['F1']:.3f} n_extra={len(s.get('extra',[]))}")
# ββ REPORT 7: Aggregate metrics ββ
print("\n" + "=" * 70)
print("AGGREGATE METRICS")
print("=" * 70)
for metric, key in [("F1", "F1"), ("Precision", "P"), ("Recall", "R"),
("Leaf F1", "leaf_F1"), ("Leaf P", "leaf_P"), ("Leaf R", "leaf_R"),
("Retrieval Recall", "ret_R")]:
vals = [s[key] for s in samples]
avg = sum(vals)/N
vals.sort()
med = vals[N//2]
print(f" {metric:20s} avg={avg:.4f} median={med:.4f} min={vals[0]:.4f} max={vals[-1]:.4f}")
# ββ REPORT 8: Samples sorted by F1 ββ
print("\n" + "=" * 70)
print("WORST 10 SAMPLES BY F1")
print("=" * 70)
by_f1 = sorted(samples, key=lambda s: s["F1"])
for s in by_f1[:10]:
n_missed = len(s.get("missed", []))
n_extra = len(s.get("extra", []))
print(f" id={s['id']:>8} F1={s['F1']:.3f} P={s['P']:.3f} R={s['R']:.3f} "
f"gt={s['n_gt']} sel={s['n_selected']} missed={n_missed} extra={n_extra} "
f"structural={s.get('structural',[])} over_sel={s['over_sel']:.2f}x")
print()
if __name__ == "__main__":
main()
|