Spaces:
Running
Running
| """Oracle simulation for probe-selection policies before implementation. | |
| Compares fixed and adaptive probe policies for ranking tag groups/categories. | |
| This uses perfect probe answers from ground-truth tags (oracle), so results are | |
| an optimistic upper bound on policy usefulness. | |
| Compact outputs (overwrite each run): | |
| - data/analysis/probe_policy_simulation.csv | |
| - data/analysis/probe_policy_simulation_summary.json | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import math | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| import numpy as np | |
| REPO = Path(__file__).resolve().parents[1] | |
| COUNTS_CSV = REPO / "fluffyrock_3m.csv" | |
| SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" | |
| WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json" | |
| REGISTRY_CSV = REPO / "data" / "category_registry.csv" | |
| CATEGORY_TAG_GROUP_MAP_CSV = REPO / "data" / "analysis" / "category_tag_group_map.csv" | |
| PROBE_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" | |
| PROBE_SUMMARY_JSON = REPO / "data" / "analysis" / "probe_informativeness_summary.json" | |
| OUT_CSV = REPO / "data" / "analysis" / "probe_policy_simulation.csv" | |
| OUT_JSON = REPO / "data" / "analysis" / "probe_policy_simulation_summary.json" | |
| MIN_COUNT = 200 | |
| MIN_GROUP_IMAGES = 20 | |
| MIN_PROBE_IMAGES = 5 | |
| PROBE_POOL_SIZE = 120 | |
| PREVALENCE_MIN = 0.02 | |
| PREVALENCE_MAX = 0.60 | |
| BUDGETS = [3, 5, 8] | |
| TOP_M_VALUES = [5, 8] | |
| MODES = ["cold_start", "warm_start_easy2"] | |
| LAPLACE_ALPHA = 1.0 | |
| ENTROPY_EPS = 1e-12 | |
| def load_counts(path: Path) -> Dict[str, int]: | |
| out: Dict[str, int] = {} | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| if len(row) < 3: | |
| continue | |
| try: | |
| out[row[0]] = int(row[2]) if row[2] else 0 | |
| except ValueError: | |
| out[row[0]] = 0 | |
| return out | |
| def load_images(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]: | |
| images: List[Set[str]] = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| obj = json.loads(line) | |
| raw = obj.get("tags_ground_truth_categorized", "") | |
| if not raw: | |
| continue | |
| try: | |
| d = json.loads(raw) | |
| except Exception: | |
| continue | |
| tags: Set[str] = set() | |
| if isinstance(d, dict): | |
| for vals in d.values(): | |
| if isinstance(vals, list): | |
| for t in vals: | |
| if isinstance(t, str) and counts.get(t, 0) >= min_count: | |
| tags.add(t) | |
| if tags: | |
| images.append(tags) | |
| return images | |
| def load_excluded_wiki_groups(path: Path) -> Set[str]: | |
| excluded: Set[str] = set() | |
| if not path.is_file(): | |
| return excluded | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if (row.get("enabled") or "").strip() not in {"1", "true", "True"}: | |
| continue | |
| category = (row.get("category_name") or "").strip().lower() | |
| group = (row.get("tag_group") or "").strip() | |
| if category.startswith("ignored_") and group: | |
| excluded.add(group) | |
| return excluded | |
| def load_groups(excluded_wiki_groups: Set[str]) -> Dict[str, Set[str]]: | |
| groups: Dict[str, Set[str]] = {} | |
| with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f: | |
| wiki = json.load(f) | |
| for g, tags in wiki.items(): | |
| if g in excluded_wiki_groups: | |
| continue | |
| if isinstance(tags, list): | |
| groups[f"wiki:{g}"] = {t for t in tags if isinstance(t, str) and t} | |
| with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}: | |
| continue | |
| c = (row.get("category_name") or "").strip() | |
| t = (row.get("tag") or "").strip() | |
| if c and t: | |
| groups.setdefault(f"cat:{c}", set()).add(t) | |
| return groups | |
| def load_probe_candidates() -> Tuple[List[Dict[str, str]], List[str]]: | |
| rows = [] | |
| with PROBE_CSV.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for r in reader: | |
| rows.append(r) | |
| rows.sort(key=lambda r: float(r["actionable_score"]), reverse=True) | |
| filtered = [ | |
| r for r in rows | |
| if int(r["sample_occurrences"]) >= MIN_PROBE_IMAGES | |
| and PREVALENCE_MIN <= float(r["prevalence"]) <= PREVALENCE_MAX | |
| ][:PROBE_POOL_SIZE] | |
| mmr = [] | |
| if PROBE_SUMMARY_JSON.is_file(): | |
| s = json.loads(PROBE_SUMMARY_JSON.read_text(encoding="utf-8")) | |
| mmr = [t for t in s.get("diversified_probe_shortlist", [])] | |
| candidate_tags = [r["tag"] for r in filtered] | |
| mmr_filtered = [t for t in mmr if t in set(candidate_tags)] | |
| return filtered, mmr_filtered | |
| def entropy(p: np.ndarray) -> float: | |
| p = np.maximum(p, ENTROPY_EPS) | |
| return float(-np.sum(p * np.log2(p))) | |
| def normalize_probs(logp: np.ndarray) -> np.ndarray: | |
| z = logp - np.max(logp) | |
| e = np.exp(z) | |
| s = np.sum(e) | |
| return e / max(s, ENTROPY_EPS) | |
| def ndcg_binary(ranked_true_flags: List[int], m: int, n_true: int) -> float: | |
| if m <= 0: | |
| return 0.0 | |
| dcg = 0.0 | |
| for i, rel in enumerate(ranked_true_flags[:m]): | |
| if rel: | |
| dcg += 1.0 / math.log2(i + 2) | |
| ideal_k = min(n_true, m) | |
| if ideal_k == 0: | |
| return 0.0 | |
| idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_k)) | |
| return dcg / max(idcg, ENTROPY_EPS) | |
| def main() -> None: | |
| counts = load_counts(COUNTS_CSV) | |
| images = load_images(SAMPLE_JSONL, counts, MIN_COUNT) | |
| if not images: | |
| raise RuntimeError("No images loaded.") | |
| n_images = len(images) | |
| excluded_wiki_groups = load_excluded_wiki_groups(CATEGORY_TAG_GROUP_MAP_CSV) | |
| groups_all = load_groups(excluded_wiki_groups) | |
| # Keep active groups only. | |
| group_image_idxs: Dict[str, Set[int]] = {} | |
| for g, members in groups_all.items(): | |
| idxs = {i for i, tags in enumerate(images) if tags & members} | |
| if len(idxs) >= MIN_GROUP_IMAGES: | |
| group_image_idxs[g] = idxs | |
| group_names = sorted(group_image_idxs.keys()) | |
| n_groups = len(group_names) | |
| if n_groups == 0: | |
| raise RuntimeError("No active groups.") | |
| group_idx = {g: i for i, g in enumerate(group_names)} | |
| group_priors = np.zeros(n_groups, dtype=np.float64) | |
| for g, idxs in group_image_idxs.items(): | |
| group_priors[group_idx[g]] = (len(idxs) + LAPLACE_ALPHA) / (n_images + LAPLACE_ALPHA * n_groups) | |
| group_priors /= np.sum(group_priors) | |
| # Per-image true groups for evaluation. | |
| true_groups_by_image: List[Set[str]] = [] | |
| for tags in images: | |
| true_g = {g for g in group_names if tags & groups_all[g]} | |
| true_groups_by_image.append(true_g) | |
| probe_rows, mmr_shortlist = load_probe_candidates() | |
| if not probe_rows: | |
| raise RuntimeError("No probe candidates from probe_informativeness.csv.") | |
| candidate_tags = [r["tag"] for r in probe_rows] | |
| candidate_set = set(candidate_tags) | |
| top_actionable = candidate_tags | |
| # Fill mmr shortlist to have enough probes for larger budgets. | |
| mmr_full = list(mmr_shortlist) | |
| for t in top_actionable: | |
| if t not in mmr_full: | |
| mmr_full.append(t) | |
| easy_known_tags = {r["tag"] for r in probe_rows if r.get("needs_glossary", "0") == "0"} | |
| # Probe state precompute: presence by image. | |
| probe_present_by_image: Dict[str, np.ndarray] = {} | |
| for t in candidate_tags: | |
| arr = np.zeros(n_images, dtype=np.int8) | |
| for i, tags in enumerate(images): | |
| if t in tags: | |
| arr[i] = 1 | |
| probe_present_by_image[t] = arr | |
| # Likelihoods: P(probe=1 | group), smoothed. | |
| p1_given_group: Dict[str, np.ndarray] = {} | |
| for t in candidate_tags: | |
| arr = np.zeros(n_groups, dtype=np.float64) | |
| t_present = probe_present_by_image[t] | |
| for g, g_i in group_idx.items(): | |
| idxs = group_image_idxs[g] | |
| n_g = len(idxs) | |
| n_tg = int(np.sum([t_present[i] for i in idxs])) | |
| arr[g_i] = (n_tg + LAPLACE_ALPHA) / (n_g + 2 * LAPLACE_ALPHA) | |
| p1_given_group[t] = np.clip(arr, 1e-6, 1 - 1e-6) | |
| def posterior_from_evidence(evidence: Dict[str, int]) -> np.ndarray: | |
| logp = np.log(np.maximum(group_priors, ENTROPY_EPS)) | |
| for t, v in evidence.items(): | |
| if t not in p1_given_group: | |
| continue | |
| p1 = p1_given_group[t] | |
| if v == 1: | |
| logp += np.log(p1) | |
| else: | |
| logp += np.log(1 - p1) | |
| return normalize_probs(logp) | |
| def init_evidence(mode: str, image_i: int) -> Dict[str, int]: | |
| if mode == "cold_start": | |
| return {} | |
| if mode == "warm_start_easy2": | |
| tags = images[image_i] | |
| # Approximate "already known from prompt" with up to 2 easy tags present. | |
| present_easy = [t for t in top_actionable if t in easy_known_tags and t in tags] | |
| return {t: 1 for t in present_easy[:2]} | |
| raise ValueError(f"Unknown mode: {mode}") | |
| def choose_adaptive_entropy(image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]: | |
| chosen: List[str] = [] | |
| asked = set(evidence.keys()) | |
| for _ in range(budget): | |
| p = posterior_from_evidence(evidence) | |
| h0 = entropy(p) | |
| best_t = None | |
| best_gain = -1e9 | |
| for t in candidate_tags: | |
| if t in asked: | |
| continue | |
| p1g = p1_given_group[t] | |
| p_t1 = float(np.sum(p * p1g)) | |
| # posterior if t=1 | |
| p1 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(p1g)) | |
| h1 = entropy(p1) | |
| # posterior if t=0 | |
| p0 = normalize_probs(np.log(np.maximum(p, ENTROPY_EPS)) + np.log(1 - p1g)) | |
| h2 = entropy(p0) | |
| exp_h = p_t1 * h1 + (1 - p_t1) * h2 | |
| gain = h0 - exp_h | |
| if gain > best_gain: | |
| best_gain = gain | |
| best_t = t | |
| if best_t is None: | |
| break | |
| chosen.append(best_t) | |
| asked.add(best_t) | |
| # Oracle observation. | |
| evidence[best_t] = int(probe_present_by_image[best_t][image_i]) | |
| return chosen | |
| def choose_fixed(order: List[str], image_i: int, budget: int, evidence: Dict[str, int]) -> List[str]: | |
| chosen = [] | |
| asked = set(evidence.keys()) | |
| for t in order: | |
| if t in asked: | |
| continue | |
| chosen.append(t) | |
| asked.add(t) | |
| evidence[t] = int(probe_present_by_image[t][image_i]) # oracle observation | |
| if len(chosen) >= budget: | |
| break | |
| return chosen | |
| strategy_orders = { | |
| "fixed_top_actionable": top_actionable, | |
| "fixed_mmr": mmr_full, | |
| } | |
| metric_rows: List[Dict[str, str]] = [] | |
| for mode in MODES: | |
| for budget in BUDGETS: | |
| for strategy in ["baseline_no_probe", "fixed_top_actionable", "fixed_mmr", "adaptive_entropy"]: | |
| per_top_m = {m: defaultdict(float) for m in TOP_M_VALUES} | |
| for i in range(n_images): | |
| ev = init_evidence(mode, i) | |
| if strategy == "baseline_no_probe": | |
| pass | |
| elif strategy == "adaptive_entropy": | |
| choose_adaptive_entropy(i, budget, ev) | |
| else: | |
| choose_fixed(strategy_orders[strategy], i, budget, ev) | |
| post = posterior_from_evidence(ev) | |
| ranking = np.argsort(post)[::-1] | |
| ranked_groups = [group_names[j] for j in ranking] | |
| true_g = true_groups_by_image[i] | |
| for m in TOP_M_VALUES: | |
| topm = ranked_groups[:m] | |
| n_true = len(true_g) | |
| n_hit = len(set(topm) & true_g) | |
| hit = 1.0 if n_hit > 0 else 0.0 | |
| rec = n_hit / n_true if n_true > 0 else 0.0 | |
| prec = n_hit / m if m > 0 else 0.0 | |
| flags = [1 if g in true_g else 0 for g in topm] | |
| ndcg = ndcg_binary(flags, m, n_true) | |
| true_mass = float(np.sum([post[group_idx[g]] for g in true_g])) if true_g else 0.0 | |
| topm_true_mass = float(np.sum([post[group_idx[g]] for g in topm if g in true_g])) | |
| per_top_m[m]["hit"] += hit | |
| per_top_m[m]["rec"] += rec | |
| per_top_m[m]["prec"] += prec | |
| per_top_m[m]["ndcg"] += ndcg | |
| per_top_m[m]["true_mass"] += true_mass | |
| per_top_m[m]["topm_true_mass"] += topm_true_mass | |
| for m in TOP_M_VALUES: | |
| agg = per_top_m[m] | |
| metric_rows.append( | |
| { | |
| "mode": mode, | |
| "strategy": strategy, | |
| "budget": str(budget), | |
| "top_m": str(m), | |
| "hit_at_m": f"{agg['hit'] / n_images:.6f}", | |
| "recall_at_m": f"{agg['rec'] / n_images:.6f}", | |
| "precision_at_m": f"{agg['prec'] / n_images:.6f}", | |
| "ndcg_at_m": f"{agg['ndcg'] / n_images:.6f}", | |
| "true_mass": f"{agg['true_mass'] / n_images:.6f}", | |
| "topm_true_mass": f"{agg['topm_true_mass'] / n_images:.6f}", | |
| } | |
| ) | |
| OUT_CSV.parent.mkdir(parents=True, exist_ok=True) | |
| with OUT_CSV.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=[ | |
| "mode", | |
| "strategy", | |
| "budget", | |
| "top_m", | |
| "hit_at_m", | |
| "recall_at_m", | |
| "precision_at_m", | |
| "ndcg_at_m", | |
| "true_mass", | |
| "topm_true_mass", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(metric_rows) | |
| # Quick "is it likely useful?" summary at top_m=5, budget=5. | |
| lookup = { | |
| (r["mode"], r["strategy"], r["budget"], r["top_m"]): r | |
| for r in metric_rows | |
| } | |
| key = lambda mode, strategy: lookup[(mode, strategy, "5", "5")] | |
| likely_useful = [] | |
| for mode in MODES: | |
| b = key(mode, "baseline_no_probe") | |
| a = key(mode, "adaptive_entropy") | |
| t = key(mode, "fixed_top_actionable") | |
| likely_useful.append( | |
| { | |
| "mode": mode, | |
| "baseline_ndcg_at_5": float(b["ndcg_at_m"]), | |
| "fixed_top_ndcg_at_5": float(t["ndcg_at_m"]), | |
| "adaptive_ndcg_at_5": float(a["ndcg_at_m"]), | |
| "adaptive_minus_fixed_top_ndcg_at_5": float(a["ndcg_at_m"]) - float(t["ndcg_at_m"]), | |
| "adaptive_minus_baseline_ndcg_at_5": float(a["ndcg_at_m"]) - float(b["ndcg_at_m"]), | |
| } | |
| ) | |
| summary = { | |
| "config": { | |
| "min_count": MIN_COUNT, | |
| "min_group_images": MIN_GROUP_IMAGES, | |
| "min_probe_images": MIN_PROBE_IMAGES, | |
| "probe_pool_size": PROBE_POOL_SIZE, | |
| "prevalence_min": PREVALENCE_MIN, | |
| "prevalence_max": PREVALENCE_MAX, | |
| "budgets": BUDGETS, | |
| "top_m_values": TOP_M_VALUES, | |
| "modes": MODES, | |
| "laplace_alpha": LAPLACE_ALPHA, | |
| "note": "Oracle probe answers from GT tags; optimistic upper bound.", | |
| }, | |
| "n_images": n_images, | |
| "n_active_groups": n_groups, | |
| "n_candidate_probes": len(candidate_tags), | |
| "excluded_wiki_groups": sorted(excluded_wiki_groups), | |
| "probe_pool_head": candidate_tags[:30], | |
| "mmr_head": mmr_full[:30], | |
| "likely_useful_snapshot_budget5_top5": likely_useful, | |
| "outputs": { | |
| "csv": str(OUT_CSV), | |
| "summary_json": str(OUT_JSON), | |
| }, | |
| } | |
| with OUT_JSON.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| print(f"Images: {n_images}") | |
| print(f"Active groups: {n_groups}") | |
| print(f"Candidate probes: {len(candidate_tags)}") | |
| print("Snapshot budget=5 top_m=5:", likely_useful) | |
| print(f"Outputs: {OUT_CSV}, {OUT_JSON}") | |
| if __name__ == "__main__": | |
| main() | |