Spaces:
Sleeping
Sleeping
File size: 5,732 Bytes
871ff87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Deterministic plan of (profile, base_prompt, backend, seed, strategy) tuples.
The plan is produced once and written to `data/plan.jsonl`. Image generation and
MCQ construction both consume this file, so as long as the plan is stable the
entire dataset is reproducible bit-for-bit.
"""
from __future__ import annotations
import itertools
from dataclasses import asdict, dataclass
from pathlib import Path
import numpy as np
import yaml
from aamcq.distractors import STRATEGIES, Strategy
from aamcq.profile import VisualProfile, enumerate_profiles
from aamcq.utils.seeding import item_seed
BASE_PROMPT_CATEGORIES = ("portrait", "landscape", "animal", "still_life", "architecture")
@dataclass(frozen=True)
class PlanItem:
item_id: str
gt_profile: dict[str, str]
base_prompt: str
base_prompt_category: str
backend: str
seed: int
distractor_strategy: str
difficulty: str
def to_dict(self) -> dict:
return asdict(self)
def _load_base_prompts(path: str | Path) -> dict[str, list[str]]:
"""Load `clean`-quality prompts per category from base_prompts.yaml."""
with open(path) as f:
data = yaml.safe_load(f)
cats = data.get("categories", {})
recommended = data.get("recommended", {})
out: dict[str, list[str]] = {}
for category in BASE_PROMPT_CATEGORIES:
entries = cats.get(category, {}).get("prompts", [])
prompts = [e["text"] for e in entries if e.get("quality") == "clean"]
if not prompts and recommended.get(category):
prompts = [recommended[category]]
if not prompts:
raise ValueError(f"base_prompts.yaml missing category {category!r}")
out[category] = prompts
return out
def _allocate_mix(
n: int, mix: dict[str, float], rng: np.random.Generator, what: str
) -> list[str]:
if abs(sum(mix.values()) - 1.0) > 1e-6:
raise ValueError(f"{what} mix must sum to 1.0, got {sum(mix.values())}")
counts = {name: int(round(frac * n)) for name, frac in mix.items()}
drift = n - sum(counts.values())
names = list(mix.keys())
i = 0
while drift != 0:
name = names[i % len(names)]
if drift > 0:
counts[name] += 1
drift -= 1
elif counts[name] > 0:
counts[name] -= 1
drift += 1
i += 1
slots: list[str] = []
for name, count in counts.items():
slots.extend([name] * count)
rng.shuffle(slots)
return slots
def stratified_sample_by_style(
vocab: dict[str, list[str]],
n_target: int,
rng: np.random.Generator,
small_pool_cap: int = 50,
) -> list[VisualProfile]:
"""Proportional stratified sample over art_style groups.
Pools smaller than `small_pool_cap` (e.g. Photorealism's 36 under the
current compat filter) are sampled in full so every profile in a
minority style appears at least once. Final size is approximately
`n_target` but may vary by a few due to rounding + small-pool expansion.
"""
all_profiles = list(enumerate_profiles(vocab))
by_style: dict[str, list[VisualProfile]] = {}
for p in all_profiles:
by_style.setdefault(p.art_style, []).append(p)
total = len(all_profiles)
sampled: list[VisualProfile] = []
for pool in by_style.values():
if len(pool) < small_pool_cap:
want = len(pool)
else:
want = min(round(len(pool) / total * n_target), len(pool))
idx = rng.choice(len(pool), size=want, replace=False)
sampled.extend(pool[int(i)] for i in idx)
rng.shuffle(sampled)
return sampled
def build_plan(
vocab: dict[str, list[str]],
base_prompts_path: str | Path,
distractor_policy: dict,
generation_mix: dict[str, float],
n_random: int,
master_seed: int = 202,
stratified: bool = False,
) -> list[PlanItem]:
rng = np.random.default_rng(master_seed)
base_prompts = _load_base_prompts(base_prompts_path)
if n_random <= 0:
return []
for name in distractor_policy["strategy_mix"]:
if name not in STRATEGIES:
raise ValueError(f"unknown strategy {name!r}; expected {STRATEGIES}")
if stratified:
profiles = stratified_sample_by_style(vocab, n_random, rng)
else:
all_profiles = list(enumerate_profiles(vocab))
idx = rng.choice(len(all_profiles), size=n_random, replace=False)
profiles = [all_profiles[int(i)] for i in idx]
n_items = len(profiles)
sources: list[tuple[str, VisualProfile]] = [
(f"rnd_{k:04d}", p) for k, p in enumerate(profiles)
]
strategies = _allocate_mix(n_items, distractor_policy["strategy_mix"], rng, "strategy")
backends = _allocate_mix(n_items, generation_mix, rng, "backend")
prompt_cycle = itertools.cycle(BASE_PROMPT_CATEGORIES)
for _ in range(int(rng.integers(0, len(BASE_PROMPT_CATEGORIES)))):
next(prompt_cycle)
plan: list[PlanItem] = []
for (iid, profile), strat, backend in zip(sources, strategies, backends):
category = next(prompt_cycle)
pool = base_prompts[category]
prompt_rng = np.random.default_rng(item_seed(iid, master_seed, "prompt"))
base_prompt = str(pool[int(prompt_rng.integers(0, len(pool)))])
item_id = f"ab_mcq_{len(plan):05d}_{iid}"
plan.append(
PlanItem(
item_id=item_id,
gt_profile=profile.to_dict(),
base_prompt=base_prompt,
base_prompt_category=category,
backend=backend,
seed=item_seed(item_id, master_seed, "gen"),
distractor_strategy=strat,
difficulty="medium",
)
)
return plan
|