aesthetic-annotators / src /aamcq /instance_plan.py
lanczos's picture
deploy: labeling server
871ff87 verified
"""Deterministic plan of (profile, base_prompt, backend, seed, strategy) tuples.
The plan is produced once and written to `data/plan.jsonl`. Image generation and
MCQ construction both consume this file, so as long as the plan is stable the
entire dataset is reproducible bit-for-bit.
"""
from __future__ import annotations
import itertools
from dataclasses import asdict, dataclass
from pathlib import Path
import numpy as np
import yaml
from aamcq.distractors import STRATEGIES, Strategy
from aamcq.profile import VisualProfile, enumerate_profiles
from aamcq.utils.seeding import item_seed
BASE_PROMPT_CATEGORIES = ("portrait", "landscape", "animal", "still_life", "architecture")
@dataclass(frozen=True)
class PlanItem:
item_id: str
gt_profile: dict[str, str]
base_prompt: str
base_prompt_category: str
backend: str
seed: int
distractor_strategy: str
difficulty: str
def to_dict(self) -> dict:
return asdict(self)
def _load_base_prompts(path: str | Path) -> dict[str, list[str]]:
"""Load `clean`-quality prompts per category from base_prompts.yaml."""
with open(path) as f:
data = yaml.safe_load(f)
cats = data.get("categories", {})
recommended = data.get("recommended", {})
out: dict[str, list[str]] = {}
for category in BASE_PROMPT_CATEGORIES:
entries = cats.get(category, {}).get("prompts", [])
prompts = [e["text"] for e in entries if e.get("quality") == "clean"]
if not prompts and recommended.get(category):
prompts = [recommended[category]]
if not prompts:
raise ValueError(f"base_prompts.yaml missing category {category!r}")
out[category] = prompts
return out
def _allocate_mix(
n: int, mix: dict[str, float], rng: np.random.Generator, what: str
) -> list[str]:
if abs(sum(mix.values()) - 1.0) > 1e-6:
raise ValueError(f"{what} mix must sum to 1.0, got {sum(mix.values())}")
counts = {name: int(round(frac * n)) for name, frac in mix.items()}
drift = n - sum(counts.values())
names = list(mix.keys())
i = 0
while drift != 0:
name = names[i % len(names)]
if drift > 0:
counts[name] += 1
drift -= 1
elif counts[name] > 0:
counts[name] -= 1
drift += 1
i += 1
slots: list[str] = []
for name, count in counts.items():
slots.extend([name] * count)
rng.shuffle(slots)
return slots
def stratified_sample_by_style(
vocab: dict[str, list[str]],
n_target: int,
rng: np.random.Generator,
small_pool_cap: int = 50,
) -> list[VisualProfile]:
"""Proportional stratified sample over art_style groups.
Pools smaller than `small_pool_cap` (e.g. Photorealism's 36 under the
current compat filter) are sampled in full so every profile in a
minority style appears at least once. Final size is approximately
`n_target` but may vary by a few due to rounding + small-pool expansion.
"""
all_profiles = list(enumerate_profiles(vocab))
by_style: dict[str, list[VisualProfile]] = {}
for p in all_profiles:
by_style.setdefault(p.art_style, []).append(p)
total = len(all_profiles)
sampled: list[VisualProfile] = []
for pool in by_style.values():
if len(pool) < small_pool_cap:
want = len(pool)
else:
want = min(round(len(pool) / total * n_target), len(pool))
idx = rng.choice(len(pool), size=want, replace=False)
sampled.extend(pool[int(i)] for i in idx)
rng.shuffle(sampled)
return sampled
def build_plan(
vocab: dict[str, list[str]],
base_prompts_path: str | Path,
distractor_policy: dict,
generation_mix: dict[str, float],
n_random: int,
master_seed: int = 202,
stratified: bool = False,
) -> list[PlanItem]:
rng = np.random.default_rng(master_seed)
base_prompts = _load_base_prompts(base_prompts_path)
if n_random <= 0:
return []
for name in distractor_policy["strategy_mix"]:
if name not in STRATEGIES:
raise ValueError(f"unknown strategy {name!r}; expected {STRATEGIES}")
if stratified:
profiles = stratified_sample_by_style(vocab, n_random, rng)
else:
all_profiles = list(enumerate_profiles(vocab))
idx = rng.choice(len(all_profiles), size=n_random, replace=False)
profiles = [all_profiles[int(i)] for i in idx]
n_items = len(profiles)
sources: list[tuple[str, VisualProfile]] = [
(f"rnd_{k:04d}", p) for k, p in enumerate(profiles)
]
strategies = _allocate_mix(n_items, distractor_policy["strategy_mix"], rng, "strategy")
backends = _allocate_mix(n_items, generation_mix, rng, "backend")
prompt_cycle = itertools.cycle(BASE_PROMPT_CATEGORIES)
for _ in range(int(rng.integers(0, len(BASE_PROMPT_CATEGORIES)))):
next(prompt_cycle)
plan: list[PlanItem] = []
for (iid, profile), strat, backend in zip(sources, strategies, backends):
category = next(prompt_cycle)
pool = base_prompts[category]
prompt_rng = np.random.default_rng(item_seed(iid, master_seed, "prompt"))
base_prompt = str(pool[int(prompt_rng.integers(0, len(pool)))])
item_id = f"ab_mcq_{len(plan):05d}_{iid}"
plan.append(
PlanItem(
item_id=item_id,
gt_profile=profile.to_dict(),
base_prompt=base_prompt,
base_prompt_category=category,
backend=backend,
seed=item_seed(item_id, master_seed, "gen"),
distractor_strategy=strat,
difficulty="medium",
)
)
return plan