"""Deterministic plan of (profile, base_prompt, backend, seed, strategy) tuples. The plan is produced once and written to `data/plan.jsonl`. Image generation and MCQ construction both consume this file, so as long as the plan is stable the entire dataset is reproducible bit-for-bit. """ from __future__ import annotations import itertools from dataclasses import asdict, dataclass from pathlib import Path import numpy as np import yaml from aamcq.distractors import STRATEGIES, Strategy from aamcq.profile import VisualProfile, enumerate_profiles from aamcq.utils.seeding import item_seed BASE_PROMPT_CATEGORIES = ("portrait", "landscape", "animal", "still_life", "architecture") @dataclass(frozen=True) class PlanItem: item_id: str gt_profile: dict[str, str] base_prompt: str base_prompt_category: str backend: str seed: int distractor_strategy: str difficulty: str def to_dict(self) -> dict: return asdict(self) def _load_base_prompts(path: str | Path) -> dict[str, list[str]]: """Load `clean`-quality prompts per category from base_prompts.yaml.""" with open(path) as f: data = yaml.safe_load(f) cats = data.get("categories", {}) recommended = data.get("recommended", {}) out: dict[str, list[str]] = {} for category in BASE_PROMPT_CATEGORIES: entries = cats.get(category, {}).get("prompts", []) prompts = [e["text"] for e in entries if e.get("quality") == "clean"] if not prompts and recommended.get(category): prompts = [recommended[category]] if not prompts: raise ValueError(f"base_prompts.yaml missing category {category!r}") out[category] = prompts return out def _allocate_mix( n: int, mix: dict[str, float], rng: np.random.Generator, what: str ) -> list[str]: if abs(sum(mix.values()) - 1.0) > 1e-6: raise ValueError(f"{what} mix must sum to 1.0, got {sum(mix.values())}") counts = {name: int(round(frac * n)) for name, frac in mix.items()} drift = n - sum(counts.values()) names = list(mix.keys()) i = 0 while drift != 0: name = names[i % len(names)] if drift > 0: counts[name] += 1 drift -= 1 elif counts[name] > 0: counts[name] -= 1 drift += 1 i += 1 slots: list[str] = [] for name, count in counts.items(): slots.extend([name] * count) rng.shuffle(slots) return slots def stratified_sample_by_style( vocab: dict[str, list[str]], n_target: int, rng: np.random.Generator, small_pool_cap: int = 50, ) -> list[VisualProfile]: """Proportional stratified sample over art_style groups. Pools smaller than `small_pool_cap` (e.g. Photorealism's 36 under the current compat filter) are sampled in full so every profile in a minority style appears at least once. Final size is approximately `n_target` but may vary by a few due to rounding + small-pool expansion. """ all_profiles = list(enumerate_profiles(vocab)) by_style: dict[str, list[VisualProfile]] = {} for p in all_profiles: by_style.setdefault(p.art_style, []).append(p) total = len(all_profiles) sampled: list[VisualProfile] = [] for pool in by_style.values(): if len(pool) < small_pool_cap: want = len(pool) else: want = min(round(len(pool) / total * n_target), len(pool)) idx = rng.choice(len(pool), size=want, replace=False) sampled.extend(pool[int(i)] for i in idx) rng.shuffle(sampled) return sampled def build_plan( vocab: dict[str, list[str]], base_prompts_path: str | Path, distractor_policy: dict, generation_mix: dict[str, float], n_random: int, master_seed: int = 202, stratified: bool = False, ) -> list[PlanItem]: rng = np.random.default_rng(master_seed) base_prompts = _load_base_prompts(base_prompts_path) if n_random <= 0: return [] for name in distractor_policy["strategy_mix"]: if name not in STRATEGIES: raise ValueError(f"unknown strategy {name!r}; expected {STRATEGIES}") if stratified: profiles = stratified_sample_by_style(vocab, n_random, rng) else: all_profiles = list(enumerate_profiles(vocab)) idx = rng.choice(len(all_profiles), size=n_random, replace=False) profiles = [all_profiles[int(i)] for i in idx] n_items = len(profiles) sources: list[tuple[str, VisualProfile]] = [ (f"rnd_{k:04d}", p) for k, p in enumerate(profiles) ] strategies = _allocate_mix(n_items, distractor_policy["strategy_mix"], rng, "strategy") backends = _allocate_mix(n_items, generation_mix, rng, "backend") prompt_cycle = itertools.cycle(BASE_PROMPT_CATEGORIES) for _ in range(int(rng.integers(0, len(BASE_PROMPT_CATEGORIES)))): next(prompt_cycle) plan: list[PlanItem] = [] for (iid, profile), strat, backend in zip(sources, strategies, backends): category = next(prompt_cycle) pool = base_prompts[category] prompt_rng = np.random.default_rng(item_seed(iid, master_seed, "prompt")) base_prompt = str(pool[int(prompt_rng.integers(0, len(pool)))]) item_id = f"ab_mcq_{len(plan):05d}_{iid}" plan.append( PlanItem( item_id=item_id, gt_profile=profile.to_dict(), base_prompt=base_prompt, base_prompt_category=category, backend=backend, seed=item_seed(item_id, master_seed, "gen"), distractor_strategy=strat, difficulty="medium", ) ) return plan