"""Rank candidate probe tags by informativeness before any LLM queries. This is an offline metric pass combining: - entropy / information gain from sample co-occurrence, - lift against active groups/categories, - reduced TF-IDF semantic focus against group centroids. Compact outputs (overwrite in place): - data/analysis/probe_informativeness.csv - data/analysis/probe_informativeness_summary.json """ from __future__ import annotations import csv import json import math from collections import Counter from pathlib import Path from typing import Dict, List, Set, Tuple import numpy as np from psq_rag.retrieval.state import get_tfidf_tag_vectors REPO = Path(__file__).resolve().parents[1] COUNTS_CSV = REPO / "fluffyrock_3m.csv" SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json" REGISTRY_CSV = REPO / "data" / "category_registry.csv" CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv" OUT_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" OUT_SUMMARY = REPO / "data" / "analysis" / "probe_informativeness_summary.json" MIN_COUNT = 200 MIN_PROBE_IMAGES = 5 MIN_GROUP_IMAGES = 20 SOFTMAX_TAU = 0.15 MMR_LAMBDA = 0.35 MMR_TOP_POOL = 120 MMR_K = 15 DOMAIN_JARGON = { "solo", "duo", "trio", "anthro", "feral", "gynomorph", "andromorph", "maleherm", "topwear", "bottomwear", "legwear", "handwear", "headwear", "footwear", "leporid", "canid", "canis", "felid", "felis", "equid", "haplorhine", "zero_pictured", "male/female", "male/male", "female/female", } def load_counts(path: Path) -> Dict[str, int]: out: Dict[str, int] = {} with path.open("r", encoding="utf-8", newline="") as f: reader = csv.reader(f) for row in reader: if len(row) < 3: continue try: out[row[0]] = int(row[2]) if row[2] else 0 except ValueError: out[row[0]] = 0 return out def load_image_tags(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]: rows: List[Set[str]] = [] with path.open("r", encoding="utf-8") as f: for line in f: obj = json.loads(line) raw = obj.get("tags_ground_truth_categorized", "") if not raw: continue try: d = json.loads(raw) except Exception: continue tags: Set[str] = set() if isinstance(d, dict): for vals in d.values(): if isinstance(vals, list): for t in vals: if isinstance(t, str) and counts.get(t, 0) >= min_count: tags.add(t) if tags: rows.append(tags) return rows def load_excluded_wiki_groups_from_policy(path: Path) -> Set[str]: """Read excluded wiki groups from the tag-group map file. Convention: - rows with enabled=1 and category_name starting with 'ignored_' - tag_group column contains the wiki group name to exclude. """ excluded: Set[str] = set() if not path.is_file(): return excluded with path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if (row.get("enabled") or "").strip() not in {"1", "true", "True"}: continue category = (row.get("category_name") or "").strip().lower() group = (row.get("tag_group") or "").strip() if category.startswith("ignored_") and group: excluded.add(group) return excluded def load_groups() -> Tuple[Dict[str, Set[str]], Set[str]]: groups: Dict[str, Set[str]] = {} excluded_wiki_groups = load_excluded_wiki_groups_from_policy(CATEGORY_TAG_GROUP_MAP_CSV) with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f: wiki = json.load(f) for g, tags in wiki.items(): if g in excluded_wiki_groups: continue if isinstance(tags, list): groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t} with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}: continue c = (row.get("category_name") or "").strip() t = (row.get("tag") or "").strip() if c and t: groups.setdefault(f"cat:{c}", set()).add(t) return groups, excluded_wiki_groups def needs_glossary(tag: str) -> bool: if tag in DOMAIN_JARGON: return True if "/" in tag or "(" in tag or ")" in tag: return True if any(ch.isdigit() for ch in tag): return True # Taxonomy-ish suffixes often need disambiguation in prompts. if tag.endswith("id") or tag.endswith("ine"): return True return False def infer_probe_bundle(tag: str, semantic_top_group: str, strongest_group: str) -> str: t = tag g = f"{semantic_top_group} {strongest_group}".lower() if t in {"solo", "duo", "trio", "group", "zero_pictured"}: return "count_cardinality" if t in {"anthro", "feral", "humanoid", "biped", "quadruped"}: return "body_type_presence" if t in {"clothed", "clothing", "topless", "bottomless", "nude", "barefoot", "topwear", "bottomwear"}: return "clothing_state" if any(x in t for x in ["canid", "canis", "felid", "felis", "equid", "leporid", "species", "mammal", "bird", "bear", "unicorn", "reptile", "dragon"]): return "species_taxonomy" if any(x in t for x in ["breast", "thigh", "hips", "curvy", "muscular", "overweight", "chubby", "butt"]): return "body_shape_breasts" if any(x in t for x in ["look", "gaze", "eyes", "smile", "blush", "open_mouth", "eyes_closed"]): return "gaze_expression" if t in {"text", "dialogue", "<3"} or any(x in t for x in ["text", "dialogue", "logo", "symbol"]): return "text_symbols" if any(x in t for x in ["background", "outside", "inside", "indoors", "outdoors", "standing", "sitting"]): return "scene_pose" if "cat:clothing" in g or "wiki:clothes" in g: return "clothing_state" if "cat:count" in g: return "count_cardinality" return "other" def entropy_binary(p: float) -> float: p = min(max(p, 1e-12), 1 - 1e-12) return -(p * math.log2(p) + (1 - p) * math.log2(1 - p)) def softmax(x: np.ndarray, tau: float) -> np.ndarray: z = x / max(tau, 1e-6) z = z - np.max(z) e = np.exp(z) return e / max(np.sum(e), 1e-12) def binary_mi(a_idx: Set[int], b_idx: Set[int], n: int) -> float: # MI for Bernoulli variables in bits. n11 = len(a_idx & b_idx) n10 = len(a_idx - b_idx) n01 = len(b_idx - a_idx) n00 = n - n11 - n10 - n01 probs = { (1, 1): n11 / n, (1, 0): n10 / n, (0, 1): n01 / n, (0, 0): n00 / n, } pa = (n11 + n10) / n pb = (n11 + n01) / n mi = 0.0 for (a, b), p in probs.items(): if p <= 0: continue qa = pa if a == 1 else (1 - pa) qb = pb if b == 1 else (1 - pb) mi += p * math.log2(p / max(qa * qb, 1e-12)) return max(mi, 0.0) def main() -> None: counts = load_counts(COUNTS_CSV) image_tags = load_image_tags(SAMPLE_JSONL, counts, MIN_COUNT) n_images = len(image_tags) if n_images == 0: raise RuntimeError("No image tags loaded.") groups_all, excluded_wiki_groups = load_groups() probe_to_images: Dict[str, Set[int]] = {} tag_occ = Counter() for i, tags in enumerate(image_tags): for t in tags: tag_occ[t] += 1 probe_to_images.setdefault(t, set()).add(i) group_to_images: Dict[str, Set[int]] = {} for g, members in groups_all.items(): idxs: Set[int] = set() for i, tags in enumerate(image_tags): if tags & members: idxs.add(i) if len(idxs) >= MIN_GROUP_IMAGES: group_to_images[g] = idxs active_groups = sorted(group_to_images.keys()) if not active_groups: raise RuntimeError("No active groups after MIN_GROUP_IMAGES filter.") # Semantic centroids for active groups. vec = get_tfidf_tag_vectors() mat = vec["reduced_matrix_norm"] tag_to_row = vec["tag_to_row_index"] group_centroids: Dict[str, np.ndarray] = {} for g in active_groups: rows = [tag_to_row[t] for t in groups_all[g] if t in tag_to_row] if len(rows) < 2: continue c = mat[rows].mean(axis=0) n = np.linalg.norm(c) if n > 0: group_centroids[g] = c / n semantic_groups = sorted(group_centroids.keys()) C = np.stack([group_centroids[g] for g in semantic_groups], axis=0) if semantic_groups else None baseline_group_probs = {g: len(group_to_images[g]) / n_images for g in active_groups} baseline_top5_mass = sum(sorted(baseline_group_probs.values(), reverse=True)[:5]) rows_out: List[Dict[str, str]] = [] probe_scores: Dict[str, float] = {} for p, p_idxs in probe_to_images.items(): if len(p_idxs) < MIN_PROBE_IMAGES: continue q = len(p_idxs) / n_images if q <= 0.0 or q >= 1.0: continue ig_sum = 0.0 ig_vals = [] mean_abs_log_lift = 0.0 lifts: Dict[str, float] = {} p1_group_probs: Dict[str, float] = {} for g in active_groups: g_idxs = group_to_images[g] pg = len(g_idxs) / n_images pg1 = len(p_idxs & g_idxs) / len(p_idxs) p0 = n_images - len(p_idxs) pg0 = len((set(range(n_images)) - p_idxs) & g_idxs) / p0 if p0 > 0 else pg ig = entropy_binary(pg) - (q * entropy_binary(pg1) + (1 - q) * entropy_binary(pg0)) ig = max(ig, 0.0) ig_vals.append(ig) ig_sum += ig lift = (pg1 + 1e-9) / (pg + 1e-9) lifts[g] = lift p1_group_probs[g] = pg1 mean_abs_log_lift += abs(math.log2(lift + 1e-12)) mean_abs_log_lift /= len(active_groups) ig_mean = float(np.mean(ig_vals)) if ig_vals else 0.0 top5_mass_p1 = sum(sorted(p1_group_probs.values(), reverse=True)[:5]) delta_top5_mass = top5_mass_p1 - baseline_top5_mass strongest_group = max(lifts.items(), key=lambda kv: abs(math.log2(kv[1] + 1e-12))) strongest_group_name = strongest_group[0] strongest_group_lift = strongest_group[1] semantic_top_group = "" semantic_margin = 0.0 semantic_entropy_norm = 1.0 if C is not None and p in tag_to_row: sims = C @ mat[tag_to_row[p]] order = np.argsort(sims)[::-1] i1 = int(order[0]) i2 = int(order[1]) if len(order) > 1 else i1 semantic_top_group = semantic_groups[i1] semantic_margin = float(sims[i1] - sims[i2]) probs = softmax(sims, SOFTMAX_TAU) h = -float(np.sum(probs * np.log2(np.maximum(probs, 1e-12)))) semantic_entropy_norm = h / math.log2(len(probs)) if len(probs) > 1 else 0.0 prevalence_balance = math.sqrt(q * (1 - q)) focus = max(0.0, 1.0 - semantic_entropy_norm) combined_score = ig_sum * prevalence_balance * (0.5 + 0.5 * focus) probe_scores[p] = combined_score rows_out.append( { "tag": p, "sample_occurrences": str(len(p_idxs)), "fluffyrock_count": str(counts.get(p, 0)), "prevalence": f"{q:.6f}", "ig_sum_bits": f"{ig_sum:.6f}", "ig_mean_bits": f"{ig_mean:.6f}", "delta_top5_mass": f"{delta_top5_mass:.6f}", "mean_abs_log2_lift": f"{mean_abs_log_lift:.6f}", "semantic_top_group": semantic_top_group, "semantic_margin": f"{semantic_margin:.6f}", "semantic_entropy_norm": f"{semantic_entropy_norm:.6f}", "strongest_group_by_lift": strongest_group_name, "strongest_group_lift": f"{strongest_group_lift:.6f}", "suggested_probe_bundle": infer_probe_bundle(p, semantic_top_group, strongest_group_name), "needs_glossary": "1" if needs_glossary(p) else "0", "combined_score": f"{combined_score:.6f}", } ) # Add an actionability score that downweights very common probes and favors # probes that noticeably reshape top-group mass. for r in rows_out: q = float(r["prevalence"]) ig = float(r["ig_sum_bits"]) delta_top5 = max(0.0, float(r["delta_top5_mass"])) semantic_focus = max(0.0, 1.0 - float(r["semantic_entropy_norm"])) prevalence_penalty = max(0.0, 1.0 - abs(2 * q - 1.0)) actionable_score = ig * prevalence_penalty * delta_top5 * (0.5 + 0.5 * semantic_focus) r["actionable_score"] = f"{actionable_score:.6f}" rows_out.sort(key=lambda r: float(r["combined_score"]), reverse=True) # Diversified shortlist via MMR-like greedy on top pool. top_pool = [r["tag"] for r in rows_out[:MMR_TOP_POOL]] selected: List[str] = [] while len(selected) < MMR_K and top_pool: best_tag = None best_val = -1e9 for t in top_pool: rel = probe_scores.get(t, 0.0) if not selected: val = rel else: red = float(np.mean([binary_mi(probe_to_images[t], probe_to_images[s], n_images) for s in selected])) val = rel - MMR_LAMBDA * red if val > best_val: best_val = val best_tag = t if best_tag is None: break selected.append(best_tag) top_pool.remove(best_tag) OUT_CSV.parent.mkdir(parents=True, exist_ok=True) with OUT_CSV.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "tag", "sample_occurrences", "fluffyrock_count", "prevalence", "ig_sum_bits", "ig_mean_bits", "delta_top5_mass", "mean_abs_log2_lift", "semantic_top_group", "semantic_margin", "semantic_entropy_norm", "strongest_group_by_lift", "strongest_group_lift", "suggested_probe_bundle", "needs_glossary", "combined_score", "actionable_score", ], ) writer.writeheader() writer.writerows(rows_out) # Aggregate bundle-level utility using top actionable tags per bundle. by_bundle: Dict[str, List[Dict[str, str]]] = {} for r in rows_out: by_bundle.setdefault(r["suggested_probe_bundle"], []).append(r) bundle_scores = [] for b, items in by_bundle.items(): items_sorted = sorted(items, key=lambda x: float(x["actionable_score"]), reverse=True) top_items = items_sorted[:5] score = sum(float(x["actionable_score"]) for x in top_items) glossary_rate = sum(1 for x in top_items if x["needs_glossary"] == "1") / len(top_items) if top_items else 0.0 bundle_scores.append( { "bundle": b, "bundle_score_top5_actionable": round(score, 6), "top_tags": [x["tag"] for x in top_items], "glossary_rate_top5": round(glossary_rate, 3), } ) bundle_scores.sort(key=lambda x: x["bundle_score_top5_actionable"], reverse=True) top_actionable = sorted(rows_out, key=lambda r: float(r["actionable_score"]), reverse=True) top_mid_prevalence = [ r for r in top_actionable if 0.03 <= float(r["prevalence"]) <= 0.35 ][:40] summary = { "config": { "min_count": MIN_COUNT, "min_probe_images": MIN_PROBE_IMAGES, "min_group_images": MIN_GROUP_IMAGES, "softmax_tau": SOFTMAX_TAU, "mmr_lambda": MMR_LAMBDA, "mmr_top_pool": MMR_TOP_POOL, "mmr_k": MMR_K, }, "n_images": n_images, "n_candidate_probes": len(rows_out), "n_active_groups": len(active_groups), "excluded_wiki_groups": sorted(excluded_wiki_groups), "top_probes_by_combined_score": rows_out[:25], "top_probes_by_actionable_score": top_actionable[:25], "top_actionable_mid_prevalence_for_manual_review": top_mid_prevalence, "bundle_scores": bundle_scores[:20], "diversified_probe_shortlist": selected, "outputs": { "csv": str(OUT_CSV), "summary_json": str(OUT_SUMMARY), }, } with OUT_SUMMARY.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"Images: {n_images}") print(f"Active groups: {len(active_groups)}") print(f"Candidate probes: {len(rows_out)}") print(f"Top probes: {[r['tag'] for r in rows_out[:10]]}") print(f"Diversified shortlist: {selected}") print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}") if __name__ == "__main__": main()