Prompt_Squirrel_RAG / scripts /simulate_probe_policy.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Oracle simulation for probe-selection policies before implementation.
Compares fixed and adaptive probe policies for ranking tag groups/categories.
This uses perfect probe answers from ground-truth tags (oracle), so results are
an optimistic upper bound on policy usefulness.
Compact outputs (overwrite each run):
- data/analysis/probe_policy_simulation.csv
- data/analysis/probe_policy_simulation_summary.json
"""
from __future__ import annotations
import csv
import json
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
import numpy as np
REPO = Path(__file__).resolve().parents[1]
COUNTS_CSV = REPO / "fluffyrock_3m.csv"
SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json"
REGISTRY_CSV = REPO / "data" / "category_registry.csv"
CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv"
PROBE_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv"
PROBE_SUMMARY_JSON = REPO / "data" / "analysis" / "probe_informativeness_summary.json"
OUT_CSV = REPO / "data" / "analysis" / "probe_policy_simulation.csv"
OUT_JSON = REPO / "data" / "analysis" / "probe_policy_simulation_summary.json"
MIN_COUNT = 200
MIN_GROUP_IMAGES = 20
MIN_PROBE_IMAGES = 5
PROBE_POOL_SIZE = 120
PREVALENCE_MIN = 0.02
PREVALENCE_MAX = 0.60
BUDGETS = [3, 5, 8]
TOP_M_VALUES = [5, 8]
MODES = ["cold_start", "warm_start_easy2"]
LAPLACE_ALPHA = 1.0
ENTROPY_EPS = 1e-12
def load_counts(path: Path) -> Dict[str, int]:
out: Dict[str, int] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
try:
out[row[0]] = int(row[2]) if row[2] else 0
except ValueError:
out[row[0]] = 0
return out
def load_images(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]:
images: List[Set[str]] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
raw = obj.get("tags_ground_truth_categorized", "")
if not raw:
continue
try:
d = json.loads(raw)
except Exception:
continue
tags: Set[str] = set()
if isinstance(d, dict):
for vals in d.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str) and counts.get(t, 0) >= min_count:
tags.add(t)
if tags:
images.append(tags)
return images
def load_excluded_wiki_groups(path: Path) -> Set[str]:
excluded: Set[str] = set()
if not path.is_file():
return excluded
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("enabled") or "").strip() not in {"1", "true", "True"}:
continue
category = (row.get("category_name") or "").strip().lower()
group = (row.get("tag_group") or "").strip()
if category.startswith("ignored_") and group:
excluded.add(group)
return excluded
def load_groups(excluded_wiki_groups: Set[str]) -> Dict[str, Set[str]]:
groups: Dict[str, Set[str]] = {}
with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f:
wiki = json.load(f)
for g, tags in wiki.items():
if g in excluded_wiki_groups:
continue
if isinstance(tags, list):
groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t}
with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}:
continue
c = (row.get("category_name") or "").strip()
t = (row.get("tag") or "").strip()
if c and t:
groups.setdefault(f"cat:{c}", set()).add(t)
return groups
def load_probe_candidates() -> Tuple[List[Dict[str, str]], List[str]]:
rows = []
with PROBE_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for r in reader:
rows.append(r)
rows.sort(key=lambda r: float(r["actionable_score"]), reverse=True)
filtered = [
r for r in rows
if int(r["sample_occurrences"]) >= MIN_PROBE_IMAGES
and PREVALENCE_MIN <= float(r["prevalence"]) <= PREVALENCE_MAX
][:PROBE_POOL_SIZE]
mmr = []
if PROBE_SUMMARY_JSON.is_file():
s = json.loads(PROBE_SUMMARY_JSON.read_text(encoding="utf-8"))
mmr = [t for t in s.get("diversified_probe_shortlist", [])]
candidate_tags = [r["tag"] for r in filtered]
mmr_filtered = [t for t in mmr if t in set(candidate_tags)]
return filtered, mmr_filtered
def entropy(p: np.ndarray) -> float:
p = np.maximum(p, ENTROPY_EPS)
return float(-np.sum(p * np.log2(p)))
def normalize_probs(logp: np.ndarray) -> np.ndarray:
z = logp - np.max(logp)
e = np.exp(z)
s = np.sum(e)
return e / max(s, ENTROPY_EPS)
def ndcg_binary(ranked_true_flags: List[int], m: int, n_true: int) -> float:
if m <= 0:
return 0.0
dcg = 0.0
for i, rel in enumerate(ranked_true_flags[:m]):
if rel:
dcg += 1.0 / math.log2(i + 2)
ideal_k = min(n_true, m)
if ideal_k == 0:
return 0.0
idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_k))
return dcg / max(idcg, ENTROPY_EPS)
def main() -> None:
counts = load_counts(COUNTS_CSV)
images = load_images(SAMPLE_JSONL, counts, MIN_COUNT)
if not images:
raise RuntimeError("No images loaded.")
n_images = len(images)
excluded_wiki_groups = load_excluded_wiki_groups(CATEGORY_TAG_GROUP_MAP_CSV)
groups_all = load_groups(excluded_wiki_groups)
# Keep active groups only.
group_image_idxs: Dict[str, Set[int]] = {}
for g, members in groups_all.items():
idxs = {i for i, tags in enumerate(images) if tags & members}
if len(idxs) >= MIN_GROUP_IMAGES:
group_image_idxs[g] = idxs
group_names = sorted(group_image_idxs.keys())
n_groups = len(group_names)
if n_groups == 0:
raise RuntimeError("No active groups.")
group_idx = {g: i for i, g in enumerate(group_names)}
group_priors = np.zeros(n_groups, dtype=np.float64)
for g, idxs in group_image_idxs.items():
group_priors[group_idx[g]] = (len(idxs) + LAPLACE_ALPHA) / (n_images + LAPLACE_ALPHA * n_groups)
group_priors /= np.sum(group_priors)
# Per-image true groups for evaluation.
true_groups_by_image: List[Set[str]] = []
for tags in images:
true_g = {g for g in group_names if tags & groups_all[g]}
true_groups_by_image.append(true_g)
probe_rows, mmr_shortlist = load_probe_candidates()
if not probe_rows:
raise RuntimeError("No probe candidates from probe_informativeness.csv.")
candidate_tags = [r["tag"] for r in probe_rows]
candidate_set = set(candidate_tags)
top_actionable = candidate_tags
# Fill mmr shortlist to have enough probes for larger budgets.
mmr_full = list(mmr_shortlist)
for t in top_actionable:
if t not in mmr_full:
mmr_full.append(t)
easy_known_tags = {r["tag"] for r in probe_rows if r.get("needs_glossary", "0") == "0"}
# Probe state precompute: presence by image.
probe_present_by_image: Dict[str, np.ndarray] = {}
for t in candidate_tags:
arr = np.zeros(n_images, dtype=np.int8)
for i, tags in enumerate(images):
if t in tags:
arr[i] = 1
probe_present_by_image[t] = arr
# Likelihoods: P(probe=1 | group), smoothed.
p1_given_group: Dict[str, np.ndarray] = {}
for t in candidate_tags:
arr = np.zeros(n_groups, dtype=np.float64)
t_present = probe_present_by_image[t]
for g, g_i in group_idx.items():
idxs = group_image_idxs[g]
n_g = len(idxs)
n_tg = int(np.sum([t_present[i] for i in idxs]))
arr[g_i] = (n_tg + LAPLACE_ALPHA) / (n_g + 2 * LAPLACE_ALPHA)
p1_given_group[t] = np.clip(arr, 1e-6, 1 - 1e-6)
def posterior_from_evidence(evidence: Dict[str, int]) -> np.ndarray:
logp = np.log(np.maximum(group_priors, ENTROPY_EPS))
for t, v in evidence.items():
if t not in p1_given_group:
continue
p1 = p1_given_group[t]
if v == 1:
logp += np.log(p1)
else:
logp += np.log(1 - p1)
return normalize_probs(logp)
def init_evidence(mode: str, image_i: int) -> Dict[str, int]:
if mode == "cold_start":
return {}
if mode == "warm_start_easy2":
tags = images[image_i]
# Approximate "already known from prompt" with up to 2 easy tags present.
present_easy = [t for t in top_actionable if t in easy_known_tags and t in tags]
return {t: 1 for t in present_easy[:2]}
raise ValueError(f"Unknown mode: {mode}")
def choose_adaptive_entropy(image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]:
chosen: List[str] = []
asked = set(evidence.keys())
for _ in range(budget):
p = posterior_from_evidence(evidence)
h0 = entropy(p)
best_t = None
best_gain = -1e9
for t in candidate_tags:
if t in asked:
continue
p1g = p1_given_group[t]
p_t1 = float(np.sum(p * p1g))
# posterior if t=1
p1 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(p1g))
h1 = entropy(p1)
# posterior if t=0
p0 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(1 - p1g))
h2 = entropy(p0)
exp_h = p_t1 * h1 + (1 - p_t1) * h2
gain = h0 - exp_h
if gain > best_gain:
best_gain = gain
best_t = t
if best_t is None:
break
chosen.append(best_t)
asked.add(best_t)
# Oracle observation.
evidence[best_t] = int(probe_present_by_image[best_t][image_i])
return chosen
def choose_fixed(order: List[str], image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]:
chosen = []
asked = set(evidence.keys())
for t in order:
if t in asked:
continue
chosen.append(t)
asked.add(t)
evidence[t] = int(probe_present_by_image[t][image_i]) # oracle observation
if len(chosen) >= budget:
break
return chosen
strategy_orders = {
"fixed_top_actionable": top_actionable,
"fixed_mmr": mmr_full,
}
metric_rows: List[Dict[str, str]] = []
for mode in MODES:
for budget in BUDGETS:
for strategy in ["baseline_no_probe", "fixed_top_actionable", "fixed_mmr", "adaptive_entropy"]:
per_top_m = {m: defaultdict(float) for m in TOP_M_VALUES}
for i in range(n_images):
ev = init_evidence(mode, i)
if strategy == "baseline_no_probe":
pass
elif strategy == "adaptive_entropy":
choose_adaptive_entropy(i, budget, ev)
else:
choose_fixed(strategy_orders[strategy], i, budget, ev)
post = posterior_from_evidence(ev)
ranking = np.argsort(post)[::-1]
ranked_groups = [group_names[j] for j in ranking]
true_g = true_groups_by_image[i]
for m in TOP_M_VALUES:
topm = ranked_groups[:m]
n_true = len(true_g)
n_hit = len(set(topm) & true_g)
hit = 1.0 if n_hit > 0 else 0.0
rec = n_hit / n_true if n_true > 0 else 0.0
prec = n_hit / m if m > 0 else 0.0
flags = [1 if g in true_g else 0 for g in topm]
ndcg = ndcg_binary(flags, m, n_true)
true_mass = float(np.sum([post[group_idx[g]] for g in true_g])) if true_g else 0.0
topm_true_mass = float(np.sum([post[group_idx[g]] for g in topm if g in true_g]))
per_top_m[m]["hit"] += hit
per_top_m[m]["rec"] += rec
per_top_m[m]["prec"] += prec
per_top_m[m]["ndcg"] += ndcg
per_top_m[m]["true_mass"] += true_mass
per_top_m[m]["topm_true_mass"] += topm_true_mass
for m in TOP_M_VALUES:
agg = per_top_m[m]
metric_rows.append(
{
"mode": mode,
"strategy": strategy,
"budget": str(budget),
"top_m": str(m),
"hit_at_m": f"{agg['hit'] / n_images:.6f}",
"recall_at_m": f"{agg['rec'] / n_images:.6f}",
"precision_at_m": f"{agg['prec'] / n_images:.6f}",
"ndcg_at_m": f"{agg['ndcg'] / n_images:.6f}",
"true_mass": f"{agg['true_mass'] / n_images:.6f}",
"topm_true_mass": f"{agg['topm_true_mass'] / n_images:.6f}",
}
)
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
with OUT_CSV.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"mode",
"strategy",
"budget",
"top_m",
"hit_at_m",
"recall_at_m",
"precision_at_m",
"ndcg_at_m",
"true_mass",
"topm_true_mass",
],
)
writer.writeheader()
writer.writerows(metric_rows)
# Quick "is it likely useful?" summary at top_m=5, budget=5.
lookup = {
(r["mode"], r["strategy"], r["budget"], r["top_m"]): r
for r in metric_rows
}
key = lambda mode, strategy: lookup[(mode, strategy, "5", "5")]
likely_useful = []
for mode in MODES:
b = key(mode, "baseline_no_probe")
a = key(mode, "adaptive_entropy")
t = key(mode, "fixed_top_actionable")
likely_useful.append(
{
"mode": mode,
"baseline_ndcg_at_5": float(b["ndcg_at_m"]),
"fixed_top_ndcg_at_5": float(t["ndcg_at_m"]),
"adaptive_ndcg_at_5": float(a["ndcg_at_m"]),
"adaptive_minus_fixed_top_ndcg_at_5": float(a["ndcg_at_m"]) - float(t["ndcg_at_m"]),
"adaptive_minus_baseline_ndcg_at_5": float(a["ndcg_at_m"]) - float(b["ndcg_at_m"]),
}
)
summary = {
"config": {
"min_count": MIN_COUNT,
"min_group_images": MIN_GROUP_IMAGES,
"min_probe_images": MIN_PROBE_IMAGES,
"probe_pool_size": PROBE_POOL_SIZE,
"prevalence_min": PREVALENCE_MIN,
"prevalence_max": PREVALENCE_MAX,
"budgets": BUDGETS,
"top_m_values": TOP_M_VALUES,
"modes": MODES,
"laplace_alpha": LAPLACE_ALPHA,
"note": "Oracle probe answers from GT tags; optimistic upper bound.",
},
"n_images": n_images,
"n_active_groups": n_groups,
"n_candidate_probes": len(candidate_tags),
"excluded_wiki_groups": sorted(excluded_wiki_groups),
"probe_pool_head": candidate_tags[:30],
"mmr_head": mmr_full[:30],
"likely_useful_snapshot_budget5_top5": likely_useful,
"outputs": {
"csv": str(OUT_CSV),
"summary_json": str(OUT_JSON),
},
}
with OUT_JSON.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Images: {n_images}")
print(f"Active groups: {n_groups}")
print(f"Candidate probes: {len(candidate_tags)}")
print("Snapshot budget=5 top_m=5:", likely_useful)
print(f"Outputs: {OUT_CSV}, {OUT_JSON}")
if __name__ == "__main__":
main()