"""Oracle simulation for probe-selection policies before implementation. Compares fixed and adaptive probe policies for ranking tag groups/categories. This uses perfect probe answers from ground-truth tags (oracle), so results are an optimistic upper bound on policy usefulness. Compact outputs (overwrite each run): - data/analysis/probe_policy_simulation.csv - data/analysis/probe_policy_simulation_summary.json """ from __future__ import annotations import csv import json import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple import numpy as np REPO = Path(__file__).resolve().parents[1] COUNTS_CSV = REPO / "fluffyrock_3m.csv" SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json" REGISTRY_CSV = REPO / "data" / "category_registry.csv" CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv" PROBE_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" PROBE_SUMMARY_JSON = REPO / "data" / "analysis" / "probe_informativeness_summary.json" OUT_CSV = REPO / "data" / "analysis" / "probe_policy_simulation.csv" OUT_JSON = REPO / "data" / "analysis" / "probe_policy_simulation_summary.json" MIN_COUNT = 200 MIN_GROUP_IMAGES = 20 MIN_PROBE_IMAGES = 5 PROBE_POOL_SIZE = 120 PREVALENCE_MIN = 0.02 PREVALENCE_MAX = 0.60 BUDGETS = [3, 5, 8] TOP_M_VALUES = [5, 8] MODES = ["cold_start", "warm_start_easy2"] LAPLACE_ALPHA = 1.0 ENTROPY_EPS = 1e-12 def load_counts(path: Path) -> Dict[str, int]: out: Dict[str, int] = {} with path.open("r", encoding="utf-8", newline="") as f: reader = csv.reader(f) for row in reader: if len(row) < 3: continue try: out[row[0]] = int(row[2]) if row[2] else 0 except ValueError: out[row[0]] = 0 return out def load_images(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]: images: List[Set[str]] = [] with path.open("r", encoding="utf-8") as f: for line in f: obj = json.loads(line) raw = obj.get("tags_ground_truth_categorized", "") if not raw: continue try: d = json.loads(raw) except Exception: continue tags: Set[str] = set() if isinstance(d, dict): for vals in d.values(): if isinstance(vals, list): for t in vals: if isinstance(t, str) and counts.get(t, 0) >= min_count: tags.add(t) if tags: images.append(tags) return images def load_excluded_wiki_groups(path: Path) -> Set[str]: excluded: Set[str] = set() if not path.is_file(): return excluded with path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if (row.get("enabled") or "").strip() not in {"1", "true", "True"}: continue category = (row.get("category_name") or "").strip().lower() group = (row.get("tag_group") or "").strip() if category.startswith("ignored_") and group: excluded.add(group) return excluded def load_groups(excluded_wiki_groups: Set[str]) -> Dict[str, Set[str]]: groups: Dict[str, Set[str]] = {} with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f: wiki = json.load(f) for g, tags in wiki.items(): if g in excluded_wiki_groups: continue if isinstance(tags, list): groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t} with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}: continue c = (row.get("category_name") or "").strip() t = (row.get("tag") or "").strip() if c and t: groups.setdefault(f"cat:{c}", set()).add(t) return groups def load_probe_candidates() -> Tuple[List[Dict[str, str]], List[str]]: rows = [] with PROBE_CSV.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for r in reader: rows.append(r) rows.sort(key=lambda r: float(r["actionable_score"]), reverse=True) filtered = [ r for r in rows if int(r["sample_occurrences"]) >= MIN_PROBE_IMAGES and PREVALENCE_MIN <= float(r["prevalence"]) <= PREVALENCE_MAX ][:PROBE_POOL_SIZE] mmr = [] if PROBE_SUMMARY_JSON.is_file(): s = json.loads(PROBE_SUMMARY_JSON.read_text(encoding="utf-8")) mmr = [t for t in s.get("diversified_probe_shortlist", [])] candidate_tags = [r["tag"] for r in filtered] mmr_filtered = [t for t in mmr if t in set(candidate_tags)] return filtered, mmr_filtered def entropy(p: np.ndarray) -> float: p = np.maximum(p, ENTROPY_EPS) return float(-np.sum(p * np.log2(p))) def normalize_probs(logp: np.ndarray) -> np.ndarray: z = logp - np.max(logp) e = np.exp(z) s = np.sum(e) return e / max(s, ENTROPY_EPS) def ndcg_binary(ranked_true_flags: List[int], m: int, n_true: int) -> float: if m <= 0: return 0.0 dcg = 0.0 for i, rel in enumerate(ranked_true_flags[:m]): if rel: dcg += 1.0 / math.log2(i + 2) ideal_k = min(n_true, m) if ideal_k == 0: return 0.0 idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_k)) return dcg / max(idcg, ENTROPY_EPS) def main() -> None: counts = load_counts(COUNTS_CSV) images = load_images(SAMPLE_JSONL, counts, MIN_COUNT) if not images: raise RuntimeError("No images loaded.") n_images = len(images) excluded_wiki_groups = load_excluded_wiki_groups(CATEGORY_TAG_GROUP_MAP_CSV) groups_all = load_groups(excluded_wiki_groups) # Keep active groups only. group_image_idxs: Dict[str, Set[int]] = {} for g, members in groups_all.items(): idxs = {i for i, tags in enumerate(images) if tags & members} if len(idxs) >= MIN_GROUP_IMAGES: group_image_idxs[g] = idxs group_names = sorted(group_image_idxs.keys()) n_groups = len(group_names) if n_groups == 0: raise RuntimeError("No active groups.") group_idx = {g: i for i, g in enumerate(group_names)} group_priors = np.zeros(n_groups, dtype=np.float64) for g, idxs in group_image_idxs.items(): group_priors[group_idx[g]] = (len(idxs) + LAPLACE_ALPHA) / (n_images + LAPLACE_ALPHA * n_groups) group_priors /= np.sum(group_priors) # Per-image true groups for evaluation. true_groups_by_image: List[Set[str]] = [] for tags in images: true_g = {g for g in group_names if tags & groups_all[g]} true_groups_by_image.append(true_g) probe_rows, mmr_shortlist = load_probe_candidates() if not probe_rows: raise RuntimeError("No probe candidates from probe_informativeness.csv.") candidate_tags = [r["tag"] for r in probe_rows] candidate_set = set(candidate_tags) top_actionable = candidate_tags # Fill mmr shortlist to have enough probes for larger budgets. mmr_full = list(mmr_shortlist) for t in top_actionable: if t not in mmr_full: mmr_full.append(t) easy_known_tags = {r["tag"] for r in probe_rows if r.get("needs_glossary", "0") == "0"} # Probe state precompute: presence by image. probe_present_by_image: Dict[str, np.ndarray] = {} for t in candidate_tags: arr = np.zeros(n_images, dtype=np.int8) for i, tags in enumerate(images): if t in tags: arr[i] = 1 probe_present_by_image[t] = arr # Likelihoods: P(probe=1 | group), smoothed. p1_given_group: Dict[str, np.ndarray] = {} for t in candidate_tags: arr = np.zeros(n_groups, dtype=np.float64) t_present = probe_present_by_image[t] for g, g_i in group_idx.items(): idxs = group_image_idxs[g] n_g = len(idxs) n_tg = int(np.sum([t_present[i] for i in idxs])) arr[g_i] = (n_tg + LAPLACE_ALPHA) / (n_g + 2 * LAPLACE_ALPHA) p1_given_group[t] = np.clip(arr, 1e-6, 1 - 1e-6) def posterior_from_evidence(evidence: Dict[str, int]) -> np.ndarray: logp = np.log(np.maximum(group_priors, ENTROPY_EPS)) for t, v in evidence.items(): if t not in p1_given_group: continue p1 = p1_given_group[t] if v == 1: logp += np.log(p1) else: logp += np.log(1 - p1) return normalize_probs(logp) def init_evidence(mode: str, image_i: int) -> Dict[str, int]: if mode == "cold_start": return {} if mode == "warm_start_easy2": tags = images[image_i] # Approximate "already known from prompt" with up to 2 easy tags present. present_easy = [t for t in top_actionable if t in easy_known_tags and t in tags] return {t: 1 for t in present_easy[:2]} raise ValueError(f"Unknown mode: {mode}") def choose_adaptive_entropy(image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]: chosen: List[str] = [] asked = set(evidence.keys()) for _ in range(budget): p = posterior_from_evidence(evidence) h0 = entropy(p) best_t = None best_gain = -1e9 for t in candidate_tags: if t in asked: continue p1g = p1_given_group[t] p_t1 = float(np.sum(p * p1g)) # posterior if t=1 p1 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(p1g)) h1 = entropy(p1) # posterior if t=0 p0 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(1 - p1g)) h2 = entropy(p0) exp_h = p_t1 * h1 + (1 - p_t1) * h2 gain = h0 - exp_h if gain > best_gain: best_gain = gain best_t = t if best_t is None: break chosen.append(best_t) asked.add(best_t) # Oracle observation. evidence[best_t] = int(probe_present_by_image[best_t][image_i]) return chosen def choose_fixed(order: List[str], image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]: chosen = [] asked = set(evidence.keys()) for t in order: if t in asked: continue chosen.append(t) asked.add(t) evidence[t] = int(probe_present_by_image[t][image_i]) # oracle observation if len(chosen) >= budget: break return chosen strategy_orders = { "fixed_top_actionable": top_actionable, "fixed_mmr": mmr_full, } metric_rows: List[Dict[str, str]] = [] for mode in MODES: for budget in BUDGETS: for strategy in ["baseline_no_probe", "fixed_top_actionable", "fixed_mmr", "adaptive_entropy"]: per_top_m = {m: defaultdict(float) for m in TOP_M_VALUES} for i in range(n_images): ev = init_evidence(mode, i) if strategy == "baseline_no_probe": pass elif strategy == "adaptive_entropy": choose_adaptive_entropy(i, budget, ev) else: choose_fixed(strategy_orders[strategy], i, budget, ev) post = posterior_from_evidence(ev) ranking = np.argsort(post)[::-1] ranked_groups = [group_names[j] for j in ranking] true_g = true_groups_by_image[i] for m in TOP_M_VALUES: topm = ranked_groups[:m] n_true = len(true_g) n_hit = len(set(topm) & true_g) hit = 1.0 if n_hit > 0 else 0.0 rec = n_hit / n_true if n_true > 0 else 0.0 prec = n_hit / m if m > 0 else 0.0 flags = [1 if g in true_g else 0 for g in topm] ndcg = ndcg_binary(flags, m, n_true) true_mass = float(np.sum([post[group_idx[g]] for g in true_g])) if true_g else 0.0 topm_true_mass = float(np.sum([post[group_idx[g]] for g in topm if g in true_g])) per_top_m[m]["hit"] += hit per_top_m[m]["rec"] += rec per_top_m[m]["prec"] += prec per_top_m[m]["ndcg"] += ndcg per_top_m[m]["true_mass"] += true_mass per_top_m[m]["topm_true_mass"] += topm_true_mass for m in TOP_M_VALUES: agg = per_top_m[m] metric_rows.append( { "mode": mode, "strategy": strategy, "budget": str(budget), "top_m": str(m), "hit_at_m": f"{agg['hit'] / n_images:.6f}", "recall_at_m": f"{agg['rec'] / n_images:.6f}", "precision_at_m": f"{agg['prec'] / n_images:.6f}", "ndcg_at_m": f"{agg['ndcg'] / n_images:.6f}", "true_mass": f"{agg['true_mass'] / n_images:.6f}", "topm_true_mass": f"{agg['topm_true_mass'] / n_images:.6f}", } ) OUT_CSV.parent.mkdir(parents=True, exist_ok=True) with OUT_CSV.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "mode", "strategy", "budget", "top_m", "hit_at_m", "recall_at_m", "precision_at_m", "ndcg_at_m", "true_mass", "topm_true_mass", ], ) writer.writeheader() writer.writerows(metric_rows) # Quick "is it likely useful?" summary at top_m=5, budget=5. lookup = { (r["mode"], r["strategy"], r["budget"], r["top_m"]): r for r in metric_rows } key = lambda mode, strategy: lookup[(mode, strategy, "5", "5")] likely_useful = [] for mode in MODES: b = key(mode, "baseline_no_probe") a = key(mode, "adaptive_entropy") t = key(mode, "fixed_top_actionable") likely_useful.append( { "mode": mode, "baseline_ndcg_at_5": float(b["ndcg_at_m"]), "fixed_top_ndcg_at_5": float(t["ndcg_at_m"]), "adaptive_ndcg_at_5": float(a["ndcg_at_m"]), "adaptive_minus_fixed_top_ndcg_at_5": float(a["ndcg_at_m"]) - float(t["ndcg_at_m"]), "adaptive_minus_baseline_ndcg_at_5": float(a["ndcg_at_m"]) - float(b["ndcg_at_m"]), } ) summary = { "config": { "min_count": MIN_COUNT, "min_group_images": MIN_GROUP_IMAGES, "min_probe_images": MIN_PROBE_IMAGES, "probe_pool_size": PROBE_POOL_SIZE, "prevalence_min": PREVALENCE_MIN, "prevalence_max": PREVALENCE_MAX, "budgets": BUDGETS, "top_m_values": TOP_M_VALUES, "modes": MODES, "laplace_alpha": LAPLACE_ALPHA, "note": "Oracle probe answers from GT tags; optimistic upper bound.", }, "n_images": n_images, "n_active_groups": n_groups, "n_candidate_probes": len(candidate_tags), "excluded_wiki_groups": sorted(excluded_wiki_groups), "probe_pool_head": candidate_tags[:30], "mmr_head": mmr_full[:30], "likely_useful_snapshot_budget5_top5": likely_useful, "outputs": { "csv": str(OUT_CSV), "summary_json": str(OUT_JSON), }, } with OUT_JSON.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"Images: {n_images}") print(f"Active groups: {n_groups}") print(f"Candidate probes: {len(candidate_tags)}") print("Snapshot budget=5 top_m=5:", likely_useful) print(f"Outputs: {OUT_CSV}, {OUT_JSON}") if __name__ == "__main__": main()