Spaces:
Running
Running
| """Rank candidate probe tags by informativeness before any LLM queries. | |
| This is an offline metric pass combining: | |
| - entropy / information gain from sample co-occurrence, | |
| - lift against active groups/categories, | |
| - reduced TF-IDF semantic focus against group centroids. | |
| Compact outputs (overwrite in place): | |
| - data/analysis/probe_informativeness.csv | |
| - data/analysis/probe_informativeness_summary.json | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import math | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| import numpy as np | |
| from psq_rag.retrieval.state import get_tfidf_tag_vectors | |
| REPO = Path(__file__).resolve().parents[1] | |
| COUNTS_CSV = REPO / "fluffyrock_3m.csv" | |
| SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" | |
| WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json" | |
| REGISTRY_CSV = REPO / "data" / "category_registry.csv" | |
| CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv" | |
| OUT_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" | |
| OUT_SUMMARY = REPO / "data" / "analysis" / "probe_informativeness_summary.json" | |
| MIN_COUNT = 200 | |
| MIN_PROBE_IMAGES = 5 | |
| MIN_GROUP_IMAGES = 20 | |
| SOFTMAX_TAU = 0.15 | |
| MMR_LAMBDA = 0.35 | |
| MMR_TOP_POOL = 120 | |
| MMR_K = 15 | |
| DOMAIN_JARGON = { | |
| "solo", "duo", "trio", "anthro", "feral", "gynomorph", "andromorph", "maleherm", | |
| "topwear", "bottomwear", "legwear", "handwear", "headwear", "footwear", | |
| "leporid", "canid", "canis", "felid", "felis", "equid", "haplorhine", | |
| "zero_pictured", "male/female", "male/male", "female/female", | |
| } | |
| def load_counts(path: Path) -> Dict[str, int]: | |
| out: Dict[str, int] = {} | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| if len(row) < 3: | |
| continue | |
| try: | |
| out[row[0]] = int(row[2]) if row[2] else 0 | |
| except ValueError: | |
| out[row[0]] = 0 | |
| return out | |
| def load_image_tags(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]: | |
| rows: List[Set[str]] = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| obj = json.loads(line) | |
| raw = obj.get("tags_ground_truth_categorized", "") | |
| if not raw: | |
| continue | |
| try: | |
| d = json.loads(raw) | |
| except Exception: | |
| continue | |
| tags: Set[str] = set() | |
| if isinstance(d, dict): | |
| for vals in d.values(): | |
| if isinstance(vals, list): | |
| for t in vals: | |
| if isinstance(t, str) and counts.get(t, 0) >= min_count: | |
| tags.add(t) | |
| if tags: | |
| rows.append(tags) | |
| return rows | |
| def load_excluded_wiki_groups_from_policy(path: Path) -> Set[str]: | |
| """Read excluded wiki groups from the tag-group map file. | |
| Convention: | |
| - rows with enabled=1 and category_name starting with 'ignored_' | |
| - tag_group column contains the wiki group name to exclude. | |
| """ | |
| excluded: Set[str] = set() | |
| if not path.is_file(): | |
| return excluded | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if (row.get("enabled") or "").strip() not in {"1", "true", "True"}: | |
| continue | |
| category = (row.get("category_name") or "").strip().lower() | |
| group = (row.get("tag_group") or "").strip() | |
| if category.startswith("ignored_") and group: | |
| excluded.add(group) | |
| return excluded | |
| def load_groups() -> Tuple[Dict[str, Set[str]], Set[str]]: | |
| groups: Dict[str, Set[str]] = {} | |
| excluded_wiki_groups = load_excluded_wiki_groups_from_policy(CATEGORY_TAG_GROUP_MAP_CSV) | |
| with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f: | |
| wiki = json.load(f) | |
| for g, tags in wiki.items(): | |
| if g in excluded_wiki_groups: | |
| continue | |
| if isinstance(tags, list): | |
| groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t} | |
| with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}: | |
| continue | |
| c = (row.get("category_name") or "").strip() | |
| t = (row.get("tag") or "").strip() | |
| if c and t: | |
| groups.setdefault(f"cat:{c}", set()).add(t) | |
| return groups, excluded_wiki_groups | |
| def needs_glossary(tag: str) -> bool: | |
| if tag in DOMAIN_JARGON: | |
| return True | |
| if "/" in tag or "(" in tag or ")" in tag: | |
| return True | |
| if any(ch.isdigit() for ch in tag): | |
| return True | |
| # Taxonomy-ish suffixes often need disambiguation in prompts. | |
| if tag.endswith("id") or tag.endswith("ine"): | |
| return True | |
| return False | |
| def infer_probe_bundle(tag: str, semantic_top_group: str, strongest_group: str) -> str: | |
| t = tag | |
| g = f"{semantic_top_group} {strongest_group}".lower() | |
| if t in {"solo", "duo", "trio", "group", "zero_pictured"}: | |
| return "count_cardinality" | |
| if t in {"anthro", "feral", "humanoid", "biped", "quadruped"}: | |
| return "body_type_presence" | |
| if t in {"clothed", "clothing", "topless", "bottomless", "nude", "barefoot", "topwear", "bottomwear"}: | |
| return "clothing_state" | |
| if any(x in t for x in ["canid", "canis", "felid", "felis", "equid", "leporid", "species", "mammal", "bird", "bear", "unicorn", "reptile", "dragon"]): | |
| return "species_taxonomy" | |
| if any(x in t for x in ["breast", "thigh", "hips", "curvy", "muscular", "overweight", "chubby", "butt"]): | |
| return "body_shape_breasts" | |
| if any(x in t for x in ["look", "gaze", "eyes", "smile", "blush", "open_mouth", "eyes_closed"]): | |
| return "gaze_expression" | |
| if t in {"text", "dialogue", "<3"} or any(x in t for x in ["text", "dialogue", "logo", "symbol"]): | |
| return "text_symbols" | |
| if any(x in t for x in ["background", "outside", "inside", "indoors", "outdoors", "standing", "sitting"]): | |
| return "scene_pose" | |
| if "cat:clothing" in g or "wiki:clothes" in g: | |
| return "clothing_state" | |
| if "cat:count" in g: | |
| return "count_cardinality" | |
| return "other" | |
| def entropy_binary(p: float) -> float: | |
| p = min(max(p, 1e-12), 1 - 1e-12) | |
| return -(p * math.log2(p) + (1 - p) * math.log2(1 - p)) | |
| def softmax(x: np.ndarray, tau: float) -> np.ndarray: | |
| z = x / max(tau, 1e-6) | |
| z = z - np.max(z) | |
| e = np.exp(z) | |
| return e / max(np.sum(e), 1e-12) | |
| def binary_mi(a_idx: Set[int], b_idx: Set[int], n: int) -> float: | |
| # MI for Bernoulli variables in bits. | |
| n11 = len(a_idx & b_idx) | |
| n10 = len(a_idx - b_idx) | |
| n01 = len(b_idx - a_idx) | |
| n00 = n - n11 - n10 - n01 | |
| probs = { | |
| (1, 1): n11 / n, | |
| (1, 0): n10 / n, | |
| (0, 1): n01 / n, | |
| (0, 0): n00 / n, | |
| } | |
| pa = (n11 + n10) / n | |
| pb = (n11 + n01) / n | |
| mi = 0.0 | |
| for (a, b), p in probs.items(): | |
| if p <= 0: | |
| continue | |
| qa = pa if a == 1 else (1 - pa) | |
| qb = pb if b == 1 else (1 - pb) | |
| mi += p * math.log2(p / max(qa * qb, 1e-12)) | |
| return max(mi, 0.0) | |
| def main() -> None: | |
| counts = load_counts(COUNTS_CSV) | |
| image_tags = load_image_tags(SAMPLE_JSONL, counts, MIN_COUNT) | |
| n_images = len(image_tags) | |
| if n_images == 0: | |
| raise RuntimeError("No image tags loaded.") | |
| groups_all, excluded_wiki_groups = load_groups() | |
| probe_to_images: Dict[str, Set[int]] = {} | |
| tag_occ = Counter() | |
| for i, tags in enumerate(image_tags): | |
| for t in tags: | |
| tag_occ[t] += 1 | |
| probe_to_images.setdefault(t, set()).add(i) | |
| group_to_images: Dict[str, Set[int]] = {} | |
| for g, members in groups_all.items(): | |
| idxs: Set[int] = set() | |
| for i, tags in enumerate(image_tags): | |
| if tags & members: | |
| idxs.add(i) | |
| if len(idxs) >= MIN_GROUP_IMAGES: | |
| group_to_images[g] = idxs | |
| active_groups = sorted(group_to_images.keys()) | |
| if not active_groups: | |
| raise RuntimeError("No active groups after MIN_GROUP_IMAGES filter.") | |
| # Semantic centroids for active groups. | |
| vec = get_tfidf_tag_vectors() | |
| mat = vec["reduced_matrix_norm"] | |
| tag_to_row = vec["tag_to_row_index"] | |
| group_centroids: Dict[str, np.ndarray] = {} | |
| for g in active_groups: | |
| rows = [tag_to_row[t] for t in groups_all[g] if t in tag_to_row] | |
| if len(rows) < 2: | |
| continue | |
| c = mat[rows].mean(axis=0) | |
| n = np.linalg.norm(c) | |
| if n > 0: | |
| group_centroids[g] = c / n | |
| semantic_groups = sorted(group_centroids.keys()) | |
| C = np.stack([group_centroids[g] for g in semantic_groups], axis=0) if semantic_groups else None | |
| baseline_group_probs = {g: len(group_to_images[g]) / n_images for g in active_groups} | |
| baseline_top5_mass = sum(sorted(baseline_group_probs.values(), reverse=True)[:5]) | |
| rows_out: List[Dict[str, str]] = [] | |
| probe_scores: Dict[str, float] = {} | |
| for p, p_idxs in probe_to_images.items(): | |
| if len(p_idxs) < MIN_PROBE_IMAGES: | |
| continue | |
| q = len(p_idxs) / n_images | |
| if q <= 0.0 or q >= 1.0: | |
| continue | |
| ig_sum = 0.0 | |
| ig_vals = [] | |
| mean_abs_log_lift = 0.0 | |
| lifts: Dict[str, float] = {} | |
| p1_group_probs: Dict[str, float] = {} | |
| for g in active_groups: | |
| g_idxs = group_to_images[g] | |
| pg = len(g_idxs) / n_images | |
| pg1 = len(p_idxs & g_idxs) / len(p_idxs) | |
| p0 = n_images - len(p_idxs) | |
| pg0 = len((set(range(n_images)) - p_idxs) & g_idxs) / p0 if p0 > 0 else pg | |
| ig = entropy_binary(pg) - (q * entropy_binary(pg1) + (1 - q) * entropy_binary(pg0)) | |
| ig = max(ig, 0.0) | |
| ig_vals.append(ig) | |
| ig_sum += ig | |
| lift = (pg1 + 1e-9) / (pg + 1e-9) | |
| lifts[g] = lift | |
| p1_group_probs[g] = pg1 | |
| mean_abs_log_lift += abs(math.log2(lift + 1e-12)) | |
| mean_abs_log_lift /= len(active_groups) | |
| ig_mean = float(np.mean(ig_vals)) if ig_vals else 0.0 | |
| top5_mass_p1 = sum(sorted(p1_group_probs.values(), reverse=True)[:5]) | |
| delta_top5_mass = top5_mass_p1 - baseline_top5_mass | |
| strongest_group = max(lifts.items(), key=lambda kv: abs(math.log2(kv[1] + 1e-12))) | |
| strongest_group_name = strongest_group[0] | |
| strongest_group_lift = strongest_group[1] | |
| semantic_top_group = "" | |
| semantic_margin = 0.0 | |
| semantic_entropy_norm = 1.0 | |
| if C is not None and p in tag_to_row: | |
| sims = C @ mat[tag_to_row[p]] | |
| order = np.argsort(sims)[::-1] | |
| i1 = int(order[0]) | |
| i2 = int(order[1]) if len(order) > 1 else i1 | |
| semantic_top_group = semantic_groups[i1] | |
| semantic_margin = float(sims[i1] - sims[i2]) | |
| probs = softmax(sims, SOFTMAX_TAU) | |
| h = -float(np.sum(probs * np.log2(np.maximum(probs, 1e-12)))) | |
| semantic_entropy_norm = h / math.log2(len(probs)) if len(probs) > 1 else 0.0 | |
| prevalence_balance = math.sqrt(q * (1 - q)) | |
| focus = max(0.0, 1.0 - semantic_entropy_norm) | |
| combined_score = ig_sum * prevalence_balance * (0.5 + 0.5 * focus) | |
| probe_scores[p] = combined_score | |
| rows_out.append( | |
| { | |
| "tag": p, | |
| "sample_occurrences": str(len(p_idxs)), | |
| "fluffyrock_count": str(counts.get(p, 0)), | |
| "prevalence": f"{q:.6f}", | |
| "ig_sum_bits": f"{ig_sum:.6f}", | |
| "ig_mean_bits": f"{ig_mean:.6f}", | |
| "delta_top5_mass": f"{delta_top5_mass:.6f}", | |
| "mean_abs_log2_lift": f"{mean_abs_log_lift:.6f}", | |
| "semantic_top_group": semantic_top_group, | |
| "semantic_margin": f"{semantic_margin:.6f}", | |
| "semantic_entropy_norm": f"{semantic_entropy_norm:.6f}", | |
| "strongest_group_by_lift": strongest_group_name, | |
| "strongest_group_lift": f"{strongest_group_lift:.6f}", | |
| "suggested_probe_bundle": infer_probe_bundle(p, semantic_top_group, strongest_group_name), | |
| "needs_glossary": "1" if needs_glossary(p) else "0", | |
| "combined_score": f"{combined_score:.6f}", | |
| } | |
| ) | |
| # Add an actionability score that downweights very common probes and favors | |
| # probes that noticeably reshape top-group mass. | |
| for r in rows_out: | |
| q = float(r["prevalence"]) | |
| ig = float(r["ig_sum_bits"]) | |
| delta_top5 = max(0.0, float(r["delta_top5_mass"])) | |
| semantic_focus = max(0.0, 1.0 - float(r["semantic_entropy_norm"])) | |
| prevalence_penalty = max(0.0, 1.0 - abs(2 * q - 1.0)) | |
| actionable_score = ig * prevalence_penalty * delta_top5 * (0.5 + 0.5 * semantic_focus) | |
| r["actionable_score"] = f"{actionable_score:.6f}" | |
| rows_out.sort(key=lambda r: float(r["combined_score"]), reverse=True) | |
| # Diversified shortlist via MMR-like greedy on top pool. | |
| top_pool = [r["tag"] for r in rows_out[:MMR_TOP_POOL]] | |
| selected: List[str] = [] | |
| while len(selected) < MMR_K and top_pool: | |
| best_tag = None | |
| best_val = -1e9 | |
| for t in top_pool: | |
| rel = probe_scores.get(t, 0.0) | |
| if not selected: | |
| val = rel | |
| else: | |
| red = float(np.mean([binary_mi(probe_to_images[t], probe_to_images[s], n_images) for s in selected])) | |
| val = rel - MMR_LAMBDA * red | |
| if val > best_val: | |
| best_val = val | |
| best_tag = t | |
| if best_tag is None: | |
| break | |
| selected.append(best_tag) | |
| top_pool.remove(best_tag) | |
| OUT_CSV.parent.mkdir(parents=True, exist_ok=True) | |
| with OUT_CSV.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=[ | |
| "tag", | |
| "sample_occurrences", | |
| "fluffyrock_count", | |
| "prevalence", | |
| "ig_sum_bits", | |
| "ig_mean_bits", | |
| "delta_top5_mass", | |
| "mean_abs_log2_lift", | |
| "semantic_top_group", | |
| "semantic_margin", | |
| "semantic_entropy_norm", | |
| "strongest_group_by_lift", | |
| "strongest_group_lift", | |
| "suggested_probe_bundle", | |
| "needs_glossary", | |
| "combined_score", | |
| "actionable_score", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(rows_out) | |
| # Aggregate bundle-level utility using top actionable tags per bundle. | |
| by_bundle: Dict[str, List[Dict[str, str]]] = {} | |
| for r in rows_out: | |
| by_bundle.setdefault(r["suggested_probe_bundle"], []).append(r) | |
| bundle_scores = [] | |
| for b, items in by_bundle.items(): | |
| items_sorted = sorted(items, key=lambda x: float(x["actionable_score"]), reverse=True) | |
| top_items = items_sorted[:5] | |
| score = sum(float(x["actionable_score"]) for x in top_items) | |
| glossary_rate = sum(1 for x in top_items if x["needs_glossary"] == "1") / len(top_items) if top_items else 0.0 | |
| bundle_scores.append( | |
| { | |
| "bundle": b, | |
| "bundle_score_top5_actionable": round(score, 6), | |
| "top_tags": [x["tag"] for x in top_items], | |
| "glossary_rate_top5": round(glossary_rate, 3), | |
| } | |
| ) | |
| bundle_scores.sort(key=lambda x: x["bundle_score_top5_actionable"], reverse=True) | |
| top_actionable = sorted(rows_out, key=lambda r: float(r["actionable_score"]), reverse=True) | |
| top_mid_prevalence = [ | |
| r for r in top_actionable if 0.03 <= float(r["prevalence"]) <= 0.35 | |
| ][:40] | |
| summary = { | |
| "config": { | |
| "min_count": MIN_COUNT, | |
| "min_probe_images": MIN_PROBE_IMAGES, | |
| "min_group_images": MIN_GROUP_IMAGES, | |
| "softmax_tau": SOFTMAX_TAU, | |
| "mmr_lambda": MMR_LAMBDA, | |
| "mmr_top_pool": MMR_TOP_POOL, | |
| "mmr_k": MMR_K, | |
| }, | |
| "n_images": n_images, | |
| "n_candidate_probes": len(rows_out), | |
| "n_active_groups": len(active_groups), | |
| "excluded_wiki_groups": sorted(excluded_wiki_groups), | |
| "top_probes_by_combined_score": rows_out[:25], | |
| "top_probes_by_actionable_score": top_actionable[:25], | |
| "top_actionable_mid_prevalence_for_manual_review": top_mid_prevalence, | |
| "bundle_scores": bundle_scores[:20], | |
| "diversified_probe_shortlist": selected, | |
| "outputs": { | |
| "csv": str(OUT_CSV), | |
| "summary_json": str(OUT_SUMMARY), | |
| }, | |
| } | |
| with OUT_SUMMARY.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| print(f"Images: {n_images}") | |
| print(f"Active groups: {len(active_groups)}") | |
| print(f"Candidate probes: {len(rows_out)}") | |
| print(f"Top probes: {[r['tag'] for r in rows_out[:10]]}") | |
| print(f"Diversified shortlist: {selected}") | |
| print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}") | |
| if __name__ == "__main__": | |
| main() | |