File size: 5,732 Bytes
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Deterministic plan of (profile, base_prompt, backend, seed, strategy) tuples.

The plan is produced once and written to `data/plan.jsonl`. Image generation and
MCQ construction both consume this file, so as long as the plan is stable the
entire dataset is reproducible bit-for-bit.
"""

from __future__ import annotations

import itertools
from dataclasses import asdict, dataclass
from pathlib import Path

import numpy as np
import yaml

from aamcq.distractors import STRATEGIES, Strategy
from aamcq.profile import VisualProfile, enumerate_profiles
from aamcq.utils.seeding import item_seed

BASE_PROMPT_CATEGORIES = ("portrait", "landscape", "animal", "still_life", "architecture")


@dataclass(frozen=True)
class PlanItem:
    item_id: str
    gt_profile: dict[str, str]
    base_prompt: str
    base_prompt_category: str
    backend: str
    seed: int
    distractor_strategy: str
    difficulty: str

    def to_dict(self) -> dict:
        return asdict(self)


def _load_base_prompts(path: str | Path) -> dict[str, list[str]]:
    """Load `clean`-quality prompts per category from base_prompts.yaml."""
    with open(path) as f:
        data = yaml.safe_load(f)
    cats = data.get("categories", {})
    recommended = data.get("recommended", {})
    out: dict[str, list[str]] = {}
    for category in BASE_PROMPT_CATEGORIES:
        entries = cats.get(category, {}).get("prompts", [])
        prompts = [e["text"] for e in entries if e.get("quality") == "clean"]
        if not prompts and recommended.get(category):
            prompts = [recommended[category]]
        if not prompts:
            raise ValueError(f"base_prompts.yaml missing category {category!r}")
        out[category] = prompts
    return out


def _allocate_mix(
    n: int, mix: dict[str, float], rng: np.random.Generator, what: str
) -> list[str]:
    if abs(sum(mix.values()) - 1.0) > 1e-6:
        raise ValueError(f"{what} mix must sum to 1.0, got {sum(mix.values())}")
    counts = {name: int(round(frac * n)) for name, frac in mix.items()}
    drift = n - sum(counts.values())
    names = list(mix.keys())
    i = 0
    while drift != 0:
        name = names[i % len(names)]
        if drift > 0:
            counts[name] += 1
            drift -= 1
        elif counts[name] > 0:
            counts[name] -= 1
            drift += 1
        i += 1
    slots: list[str] = []
    for name, count in counts.items():
        slots.extend([name] * count)
    rng.shuffle(slots)
    return slots


def stratified_sample_by_style(
    vocab: dict[str, list[str]],
    n_target: int,
    rng: np.random.Generator,
    small_pool_cap: int = 50,
) -> list[VisualProfile]:
    """Proportional stratified sample over art_style groups.

    Pools smaller than `small_pool_cap` (e.g. Photorealism's 36 under the
    current compat filter) are sampled in full so every profile in a
    minority style appears at least once. Final size is approximately
    `n_target` but may vary by a few due to rounding + small-pool expansion.
    """
    all_profiles = list(enumerate_profiles(vocab))
    by_style: dict[str, list[VisualProfile]] = {}
    for p in all_profiles:
        by_style.setdefault(p.art_style, []).append(p)

    total = len(all_profiles)
    sampled: list[VisualProfile] = []
    for pool in by_style.values():
        if len(pool) < small_pool_cap:
            want = len(pool)
        else:
            want = min(round(len(pool) / total * n_target), len(pool))
        idx = rng.choice(len(pool), size=want, replace=False)
        sampled.extend(pool[int(i)] for i in idx)
    rng.shuffle(sampled)
    return sampled


def build_plan(
    vocab: dict[str, list[str]],
    base_prompts_path: str | Path,
    distractor_policy: dict,
    generation_mix: dict[str, float],
    n_random: int,
    master_seed: int = 202,
    stratified: bool = False,
) -> list[PlanItem]:
    rng = np.random.default_rng(master_seed)
    base_prompts = _load_base_prompts(base_prompts_path)

    if n_random <= 0:
        return []

    for name in distractor_policy["strategy_mix"]:
        if name not in STRATEGIES:
            raise ValueError(f"unknown strategy {name!r}; expected {STRATEGIES}")

    if stratified:
        profiles = stratified_sample_by_style(vocab, n_random, rng)
    else:
        all_profiles = list(enumerate_profiles(vocab))
        idx = rng.choice(len(all_profiles), size=n_random, replace=False)
        profiles = [all_profiles[int(i)] for i in idx]
    n_items = len(profiles)
    sources: list[tuple[str, VisualProfile]] = [
        (f"rnd_{k:04d}", p) for k, p in enumerate(profiles)
    ]

    strategies = _allocate_mix(n_items, distractor_policy["strategy_mix"], rng, "strategy")
    backends = _allocate_mix(n_items, generation_mix, rng, "backend")

    prompt_cycle = itertools.cycle(BASE_PROMPT_CATEGORIES)
    for _ in range(int(rng.integers(0, len(BASE_PROMPT_CATEGORIES)))):
        next(prompt_cycle)

    plan: list[PlanItem] = []
    for (iid, profile), strat, backend in zip(sources, strategies, backends):
        category = next(prompt_cycle)
        pool = base_prompts[category]
        prompt_rng = np.random.default_rng(item_seed(iid, master_seed, "prompt"))
        base_prompt = str(pool[int(prompt_rng.integers(0, len(pool)))])

        item_id = f"ab_mcq_{len(plan):05d}_{iid}"
        plan.append(
            PlanItem(
                item_id=item_id,
                gt_profile=profile.to_dict(),
                base_prompt=base_prompt,
                base_prompt_category=category,
                backend=backend,
                seed=item_seed(item_id, master_seed, "gen"),
                distractor_strategy=strat,
                difficulty="medium",
            )
        )
    return plan