File size: 6,252 Bytes
ed37502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Content variation engine for generating combinatorial batches.

Given a prompt template and a character, the variation engine produces
multiple generation jobs with different combinations of poses, outfits,
emotions, camera angles, and other variable attributes.
"""

from __future__ import annotations

import itertools
import random
import uuid
from dataclasses import dataclass, field
from typing import Any

from content_engine.services.template_engine import TemplateEngine


@dataclass
class CharacterProfile:
    """Character configuration loaded from YAML."""

    id: str
    name: str
    trigger_word: str
    lora_filename: str
    lora_strength: float = 0.85
    default_checkpoint: str | None = None
    style_loras: list[dict[str, Any]] = field(default_factory=list)
    description: str = ""
    physical_traits: dict[str, str] = field(default_factory=dict)


@dataclass
class VariationJob:
    """A single generation job produced by the variation engine."""

    job_id: str
    batch_id: str
    character: CharacterProfile
    template_id: str
    content_rating: str
    variables: dict[str, str]
    seed: int
    loras: list[dict[str, Any]]


class VariationEngine:
    """Generates batches of variation jobs from templates."""

    def __init__(self, template_engine: TemplateEngine):
        self.template_engine = template_engine

    def generate_batch(
        self,
        template_id: str,
        character: CharacterProfile,
        *,
        content_rating: str = "sfw",
        count: int = 10,
        variation_mode: str = "random",  # curated | random | exhaustive
        pin: dict[str, str] | None = None,
        seed_strategy: str = "random",  # random | sequential | fixed
        base_seed: int | None = None,
    ) -> list[VariationJob]:
        """Generate a batch of variation jobs.

        Args:
            template_id: Which prompt template to use.
            character: Character profile for LoRA and trigger word.
            content_rating: "sfw" or "nsfw".
            count: Number of variations to generate.
            variation_mode: How to select variable combinations.
            pin: Variables to keep fixed across all variations.
            seed_strategy: How to assign seeds.
            base_seed: Starting seed for sequential strategy.
        """
        template = self.template_engine.get(template_id)
        pin = pin or {}
        batch_id = str(uuid.uuid4())

        # Build variable combinations
        combos = self._select_combinations(template_id, count, variation_mode, pin)

        # Inject character-specific variables
        for combo in combos:
            combo["character_trigger"] = character.trigger_word
            combo["character_lora"] = character.lora_filename

        # Build LoRA list for each job
        base_loras = [
            {
                "name": character.lora_filename,
                "strength_model": character.lora_strength,
                "strength_clip": character.lora_strength,
            }
        ]
        for style_lora in character.style_loras:
            base_loras.append(style_lora)

        # Create jobs
        jobs = []
        for i, combo in enumerate(combos):
            seed = self._get_seed(seed_strategy, base_seed, i)
            jobs.append(
                VariationJob(
                    job_id=str(uuid.uuid4()),
                    batch_id=batch_id,
                    character=character,
                    template_id=template_id,
                    content_rating=content_rating,
                    variables=combo,
                    seed=seed,
                    loras=list(base_loras),
                )
            )

        return jobs

    def _select_combinations(
        self,
        template_id: str,
        count: int,
        mode: str,
        pin: dict[str, str],
    ) -> list[dict[str, str]]:
        """Select variable combinations based on mode."""
        template = self.template_engine.get(template_id)

        if mode == "random":
            return self._random_combos(template.variables, count, pin)
        elif mode == "exhaustive":
            return self._exhaustive_combos(template.variables, count, pin)
        else:
            # "curated" falls back to random for now
            return self._random_combos(template.variables, count, pin)

    def _random_combos(
        self,
        variables: dict,
        count: int,
        pin: dict[str, str],
    ) -> list[dict[str, str]]:
        """Generate random combinations."""
        combos = []
        for _ in range(count):
            combo: dict[str, str] = {}
            for var_name, var_def in variables.items():
                if var_name in pin:
                    combo[var_name] = pin[var_name]
                elif var_def.type == "choice" and var_def.options:
                    combo[var_name] = random.choice(var_def.options)
                elif var_def.default:
                    combo[var_name] = var_def.default
            combos.append(combo)
        return combos

    def _exhaustive_combos(
        self,
        variables: dict,
        count: int,
        pin: dict[str, str],
    ) -> list[dict[str, str]]:
        """Generate exhaustive (cartesian product) combinations, capped at count."""
        axes: list[list[tuple[str, str]]] = []
        for var_name, var_def in variables.items():
            if var_name in pin:
                axes.append([(var_name, pin[var_name])])
            elif var_def.type == "choice" and var_def.options:
                axes.append([(var_name, opt) for opt in var_def.options])

        if not axes:
            return [{}] * count

        all_combos = [dict(combo) for combo in itertools.product(*axes)]
        if len(all_combos) > count:
            all_combos = random.sample(all_combos, count)
        return all_combos

    def _get_seed(
        self, strategy: str, base_seed: int | None, index: int
    ) -> int:
        """Generate a seed based on strategy."""
        if strategy == "fixed" and base_seed is not None:
            return base_seed
        elif strategy == "sequential" and base_seed is not None:
            return base_seed + index
        else:
            return random.randint(0, 2**32 - 1)