Spaces:
Running
Running
| """Deep analysis of retrieval gaps: what kinds of tags are missing after Stage 2. | |
| Reads the latest eval results JSONL and categorizes missed tags by: | |
| - Tag type (general, species, character, meta, etc.) | |
| - Whether the miss is a leaf tag or an implied ancestor | |
| - Semantic category (taxonomy, body/anatomy, clothing, color, pose, etc.) | |
| - Whether the tag appears in the rewrite phrases | |
| - Frequency in the tag database (common vs rare tags) | |
| Usage: | |
| python scripts/analyze_retrieval_gaps.py [path/to/eval_results.jsonl] | |
| If no path given, uses the latest file in data/eval_results/. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import re | |
| import sys | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Set, Tuple | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| # ββ Load tag database ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TYPE_ID_NAMES = {0: "general", 1: "artist", 3: "copyright", 4: "character", 5: "species", 7: "meta"} | |
| def load_tag_db() -> Tuple[Dict[str, int], Dict[str, int]]: | |
| """Return (tagβtype_id, tagβcount) from fluffyrock_3m.csv.""" | |
| tag_type: Dict[str, int] = {} | |
| tag_count: Dict[str, int] = {} | |
| csv_path = _REPO_ROOT / "fluffyrock_3m.csv" | |
| with csv_path.open("r", encoding="utf-8") as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| if len(row) < 3: | |
| continue | |
| tag = row[0].strip() | |
| try: | |
| tid = int(row[1]) if row[1].strip() else -1 | |
| except ValueError: | |
| tid = -1 | |
| try: | |
| count = int(row[2]) if row[2].strip() else 0 | |
| except ValueError: | |
| count = 0 | |
| tag_type[tag] = tid | |
| tag_count[tag] = count | |
| return tag_type, tag_count | |
| def load_implications() -> Dict[str, List[str]]: | |
| """Return antecedent β [consequent, ...] from tag implications CSV.""" | |
| impl: Dict[str, List[str]] = defaultdict(list) | |
| csv_path = _REPO_ROOT / "tag_implications-2023-07-20.csv" | |
| if not csv_path.is_file(): | |
| return impl | |
| with csv_path.open("r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if row.get("status") != "active": | |
| continue | |
| ant = row["antecedent_name"].strip() | |
| con = row["consequent_name"].strip() | |
| impl[ant].append(con) | |
| return dict(impl) | |
| def get_leaf_tags(tags: Set[str], impl: Dict[str, List[str]]) -> Set[str]: | |
| """Tags not implied by any other tag in the set.""" | |
| non_leaves: Set[str] = set() | |
| for tag in tags: | |
| queue = [tag] | |
| visited: Set[str] = set() | |
| while queue: | |
| t = queue.pop() | |
| for parent in impl.get(t, []): | |
| if parent not in visited: | |
| visited.add(parent) | |
| if parent in tags: | |
| non_leaves.add(parent) | |
| queue.append(parent) | |
| return tags - non_leaves | |
| # ββ Semantic categorization heuristics βββββββββββββββββββββββββββββββββββββ | |
| # Taxonomy / body plan tags that are almost always implied, not directly described | |
| _TAXONOMY_TAGS = frozenset({ | |
| "mammal", "canid", "canine", "canis", "felid", "feline", "felis", | |
| "ursine", "cervid", "bovid", "equid", "equine", "mustelid", "procyonid", | |
| "reptile", "scalie", "avian", "bird", "fish", "marine", "aquatic", | |
| "arthropod", "insect", "arachnid", "mollusk", "amphibian", | |
| "primate", "hominid", "rodent", "lagomorph", "leporid", "chiroptera", | |
| "marsupial", "monotreme", "pinniped", "cetacean", "ungulate", | |
| "galliform", "gallus_(genus)", "phasianid", "passerine", "oscine", | |
| "dinosaur", "theropod", | |
| }) | |
| _BODY_PLAN_TAGS = frozenset({ | |
| "anthro", "feral", "biped", "quadruped", "taur", "humanoid", | |
| "semi-anthro", "animatronic", "robot", "machine", "plushie", | |
| "kemono", | |
| }) | |
| _COUNT_TAGS_RE = re.compile( | |
| r"^\d+_(fingers|toes|horns|arms|legs|eyes|ears|wings|tails|heads|claws|fangs|nipples|breasts|penises|balls|teats)$" | |
| ) | |
| _POSE_TAGS = frozenset({ | |
| "solo", "duo", "group", "standing", "sitting", "lying", "running", | |
| "walking", "flying", "swimming", "crouching", "kneeling", "jumping", | |
| "looking_at_viewer", "looking_away", "looking_back", "looking_up", | |
| "looking_down", "looking_aside", "front_view", "side_view", "back_view", | |
| "three-quarter_view", "from_above", "from_below", "worm's-eye_view", | |
| "bird's-eye_view", "close-up", "portrait", "full-length_portrait", | |
| "butt_pose", "spread_legs", "all_fours", "on_back", "on_side", | |
| "hand_on_hip", "arms_crossed", "hands_behind_back", | |
| }) | |
| def categorize_tag(tag: str, tag_type: Dict[str, int]) -> str: | |
| """Assign a semantic category to a missed tag.""" | |
| tid = tag_type.get(tag, -1) | |
| tname = TYPE_ID_NAMES.get(tid, "unknown") | |
| if tname == "species": | |
| return "species" | |
| if tname in ("artist", "copyright", "character", "meta"): | |
| return tname | |
| # General tags β subcategorize | |
| if tag in _TAXONOMY_TAGS: | |
| return "taxonomy" | |
| if tag in _BODY_PLAN_TAGS: | |
| return "body_plan" | |
| if tag in _POSE_TAGS: | |
| return "pose/composition" | |
| if _COUNT_TAGS_RE.match(tag): | |
| return "count/anatomy" | |
| # Clothing-related | |
| if any(kw in tag for kw in ("clothing", "clothed", "topwear", "bottomwear", | |
| "legwear", "handwear", "headwear", "footwear", | |
| "shirt", "pants", "shorts", "dress", "skirt", | |
| "jacket", "coat", "hat", "boots", "shoes", | |
| "gloves", "socks", "stockings", "belt", | |
| "collar", "scarf", "cape", "armor", "suit", | |
| "uniform", "costume", "outfit", "underwear", | |
| "bra", "panties", "thigh_highs", "knee_highs")): | |
| return "clothing" | |
| # Color tags | |
| if any(tag.startswith(c + "_") for c in ( | |
| "red", "blue", "green", "yellow", "orange", "purple", "pink", | |
| "black", "white", "grey", "gray", "brown", "tan", "cream", | |
| "gold", "silver", "teal", "cyan", "magenta", | |
| )): | |
| return "color/marking" | |
| if tag.endswith("_coloring") or tag.endswith("_markings") or tag == "markings": | |
| return "color/marking" | |
| # Hair | |
| if "hair" in tag: | |
| return "hair" | |
| # Body features | |
| if any(kw in tag for kw in ("muscle", "belly", "chest", "abs", | |
| "breast", "butt", "tail", "wing", | |
| "horn", "ear", "eye", "teeth", "fang", | |
| "claw", "paw", "hoof", "snout", "muzzle", | |
| "tongue", "fur", "scales", "feather", | |
| "tuft", "fluff", "mane")): | |
| return "body/anatomy" | |
| # Gender/sex | |
| if tag in ("male", "female", "intersex", "ambiguous_gender", | |
| "andromorph", "gynomorph", "herm", "maleherm"): | |
| return "gender" | |
| # Expression/emotion | |
| if any(kw in tag for kw in ("smile", "grin", "frown", "expression", | |
| "blush", "angry", "happy", "sad", | |
| "crying", "laughing", "open_mouth", | |
| "closed_eyes", "wink")): | |
| return "expression" | |
| return "other_general" | |
| # ββ Analysis βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_eval_file(eval_path: Path) -> None: | |
| tag_type, tag_count = load_tag_db() | |
| impl = load_implications() | |
| samples = [] | |
| with eval_path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| row = json.loads(line) | |
| if row.get("_meta"): | |
| print(f"Eval config: min_why={row.get('min_why')}, " | |
| f"expand_implications={row.get('expand_implications')}, " | |
| f"n={row.get('n_samples')}, seed={row.get('seed')}") | |
| continue | |
| if row.get("error"): | |
| continue | |
| samples.append(row) | |
| print(f"Analyzing {len(samples)} samples from {eval_path.name}\n") | |
| # Collect all misses across samples | |
| all_retrieval_misses: Counter = Counter() # tag β how many samples missed it in retrieval | |
| all_selection_misses: Counter = Counter() # tag β missed in selection (retrieved but not selected) | |
| retrieval_miss_details: Dict[str, List] = defaultdict(list) # tag β [sample_ids] | |
| per_sample_stats = [] | |
| for s in samples: | |
| gt = set(s["ground_truth_tags"]) | |
| retrieved = set(s["retrieved_tags"]) | |
| selected = set(s["selected_tags"]) | |
| phrases = s.get("rewrite_phrases", []) | |
| sid = s["sample_id"] | |
| retrieval_misses = gt - retrieved | |
| selection_misses = (gt & retrieved) - selected # retrieved but dropped | |
| for tag in retrieval_misses: | |
| all_retrieval_misses[tag] += 1 | |
| retrieval_miss_details[tag].append(sid) | |
| for tag in selection_misses: | |
| all_selection_misses[tag] += 1 | |
| # Check if missed tags appear in rewrite phrases | |
| phrase_text = " ".join(phrases).lower() | |
| misses_in_phrases = {t for t in retrieval_misses | |
| if t.replace("_", " ") in phrase_text or t in phrase_text} | |
| per_sample_stats.append({ | |
| "id": sid, | |
| "gt_count": len(gt), | |
| "retrieved_count": len(retrieved), | |
| "retrieval_misses": len(retrieval_misses), | |
| "misses_in_phrases": len(misses_in_phrases), | |
| "retrieval_recall": s["retrieval_recall"], | |
| }) | |
| # ββ Report 1: Retrieval misses by category ββ | |
| print("=" * 70) | |
| print("RETRIEVAL GAPS β Tags in GT but never retrieved (Stage 2 misses)") | |
| print("=" * 70) | |
| category_misses: Dict[str, Counter] = defaultdict(Counter) | |
| for tag, miss_count in all_retrieval_misses.items(): | |
| cat = categorize_tag(tag, tag_type) | |
| category_misses[cat][tag] = miss_count | |
| # Sort categories by total miss volume | |
| cat_totals = {cat: sum(c.values()) for cat, c in category_misses.items()} | |
| for cat in sorted(cat_totals, key=cat_totals.get, reverse=True): | |
| tags_in_cat = category_misses[cat] | |
| total_misses = cat_totals[cat] | |
| unique_tags = len(tags_in_cat) | |
| print(f"\n [{cat}] β {total_misses} total misses across {unique_tags} unique tags") | |
| # Show top tags in this category | |
| for tag, cnt in tags_in_cat.most_common(8): | |
| freq = tag_count.get(tag, 0) | |
| leaf_marker = "" | |
| # Check if this tag is typically a leaf or implied | |
| if tag in _TAXONOMY_TAGS or tag in _BODY_PLAN_TAGS: | |
| leaf_marker = " (implied ancestor)" | |
| in_db = "YES" if tag in tag_type else "NO" | |
| print(f" {tag:40s} missed {cnt}/{len(samples)} samples " | |
| f"freq={freq:>8,} in_db={in_db}") | |
| # ββ Report 2: Leaf vs non-leaf misses ββ | |
| print("\n" + "=" * 70) | |
| print("LEAF vs IMPLIED ANCESTOR MISSES") | |
| print("=" * 70) | |
| all_missed_tags = set(all_retrieval_misses.keys()) | |
| leaf_misses = get_leaf_tags(all_missed_tags, impl) | |
| ancestor_misses = all_missed_tags - leaf_misses | |
| leaf_miss_volume = sum(all_retrieval_misses[t] for t in leaf_misses) | |
| ancestor_miss_volume = sum(all_retrieval_misses[t] for t in ancestor_misses) | |
| total_miss_volume = leaf_miss_volume + ancestor_miss_volume | |
| print(f"\n Unique missed tags: {len(all_missed_tags)}") | |
| print(f" Leaf tags: {len(leaf_misses)} ({len(leaf_misses)/max(1,len(all_missed_tags))*100:.0f}%)") | |
| print(f" Ancestor tags: {len(ancestor_misses)} ({len(ancestor_misses)/max(1,len(all_missed_tags))*100:.0f}%)") | |
| print(f"\n Total miss volume: {total_miss_volume}") | |
| print(f" From leaf tags: {leaf_miss_volume} ({leaf_miss_volume/max(1,total_miss_volume)*100:.0f}%)") | |
| print(f" From ancestors: {ancestor_miss_volume} ({ancestor_miss_volume/max(1,total_miss_volume)*100:.0f}%)") | |
| print(f"\n Ancestor misses recoverable by implication expansion: " | |
| f"{ancestor_miss_volume} ({ancestor_miss_volume/max(1,total_miss_volume)*100:.0f}%)") | |
| # ββ Report 3: Most-missed leaf tags ββ | |
| print("\n" + "=" * 70) | |
| print("TOP MISSED LEAF TAGS (not recoverable via implications)") | |
| print("=" * 70) | |
| leaf_miss_counter = Counter({t: all_retrieval_misses[t] for t in leaf_misses}) | |
| for tag, cnt in leaf_miss_counter.most_common(30): | |
| cat = categorize_tag(tag, tag_type) | |
| freq = tag_count.get(tag, 0) | |
| sids = retrieval_miss_details[tag] | |
| print(f" {tag:40s} missed {cnt}/{len(samples)} cat={cat:20s} freq={freq:>8,} samples={sids}") | |
| # ββ Report 4: Tags that were in rewrite phrases but not retrieved ββ | |
| print("\n" + "=" * 70) | |
| print("TAGS MENTIONED IN REWRITE PHRASES BUT NOT RETRIEVED") | |
| print("=" * 70) | |
| phrase_miss_counter: Counter = Counter() | |
| for s in samples: | |
| gt = set(s["ground_truth_tags"]) | |
| retrieved = set(s["retrieved_tags"]) | |
| phrases = s.get("rewrite_phrases", []) | |
| phrase_text = " ".join(phrases).lower() | |
| for tag in (gt - retrieved): | |
| tag_text = tag.replace("_", " ") | |
| if tag_text in phrase_text or tag in phrase_text: | |
| phrase_miss_counter[tag] += 1 | |
| if phrase_miss_counter: | |
| for tag, cnt in phrase_miss_counter.most_common(20): | |
| cat = categorize_tag(tag, tag_type) | |
| print(f" {tag:40s} mentioned but not retrieved {cnt}x cat={cat}") | |
| else: | |
| print(" (none found)") | |
| # ββ Report 5: Selection drops (retrieved but not selected) ββ | |
| print("\n" + "=" * 70) | |
| print("SELECTION DROPS β Retrieved GT tags dropped by Stage 3") | |
| print("=" * 70) | |
| if all_selection_misses: | |
| for tag, cnt in all_selection_misses.most_common(20): | |
| cat = categorize_tag(tag, tag_type) | |
| freq = tag_count.get(tag, 0) | |
| print(f" {tag:40s} dropped {cnt}/{len(samples)} cat={cat:20s} freq={freq:>8,}") | |
| else: | |
| print(" (none β all retrieved GT tags were selected)") | |
| # ββ Report 6: Frequency distribution of missed tags ββ | |
| print("\n" + "=" * 70) | |
| print("FREQUENCY DISTRIBUTION OF MISSED TAGS") | |
| print("=" * 70) | |
| freq_buckets = {"very_rare (<100)": 0, "rare (100-1k)": 0, "medium (1k-10k)": 0, | |
| "common (10k-100k)": 0, "very_common (100k+)": 0, "not_in_db": 0} | |
| for tag in all_retrieval_misses: | |
| freq = tag_count.get(tag, -1) | |
| if freq < 0: | |
| freq_buckets["not_in_db"] += 1 | |
| elif freq < 100: | |
| freq_buckets["very_rare (<100)"] += 1 | |
| elif freq < 1000: | |
| freq_buckets["rare (100-1k)"] += 1 | |
| elif freq < 10000: | |
| freq_buckets["medium (1k-10k)"] += 1 | |
| elif freq < 100000: | |
| freq_buckets["common (10k-100k)"] += 1 | |
| else: | |
| freq_buckets["very_common (100k+)"] += 1 | |
| for bucket, count in freq_buckets.items(): | |
| pct = count / max(1, len(all_retrieval_misses)) * 100 | |
| print(f" {bucket:25s} {count:4d} unique tags ({pct:.0f}%)") | |
| # ββ Report 7: Per-sample retrieval stats ββ | |
| print("\n" + "=" * 70) | |
| print("PER-SAMPLE RETRIEVAL STATS") | |
| print("=" * 70) | |
| for stat in sorted(per_sample_stats, key=lambda x: x["retrieval_recall"]): | |
| print(f" id={stat['id']:>8} recall={stat['retrieval_recall']:.3f} " | |
| f"gt={stat['gt_count']:3d} retrieved={stat['retrieved_count']:3d} " | |
| f"missed={stat['retrieval_misses']:3d} in_phrases={stat['misses_in_phrases']}") | |
| print() | |
| def main(): | |
| if len(sys.argv) > 1: | |
| eval_path = Path(sys.argv[1]) | |
| else: | |
| # Find latest eval results file | |
| results_dir = _REPO_ROOT / "data" / "eval_results" | |
| files = sorted(results_dir.glob("eval_*.jsonl")) | |
| if not files: | |
| print("No eval results found in data/eval_results/") | |
| sys.exit(1) | |
| eval_path = files[-1] | |
| print(f"Using latest eval: {eval_path.name}\n") | |
| if not eval_path.is_file(): | |
| print(f"File not found: {eval_path}") | |
| sys.exit(1) | |
| analyze_eval_file(eval_path) | |
| if __name__ == "__main__": | |
| main() | |