Prompt_Squirrel_RAG / scripts /analyze_probe_informativeness.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Rank candidate probe tags by informativeness before any LLM queries.
This is an offline metric pass combining:
- entropy / information gain from sample co-occurrence,
- lift against active groups/categories,
- reduced TF-IDF semantic focus against group centroids.
Compact outputs (overwrite in place):
- data/analysis/probe_informativeness.csv
- data/analysis/probe_informativeness_summary.json
"""
from __future__ import annotations
import csv
import json
import math
from collections import Counter
from pathlib import Path
from typing import Dict, List, Set, Tuple
import numpy as np
from psq_rag.retrieval.state import get_tfidf_tag_vectors
REPO = Path(__file__).resolve().parents[1]
COUNTS_CSV = REPO / "fluffyrock_3m.csv"
SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json"
REGISTRY_CSV = REPO / "data" / "category_registry.csv"
CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv"
OUT_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv"
OUT_SUMMARY = REPO / "data" / "analysis" / "probe_informativeness_summary.json"
MIN_COUNT = 200
MIN_PROBE_IMAGES = 5
MIN_GROUP_IMAGES = 20
SOFTMAX_TAU = 0.15
MMR_LAMBDA = 0.35
MMR_TOP_POOL = 120
MMR_K = 15
DOMAIN_JARGON = {
"solo", "duo", "trio", "anthro", "feral", "gynomorph", "andromorph", "maleherm",
"topwear", "bottomwear", "legwear", "handwear", "headwear", "footwear",
"leporid", "canid", "canis", "felid", "felis", "equid", "haplorhine",
"zero_pictured", "male/female", "male/male", "female/female",
}
def load_counts(path: Path) -> Dict[str, int]:
out: Dict[str, int] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
try:
out[row[0]] = int(row[2]) if row[2] else 0
except ValueError:
out[row[0]] = 0
return out
def load_image_tags(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]:
rows: List[Set[str]] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
raw = obj.get("tags_ground_truth_categorized", "")
if not raw:
continue
try:
d = json.loads(raw)
except Exception:
continue
tags: Set[str] = set()
if isinstance(d, dict):
for vals in d.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str) and counts.get(t, 0) >= min_count:
tags.add(t)
if tags:
rows.append(tags)
return rows
def load_excluded_wiki_groups_from_policy(path: Path) -> Set[str]:
"""Read excluded wiki groups from the tag-group map file.
Convention:
- rows with enabled=1 and category_name starting with 'ignored_'
- tag_group column contains the wiki group name to exclude.
"""
excluded: Set[str] = set()
if not path.is_file():
return excluded
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("enabled") or "").strip() not in {"1", "true", "True"}:
continue
category = (row.get("category_name") or "").strip().lower()
group = (row.get("tag_group") or "").strip()
if category.startswith("ignored_") and group:
excluded.add(group)
return excluded
def load_groups() -> Tuple[Dict[str, Set[str]], Set[str]]:
groups: Dict[str, Set[str]] = {}
excluded_wiki_groups = load_excluded_wiki_groups_from_policy(CATEGORY_TAG_GROUP_MAP_CSV)
with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f:
wiki = json.load(f)
for g, tags in wiki.items():
if g in excluded_wiki_groups:
continue
if isinstance(tags, list):
groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t}
with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}:
continue
c = (row.get("category_name") or "").strip()
t = (row.get("tag") or "").strip()
if c and t:
groups.setdefault(f"cat:{c}", set()).add(t)
return groups, excluded_wiki_groups
def needs_glossary(tag: str) -> bool:
if tag in DOMAIN_JARGON:
return True
if "/" in tag or "(" in tag or ")" in tag:
return True
if any(ch.isdigit() for ch in tag):
return True
# Taxonomy-ish suffixes often need disambiguation in prompts.
if tag.endswith("id") or tag.endswith("ine"):
return True
return False
def infer_probe_bundle(tag: str, semantic_top_group: str, strongest_group: str) -> str:
t = tag
g = f"{semantic_top_group} {strongest_group}".lower()
if t in {"solo", "duo", "trio", "group", "zero_pictured"}:
return "count_cardinality"
if t in {"anthro", "feral", "humanoid", "biped", "quadruped"}:
return "body_type_presence"
if t in {"clothed", "clothing", "topless", "bottomless", "nude", "barefoot", "topwear", "bottomwear"}:
return "clothing_state"
if any(x in t for x in ["canid", "canis", "felid", "felis", "equid", "leporid", "species", "mammal", "bird", "bear", "unicorn", "reptile", "dragon"]):
return "species_taxonomy"
if any(x in t for x in ["breast", "thigh", "hips", "curvy", "muscular", "overweight", "chubby", "butt"]):
return "body_shape_breasts"
if any(x in t for x in ["look", "gaze", "eyes", "smile", "blush", "open_mouth", "eyes_closed"]):
return "gaze_expression"
if t in {"text", "dialogue", "<3"} or any(x in t for x in ["text", "dialogue", "logo", "symbol"]):
return "text_symbols"
if any(x in t for x in ["background", "outside", "inside", "indoors", "outdoors", "standing", "sitting"]):
return "scene_pose"
if "cat:clothing" in g or "wiki:clothes" in g:
return "clothing_state"
if "cat:count" in g:
return "count_cardinality"
return "other"
def entropy_binary(p: float) -> float:
p = min(max(p, 1e-12), 1 - 1e-12)
return -(p * math.log2(p) + (1 - p) * math.log2(1 - p))
def softmax(x: np.ndarray, tau: float) -> np.ndarray:
z = x / max(tau, 1e-6)
z = z - np.max(z)
e = np.exp(z)
return e / max(np.sum(e), 1e-12)
def binary_mi(a_idx: Set[int], b_idx: Set[int], n: int) -> float:
# MI for Bernoulli variables in bits.
n11 = len(a_idx & b_idx)
n10 = len(a_idx - b_idx)
n01 = len(b_idx - a_idx)
n00 = n - n11 - n10 - n01
probs = {
(1, 1): n11 / n,
(1, 0): n10 / n,
(0, 1): n01 / n,
(0, 0): n00 / n,
}
pa = (n11 + n10) / n
pb = (n11 + n01) / n
mi = 0.0
for (a, b), p in probs.items():
if p <= 0:
continue
qa = pa if a == 1 else (1 - pa)
qb = pb if b == 1 else (1 - pb)
mi += p * math.log2(p / max(qa * qb, 1e-12))
return max(mi, 0.0)
def main() -> None:
counts = load_counts(COUNTS_CSV)
image_tags = load_image_tags(SAMPLE_JSONL, counts, MIN_COUNT)
n_images = len(image_tags)
if n_images == 0:
raise RuntimeError("No image tags loaded.")
groups_all, excluded_wiki_groups = load_groups()
probe_to_images: Dict[str, Set[int]] = {}
tag_occ = Counter()
for i, tags in enumerate(image_tags):
for t in tags:
tag_occ[t] += 1
probe_to_images.setdefault(t, set()).add(i)
group_to_images: Dict[str, Set[int]] = {}
for g, members in groups_all.items():
idxs: Set[int] = set()
for i, tags in enumerate(image_tags):
if tags & members:
idxs.add(i)
if len(idxs) >= MIN_GROUP_IMAGES:
group_to_images[g] = idxs
active_groups = sorted(group_to_images.keys())
if not active_groups:
raise RuntimeError("No active groups after MIN_GROUP_IMAGES filter.")
# Semantic centroids for active groups.
vec = get_tfidf_tag_vectors()
mat = vec["reduced_matrix_norm"]
tag_to_row = vec["tag_to_row_index"]
group_centroids: Dict[str, np.ndarray] = {}
for g in active_groups:
rows = [tag_to_row[t] for t in groups_all[g] if t in tag_to_row]
if len(rows) < 2:
continue
c = mat[rows].mean(axis=0)
n = np.linalg.norm(c)
if n > 0:
group_centroids[g] = c / n
semantic_groups = sorted(group_centroids.keys())
C = np.stack([group_centroids[g] for g in semantic_groups], axis=0) if semantic_groups else None
baseline_group_probs = {g: len(group_to_images[g]) / n_images for g in active_groups}
baseline_top5_mass = sum(sorted(baseline_group_probs.values(), reverse=True)[:5])
rows_out: List[Dict[str, str]] = []
probe_scores: Dict[str, float] = {}
for p, p_idxs in probe_to_images.items():
if len(p_idxs) < MIN_PROBE_IMAGES:
continue
q = len(p_idxs) / n_images
if q <= 0.0 or q >= 1.0:
continue
ig_sum = 0.0
ig_vals = []
mean_abs_log_lift = 0.0
lifts: Dict[str, float] = {}
p1_group_probs: Dict[str, float] = {}
for g in active_groups:
g_idxs = group_to_images[g]
pg = len(g_idxs) / n_images
pg1 = len(p_idxs & g_idxs) / len(p_idxs)
p0 = n_images - len(p_idxs)
pg0 = len((set(range(n_images)) - p_idxs) & g_idxs) / p0 if p0 > 0 else pg
ig = entropy_binary(pg) - (q * entropy_binary(pg1) + (1 - q) * entropy_binary(pg0))
ig = max(ig, 0.0)
ig_vals.append(ig)
ig_sum += ig
lift = (pg1 + 1e-9) / (pg + 1e-9)
lifts[g] = lift
p1_group_probs[g] = pg1
mean_abs_log_lift += abs(math.log2(lift + 1e-12))
mean_abs_log_lift /= len(active_groups)
ig_mean = float(np.mean(ig_vals)) if ig_vals else 0.0
top5_mass_p1 = sum(sorted(p1_group_probs.values(), reverse=True)[:5])
delta_top5_mass = top5_mass_p1 - baseline_top5_mass
strongest_group = max(lifts.items(), key=lambda kv: abs(math.log2(kv[1] + 1e-12)))
strongest_group_name = strongest_group[0]
strongest_group_lift = strongest_group[1]
semantic_top_group = ""
semantic_margin = 0.0
semantic_entropy_norm = 1.0
if C is not None and p in tag_to_row:
sims = C @ mat[tag_to_row[p]]
order = np.argsort(sims)[::-1]
i1 = int(order[0])
i2 = int(order[1]) if len(order) > 1 else i1
semantic_top_group = semantic_groups[i1]
semantic_margin = float(sims[i1] - sims[i2])
probs = softmax(sims, SOFTMAX_TAU)
h = -float(np.sum(probs * np.log2(np.maximum(probs, 1e-12))))
semantic_entropy_norm = h / math.log2(len(probs)) if len(probs) > 1 else 0.0
prevalence_balance = math.sqrt(q * (1 - q))
focus = max(0.0, 1.0 - semantic_entropy_norm)
combined_score = ig_sum * prevalence_balance * (0.5 + 0.5 * focus)
probe_scores[p] = combined_score
rows_out.append(
{
"tag": p,
"sample_occurrences": str(len(p_idxs)),
"fluffyrock_count": str(counts.get(p, 0)),
"prevalence": f"{q:.6f}",
"ig_sum_bits": f"{ig_sum:.6f}",
"ig_mean_bits": f"{ig_mean:.6f}",
"delta_top5_mass": f"{delta_top5_mass:.6f}",
"mean_abs_log2_lift": f"{mean_abs_log_lift:.6f}",
"semantic_top_group": semantic_top_group,
"semantic_margin": f"{semantic_margin:.6f}",
"semantic_entropy_norm": f"{semantic_entropy_norm:.6f}",
"strongest_group_by_lift": strongest_group_name,
"strongest_group_lift": f"{strongest_group_lift:.6f}",
"suggested_probe_bundle": infer_probe_bundle(p, semantic_top_group, strongest_group_name),
"needs_glossary": "1" if needs_glossary(p) else "0",
"combined_score": f"{combined_score:.6f}",
}
)
# Add an actionability score that downweights very common probes and favors
# probes that noticeably reshape top-group mass.
for r in rows_out:
q = float(r["prevalence"])
ig = float(r["ig_sum_bits"])
delta_top5 = max(0.0, float(r["delta_top5_mass"]))
semantic_focus = max(0.0, 1.0 - float(r["semantic_entropy_norm"]))
prevalence_penalty = max(0.0, 1.0 - abs(2 * q - 1.0))
actionable_score = ig * prevalence_penalty * delta_top5 * (0.5 + 0.5 * semantic_focus)
r["actionable_score"] = f"{actionable_score:.6f}"
rows_out.sort(key=lambda r: float(r["combined_score"]), reverse=True)
# Diversified shortlist via MMR-like greedy on top pool.
top_pool = [r["tag"] for r in rows_out[:MMR_TOP_POOL]]
selected: List[str] = []
while len(selected) < MMR_K and top_pool:
best_tag = None
best_val = -1e9
for t in top_pool:
rel = probe_scores.get(t, 0.0)
if not selected:
val = rel
else:
red = float(np.mean([binary_mi(probe_to_images[t], probe_to_images[s], n_images) for s in selected]))
val = rel - MMR_LAMBDA * red
if val > best_val:
best_val = val
best_tag = t
if best_tag is None:
break
selected.append(best_tag)
top_pool.remove(best_tag)
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
with OUT_CSV.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"sample_occurrences",
"fluffyrock_count",
"prevalence",
"ig_sum_bits",
"ig_mean_bits",
"delta_top5_mass",
"mean_abs_log2_lift",
"semantic_top_group",
"semantic_margin",
"semantic_entropy_norm",
"strongest_group_by_lift",
"strongest_group_lift",
"suggested_probe_bundle",
"needs_glossary",
"combined_score",
"actionable_score",
],
)
writer.writeheader()
writer.writerows(rows_out)
# Aggregate bundle-level utility using top actionable tags per bundle.
by_bundle: Dict[str, List[Dict[str, str]]] = {}
for r in rows_out:
by_bundle.setdefault(r["suggested_probe_bundle"], []).append(r)
bundle_scores = []
for b, items in by_bundle.items():
items_sorted = sorted(items, key=lambda x: float(x["actionable_score"]), reverse=True)
top_items = items_sorted[:5]
score = sum(float(x["actionable_score"]) for x in top_items)
glossary_rate = sum(1 for x in top_items if x["needs_glossary"] == "1") / len(top_items) if top_items else 0.0
bundle_scores.append(
{
"bundle": b,
"bundle_score_top5_actionable": round(score, 6),
"top_tags": [x["tag"] for x in top_items],
"glossary_rate_top5": round(glossary_rate, 3),
}
)
bundle_scores.sort(key=lambda x: x["bundle_score_top5_actionable"], reverse=True)
top_actionable = sorted(rows_out, key=lambda r: float(r["actionable_score"]), reverse=True)
top_mid_prevalence = [
r for r in top_actionable if 0.03 <= float(r["prevalence"]) <= 0.35
][:40]
summary = {
"config": {
"min_count": MIN_COUNT,
"min_probe_images": MIN_PROBE_IMAGES,
"min_group_images": MIN_GROUP_IMAGES,
"softmax_tau": SOFTMAX_TAU,
"mmr_lambda": MMR_LAMBDA,
"mmr_top_pool": MMR_TOP_POOL,
"mmr_k": MMR_K,
},
"n_images": n_images,
"n_candidate_probes": len(rows_out),
"n_active_groups": len(active_groups),
"excluded_wiki_groups": sorted(excluded_wiki_groups),
"top_probes_by_combined_score": rows_out[:25],
"top_probes_by_actionable_score": top_actionable[:25],
"top_actionable_mid_prevalence_for_manual_review": top_mid_prevalence,
"bundle_scores": bundle_scores[:20],
"diversified_probe_shortlist": selected,
"outputs": {
"csv": str(OUT_CSV),
"summary_json": str(OUT_SUMMARY),
},
}
with OUT_SUMMARY.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Images: {n_images}")
print(f"Active groups: {len(active_groups)}")
print(f"Candidate probes: {len(rows_out)}")
print(f"Top probes: {[r['tag'] for r in rows_out[:10]]}")
print(f"Diversified shortlist: {selected}")
print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}")
if __name__ == "__main__":
main()