Spaces:
Running
Running
File size: 17,133 Bytes
4968635 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 | """Deep analysis of retrieval gaps: what kinds of tags are missing after Stage 2.
Reads the latest eval results JSONL and categorizes missed tags by:
- Tag type (general, species, character, meta, etc.)
- Whether the miss is a leaf tag or an implied ancestor
- Semantic category (taxonomy, body/anatomy, clothing, color, pose, etc.)
- Whether the tag appears in the rewrite phrases
- Frequency in the tag database (common vs rare tags)
Usage:
python scripts/analyze_retrieval_gaps.py [path/to/eval_results.jsonl]
If no path given, uses the latest file in data/eval_results/.
"""
from __future__ import annotations
import csv
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
_REPO_ROOT = Path(__file__).resolve().parents[1]
# ββ Load tag database ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"}
def load_tag_db() -> Tuple[Dict[str, int], Dict[str, int]]:
"""Return (tagβtype_id, tagβcount) from fluffyrock_3m.csv."""
tag_type: Dict[str, int] = {}
tag_count: Dict[str, int] = {}
csv_path = _REPO_ROOT / "fluffyrock_3m.csv"
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
tag = row[0].strip()
try:
tid = int(row[1]) if row[1].strip() else -1
except ValueError:
tid = -1
try:
count = int(row[2]) if row[2].strip() else 0
except ValueError:
count = 0
tag_type[tag] = tid
tag_count[tag] = count
return tag_type, tag_count
def load_implications() -> Dict[str, List[str]]:
"""Return antecedent β [consequent, ...] from tag implications CSV."""
impl: Dict[str, List[str]] = defaultdict(list)
csv_path = _REPO_ROOT / "tag_implications-2023-07-20.csv"
if not csv_path.is_file():
return impl
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("status") != "active":
continue
ant = row["antecedent_name"].strip()
con = row["consequent_name"].strip()
impl[ant].append(con)
return dict(impl)
def get_leaf_tags(tags: Set[str], impl: Dict[str, List[str]]) -> Set[str]:
"""Tags not implied by any other tag in the set."""
non_leaves: Set[str] = set()
for tag in tags:
queue = [tag]
visited: Set[str] = set()
while queue:
t = queue.pop()
for parent in impl.get(t, []):
if parent not in visited:
visited.add(parent)
if parent in tags:
non_leaves.add(parent)
queue.append(parent)
return tags - non_leaves
# ββ Semantic categorization heuristics βββββββββββββββββββββββββββββββββββββ
# Taxonomy / body plan tags that are almost always implied, not directly described
_TAXONOMY_TAGS = frozenset({
"mammal", "canid", "canine", "canis", "felid", "feline", "felis",
"ursine", "cervid", "bovid", "equid", "equine", "mustelid", "procyonid",
"reptile", "scalie", "avian", "bird", "fish", "marine", "aquatic",
"arthropod", "insect", "arachnid", "mollusk", "amphibian",
"primate", "hominid", "rodent", "lagomorph", "leporid", "chiroptera",
"marsupial", "monotreme", "pinniped", "cetacean", "ungulate",
"galliform", "gallus_(genus)", "phasianid", "passerine", "oscine",
"dinosaur", "theropod",
})
_BODY_PLAN_TAGS = frozenset({
"anthro", "feral", "biped", "quadruped", "taur", "humanoid",
"semi-anthro", "animatronic", "robot", "machine", "plushie",
"kemono",
})
_COUNT_TAGS_RE = re.compile(
r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails|heads|claws|fangs|nipples|breasts|penises|balls|teats)$"
)
_POSE_TAGS = frozenset({
"solo", "duo", "group", "standing", "sitting", "lying", "running",
"walking", "flying", "swimming", "crouching", "kneeling", "jumping",
"looking_at_viewer", "looking_away", "looking_back", "looking_up",
"looking_down", "looking_aside", "front_view", "side_view", "back_view",
"three-quarter_view", "from_above", "from_below", "worm's-eye_view",
"bird's-eye_view", "close-up", "portrait", "full-length_portrait",
"butt_pose", "spread_legs", "all_fours", "on_back", "on_side",
"hand_on_hip", "arms_crossed", "hands_behind_back",
})
def categorize_tag(tag: str, tag_type: Dict[str, int]) -> str:
"""Assign a semantic category to a missed tag."""
tid = tag_type.get(tag, -1)
tname = TYPE_ID_NAMES.get(tid, "unknown")
if tname == "species":
return "species"
if tname in ("artist", "copyright", "character", "meta"):
return tname
# General tags β subcategorize
if tag in _TAXONOMY_TAGS:
return "taxonomy"
if tag in _BODY_PLAN_TAGS:
return "body_plan"
if tag in _POSE_TAGS:
return "pose/composition"
if _COUNT_TAGS_RE.match(tag):
return "count/anatomy"
# Clothing-related
if any(kw in tag for kw in ("clothing", "clothed", "topwear", "bottomwear",
"legwear", "handwear", "headwear", "footwear",
"shirt", "pants", "shorts", "dress", "skirt",
"jacket", "coat", "hat", "boots", "shoes",
"gloves", "socks", "stockings", "belt",
"collar", "scarf", "cape", "armor", "suit",
"uniform", "costume", "outfit", "underwear",
"bra", "panties", "thigh_highs", "knee_highs")):
return "clothing"
# Color tags
if any(tag.startswith(c + "_") for c in (
"red", "blue", "green", "yellow", "orange", "purple", "pink",
"black", "white", "grey", "gray", "brown", "tan", "cream",
"gold", "silver", "teal", "cyan", "magenta",
)):
return "color/marking"
if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings":
return "color/marking"
# Hair
if "hair" in tag:
return "hair"
# Body features
if any(kw in tag for kw in ("muscle", "belly", "chest", "abs",
"breast", "butt", "tail", "wing",
"horn", "ear", "eye", "teeth", "fang",
"claw", "paw", "hoof", "snout", "muzzle",
"tongue", "fur", "scales", "feather",
"tuft", "fluff", "mane")):
return "body/anatomy"
# Gender/sex
if tag in ("male", "female", "intersex", "ambiguous_gender",
"andromorph", "gynomorph", "herm", "maleherm"):
return "gender"
# Expression/emotion
if any(kw in tag for kw in ("smile", "grin", "frown", "expression",
"blush", "angry", "happy", "sad",
"crying", "laughing", "open_mouth",
"closed_eyes", "wink")):
return "expression"
return "other_general"
# ββ Analysis βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def analyze_eval_file(eval_path: Path) -> None:
tag_type, tag_count = load_tag_db()
impl = load_implications()
samples = []
with eval_path.open("r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
if row.get("_meta"):
print(f"Eval config: min_why={row.get('min_why')}, "
f"expand_implications={row.get('expand_implications')}, "
f"n={row.get('n_samples')}, seed={row.get('seed')}")
continue
if row.get("error"):
continue
samples.append(row)
print(f"Analyzing {len(samples)} samples from {eval_path.name}\n")
# Collect all misses across samples
all_retrieval_misses: Counter = Counter() # tag β how many samples missed it in retrieval
all_selection_misses: Counter = Counter() # tag β missed in selection (retrieved but not selected)
retrieval_miss_details: Dict[str, List] = defaultdict(list) # tag β [sample_ids]
per_sample_stats = []
for s in samples:
gt = set(s["ground_truth_tags"])
retrieved = set(s["retrieved_tags"])
selected = set(s["selected_tags"])
phrases = s.get("rewrite_phrases", [])
sid = s["sample_id"]
retrieval_misses = gt - retrieved
selection_misses = (gt & retrieved) - selected # retrieved but dropped
for tag in retrieval_misses:
all_retrieval_misses[tag] += 1
retrieval_miss_details[tag].append(sid)
for tag in selection_misses:
all_selection_misses[tag] += 1
# Check if missed tags appear in rewrite phrases
phrase_text = " ".join(phrases).lower()
misses_in_phrases = {t for t in retrieval_misses
if t.replace("_", " ") in phrase_text or t in phrase_text}
per_sample_stats.append({
"id": sid,
"gt_count": len(gt),
"retrieved_count": len(retrieved),
"retrieval_misses": len(retrieval_misses),
"misses_in_phrases": len(misses_in_phrases),
"retrieval_recall": s["retrieval_recall"],
})
# ββ Report 1: Retrieval misses by category ββ
print("=" * 70)
print("RETRIEVAL GAPS β Tags in GT but never retrieved (Stage 2 misses)")
print("=" * 70)
category_misses: Dict[str, Counter] = defaultdict(Counter)
for tag, miss_count in all_retrieval_misses.items():
cat = categorize_tag(tag, tag_type)
category_misses[cat][tag] = miss_count
# Sort categories by total miss volume
cat_totals = {cat: sum(c.values()) for cat, c in category_misses.items()}
for cat in sorted(cat_totals, key=cat_totals.get, reverse=True):
tags_in_cat = category_misses[cat]
total_misses = cat_totals[cat]
unique_tags = len(tags_in_cat)
print(f"\n [{cat}] β {total_misses} total misses across {unique_tags} unique tags")
# Show top tags in this category
for tag, cnt in tags_in_cat.most_common(8):
freq = tag_count.get(tag, 0)
leaf_marker = ""
# Check if this tag is typically a leaf or implied
if tag in _TAXONOMY_TAGS or tag in _BODY_PLAN_TAGS:
leaf_marker = " (implied ancestor)"
in_db = "YES" if tag in tag_type else "NO"
print(f" {tag:40s} missed {cnt}/{len(samples)} samples "
f"freq={freq:>8,} in_db={in_db}")
# ββ Report 2: Leaf vs non-leaf misses ββ
print("\n" + "=" * 70)
print("LEAF vs IMPLIED ANCESTOR MISSES")
print("=" * 70)
all_missed_tags = set(all_retrieval_misses.keys())
leaf_misses = get_leaf_tags(all_missed_tags, impl)
ancestor_misses = all_missed_tags - leaf_misses
leaf_miss_volume = sum(all_retrieval_misses[t] for t in leaf_misses)
ancestor_miss_volume = sum(all_retrieval_misses[t] for t in ancestor_misses)
total_miss_volume = leaf_miss_volume + ancestor_miss_volume
print(f"\n Unique missed tags: {len(all_missed_tags)}")
print(f" Leaf tags: {len(leaf_misses)} ({len(leaf_misses)/max(1,len(all_missed_tags))*100:.0f}%)")
print(f" Ancestor tags: {len(ancestor_misses)} ({len(ancestor_misses)/max(1,len(all_missed_tags))*100:.0f}%)")
print(f"\n Total miss volume: {total_miss_volume}")
print(f" From leaf tags: {leaf_miss_volume} ({leaf_miss_volume/max(1,total_miss_volume)*100:.0f}%)")
print(f" From ancestors: {ancestor_miss_volume} ({ancestor_miss_volume/max(1,total_miss_volume)*100:.0f}%)")
print(f"\n Ancestor misses recoverable by implication expansion: "
f"{ancestor_miss_volume} ({ancestor_miss_volume/max(1,total_miss_volume)*100:.0f}%)")
# ββ Report 3: Most-missed leaf tags ββ
print("\n" + "=" * 70)
print("TOP MISSED LEAF TAGS (not recoverable via implications)")
print("=" * 70)
leaf_miss_counter = Counter({t: all_retrieval_misses[t] for t in leaf_misses})
for tag, cnt in leaf_miss_counter.most_common(30):
cat = categorize_tag(tag, tag_type)
freq = tag_count.get(tag, 0)
sids = retrieval_miss_details[tag]
print(f" {tag:40s} missed {cnt}/{len(samples)} cat={cat:20s} freq={freq:>8,} samples={sids}")
# ββ Report 4: Tags that were in rewrite phrases but not retrieved ββ
print("\n" + "=" * 70)
print("TAGS MENTIONED IN REWRITE PHRASES BUT NOT RETRIEVED")
print("=" * 70)
phrase_miss_counter: Counter = Counter()
for s in samples:
gt = set(s["ground_truth_tags"])
retrieved = set(s["retrieved_tags"])
phrases = s.get("rewrite_phrases", [])
phrase_text = " ".join(phrases).lower()
for tag in (gt - retrieved):
tag_text = tag.replace("_", " ")
if tag_text in phrase_text or tag in phrase_text:
phrase_miss_counter[tag] += 1
if phrase_miss_counter:
for tag, cnt in phrase_miss_counter.most_common(20):
cat = categorize_tag(tag, tag_type)
print(f" {tag:40s} mentioned but not retrieved {cnt}x cat={cat}")
else:
print(" (none found)")
# ββ Report 5: Selection drops (retrieved but not selected) ββ
print("\n" + "=" * 70)
print("SELECTION DROPS β Retrieved GT tags dropped by Stage 3")
print("=" * 70)
if all_selection_misses:
for tag, cnt in all_selection_misses.most_common(20):
cat = categorize_tag(tag, tag_type)
freq = tag_count.get(tag, 0)
print(f" {tag:40s} dropped {cnt}/{len(samples)} cat={cat:20s} freq={freq:>8,}")
else:
print(" (none β all retrieved GT tags were selected)")
# ββ Report 6: Frequency distribution of missed tags ββ
print("\n" + "=" * 70)
print("FREQUENCY DISTRIBUTION OF MISSED TAGS")
print("=" * 70)
freq_buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0,
"common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0}
for tag in all_retrieval_misses:
freq = tag_count.get(tag, -1)
if freq < 0:
freq_buckets["not_in_db"] += 1
elif freq < 100:
freq_buckets["very_rare (<100)"] += 1
elif freq < 1000:
freq_buckets["rare (100-1k)"] += 1
elif freq < 10000:
freq_buckets["medium (1k-10k)"] += 1
elif freq < 100000:
freq_buckets["common (10k-100k)"] += 1
else:
freq_buckets["very_common (100k+)"] += 1
for bucket, count in freq_buckets.items():
pct = count / max(1, len(all_retrieval_misses)) * 100
print(f" {bucket:25s} {count:4d} unique tags ({pct:.0f}%)")
# ββ Report 7: Per-sample retrieval stats ββ
print("\n" + "=" * 70)
print("PER-SAMPLE RETRIEVAL STATS")
print("=" * 70)
for stat in sorted(per_sample_stats, key=lambda x: x["retrieval_recall"]):
print(f" id={stat['id']:>8} recall={stat['retrieval_recall']:.3f} "
f"gt={stat['gt_count']:3d} retrieved={stat['retrieved_count']:3d} "
f"missed={stat['retrieval_misses']:3d} in_phrases={stat['misses_in_phrases']}")
print()
def main():
if len(sys.argv) > 1:
eval_path = Path(sys.argv[1])
else:
# Find latest eval results file
results_dir = _REPO_ROOT / "data" / "eval_results"
files = sorted(results_dir.glob("eval_*.jsonl"))
if not files:
print("No eval results found in data/eval_results/")
sys.exit(1)
eval_path = files[-1]
print(f"Using latest eval: {eval_path.name}\n")
if not eval_path.is_file():
print(f"File not found: {eval_path}")
sys.exit(1)
analyze_eval_file(eval_path)
if __name__ == "__main__":
main()
|