Spaces:
Sleeping
Sleeping
| """Deterministic plan of (profile, base_prompt, backend, seed, strategy) tuples. | |
| The plan is produced once and written to `data/plan.jsonl`. Image generation and | |
| MCQ construction both consume this file, so as long as the plan is stable the | |
| entire dataset is reproducible bit-for-bit. | |
| """ | |
| from __future__ import annotations | |
| import itertools | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| from aamcq.distractors import STRATEGIES, Strategy | |
| from aamcq.profile import VisualProfile, enumerate_profiles | |
| from aamcq.utils.seeding import item_seed | |
| BASE_PROMPT_CATEGORIES = ("portrait", "landscape", "animal", "still_life", "architecture") | |
| class PlanItem: | |
| item_id: str | |
| gt_profile: dict[str, str] | |
| base_prompt: str | |
| base_prompt_category: str | |
| backend: str | |
| seed: int | |
| distractor_strategy: str | |
| difficulty: str | |
| def to_dict(self) -> dict: | |
| return asdict(self) | |
| def _load_base_prompts(path: str | Path) -> dict[str, list[str]]: | |
| """Load `clean`-quality prompts per category from base_prompts.yaml.""" | |
| with open(path) as f: | |
| data = yaml.safe_load(f) | |
| cats = data.get("categories", {}) | |
| recommended = data.get("recommended", {}) | |
| out: dict[str, list[str]] = {} | |
| for category in BASE_PROMPT_CATEGORIES: | |
| entries = cats.get(category, {}).get("prompts", []) | |
| prompts = [e["text"] for e in entries if e.get("quality") == "clean"] | |
| if not prompts and recommended.get(category): | |
| prompts = [recommended[category]] | |
| if not prompts: | |
| raise ValueError(f"base_prompts.yaml missing category {category!r}") | |
| out[category] = prompts | |
| return out | |
| def _allocate_mix( | |
| n: int, mix: dict[str, float], rng: np.random.Generator, what: str | |
| ) -> list[str]: | |
| if abs(sum(mix.values()) - 1.0) > 1e-6: | |
| raise ValueError(f"{what} mix must sum to 1.0, got {sum(mix.values())}") | |
| counts = {name: int(round(frac * n)) for name, frac in mix.items()} | |
| drift = n - sum(counts.values()) | |
| names = list(mix.keys()) | |
| i = 0 | |
| while drift != 0: | |
| name = names[i % len(names)] | |
| if drift > 0: | |
| counts[name] += 1 | |
| drift -= 1 | |
| elif counts[name] > 0: | |
| counts[name] -= 1 | |
| drift += 1 | |
| i += 1 | |
| slots: list[str] = [] | |
| for name, count in counts.items(): | |
| slots.extend([name] * count) | |
| rng.shuffle(slots) | |
| return slots | |
| def stratified_sample_by_style( | |
| vocab: dict[str, list[str]], | |
| n_target: int, | |
| rng: np.random.Generator, | |
| small_pool_cap: int = 50, | |
| ) -> list[VisualProfile]: | |
| """Proportional stratified sample over art_style groups. | |
| Pools smaller than `small_pool_cap` (e.g. Photorealism's 36 under the | |
| current compat filter) are sampled in full so every profile in a | |
| minority style appears at least once. Final size is approximately | |
| `n_target` but may vary by a few due to rounding + small-pool expansion. | |
| """ | |
| all_profiles = list(enumerate_profiles(vocab)) | |
| by_style: dict[str, list[VisualProfile]] = {} | |
| for p in all_profiles: | |
| by_style.setdefault(p.art_style, []).append(p) | |
| total = len(all_profiles) | |
| sampled: list[VisualProfile] = [] | |
| for pool in by_style.values(): | |
| if len(pool) < small_pool_cap: | |
| want = len(pool) | |
| else: | |
| want = min(round(len(pool) / total * n_target), len(pool)) | |
| idx = rng.choice(len(pool), size=want, replace=False) | |
| sampled.extend(pool[int(i)] for i in idx) | |
| rng.shuffle(sampled) | |
| return sampled | |
| def build_plan( | |
| vocab: dict[str, list[str]], | |
| base_prompts_path: str | Path, | |
| distractor_policy: dict, | |
| generation_mix: dict[str, float], | |
| n_random: int, | |
| master_seed: int = 202, | |
| stratified: bool = False, | |
| ) -> list[PlanItem]: | |
| rng = np.random.default_rng(master_seed) | |
| base_prompts = _load_base_prompts(base_prompts_path) | |
| if n_random <= 0: | |
| return [] | |
| for name in distractor_policy["strategy_mix"]: | |
| if name not in STRATEGIES: | |
| raise ValueError(f"unknown strategy {name!r}; expected {STRATEGIES}") | |
| if stratified: | |
| profiles = stratified_sample_by_style(vocab, n_random, rng) | |
| else: | |
| all_profiles = list(enumerate_profiles(vocab)) | |
| idx = rng.choice(len(all_profiles), size=n_random, replace=False) | |
| profiles = [all_profiles[int(i)] for i in idx] | |
| n_items = len(profiles) | |
| sources: list[tuple[str, VisualProfile]] = [ | |
| (f"rnd_{k:04d}", p) for k, p in enumerate(profiles) | |
| ] | |
| strategies = _allocate_mix(n_items, distractor_policy["strategy_mix"], rng, "strategy") | |
| backends = _allocate_mix(n_items, generation_mix, rng, "backend") | |
| prompt_cycle = itertools.cycle(BASE_PROMPT_CATEGORIES) | |
| for _ in range(int(rng.integers(0, len(BASE_PROMPT_CATEGORIES)))): | |
| next(prompt_cycle) | |
| plan: list[PlanItem] = [] | |
| for (iid, profile), strat, backend in zip(sources, strategies, backends): | |
| category = next(prompt_cycle) | |
| pool = base_prompts[category] | |
| prompt_rng = np.random.default_rng(item_seed(iid, master_seed, "prompt")) | |
| base_prompt = str(pool[int(prompt_rng.integers(0, len(pool)))]) | |
| item_id = f"ab_mcq_{len(plan):05d}_{iid}" | |
| plan.append( | |
| PlanItem( | |
| item_id=item_id, | |
| gt_profile=profile.to_dict(), | |
| base_prompt=base_prompt, | |
| base_prompt_category=category, | |
| backend=backend, | |
| seed=item_seed(item_id, master_seed, "gen"), | |
| distractor_strategy=strat, | |
| difficulty="medium", | |
| ) | |
| ) | |
| return plan | |