Spaces:
Running
Running
File size: 6,252 Bytes
ed37502 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """Content variation engine for generating combinatorial batches.
Given a prompt template and a character, the variation engine produces
multiple generation jobs with different combinations of poses, outfits,
emotions, camera angles, and other variable attributes.
"""
from __future__ import annotations
import itertools
import random
import uuid
from dataclasses import dataclass, field
from typing import Any
from content_engine.services.template_engine import TemplateEngine
@dataclass
class CharacterProfile:
"""Character configuration loaded from YAML."""
id: str
name: str
trigger_word: str
lora_filename: str
lora_strength: float = 0.85
default_checkpoint: str | None = None
style_loras: list[dict[str, Any]] = field(default_factory=list)
description: str = ""
physical_traits: dict[str, str] = field(default_factory=dict)
@dataclass
class VariationJob:
"""A single generation job produced by the variation engine."""
job_id: str
batch_id: str
character: CharacterProfile
template_id: str
content_rating: str
variables: dict[str, str]
seed: int
loras: list[dict[str, Any]]
class VariationEngine:
"""Generates batches of variation jobs from templates."""
def __init__(self, template_engine: TemplateEngine):
self.template_engine = template_engine
def generate_batch(
self,
template_id: str,
character: CharacterProfile,
*,
content_rating: str = "sfw",
count: int = 10,
variation_mode: str = "random", # curated | random | exhaustive
pin: dict[str, str] | None = None,
seed_strategy: str = "random", # random | sequential | fixed
base_seed: int | None = None,
) -> list[VariationJob]:
"""Generate a batch of variation jobs.
Args:
template_id: Which prompt template to use.
character: Character profile for LoRA and trigger word.
content_rating: "sfw" or "nsfw".
count: Number of variations to generate.
variation_mode: How to select variable combinations.
pin: Variables to keep fixed across all variations.
seed_strategy: How to assign seeds.
base_seed: Starting seed for sequential strategy.
"""
template = self.template_engine.get(template_id)
pin = pin or {}
batch_id = str(uuid.uuid4())
# Build variable combinations
combos = self._select_combinations(template_id, count, variation_mode, pin)
# Inject character-specific variables
for combo in combos:
combo["character_trigger"] = character.trigger_word
combo["character_lora"] = character.lora_filename
# Build LoRA list for each job
base_loras = [
{
"name": character.lora_filename,
"strength_model": character.lora_strength,
"strength_clip": character.lora_strength,
}
]
for style_lora in character.style_loras:
base_loras.append(style_lora)
# Create jobs
jobs = []
for i, combo in enumerate(combos):
seed = self._get_seed(seed_strategy, base_seed, i)
jobs.append(
VariationJob(
job_id=str(uuid.uuid4()),
batch_id=batch_id,
character=character,
template_id=template_id,
content_rating=content_rating,
variables=combo,
seed=seed,
loras=list(base_loras),
)
)
return jobs
def _select_combinations(
self,
template_id: str,
count: int,
mode: str,
pin: dict[str, str],
) -> list[dict[str, str]]:
"""Select variable combinations based on mode."""
template = self.template_engine.get(template_id)
if mode == "random":
return self._random_combos(template.variables, count, pin)
elif mode == "exhaustive":
return self._exhaustive_combos(template.variables, count, pin)
else:
# "curated" falls back to random for now
return self._random_combos(template.variables, count, pin)
def _random_combos(
self,
variables: dict,
count: int,
pin: dict[str, str],
) -> list[dict[str, str]]:
"""Generate random combinations."""
combos = []
for _ in range(count):
combo: dict[str, str] = {}
for var_name, var_def in variables.items():
if var_name in pin:
combo[var_name] = pin[var_name]
elif var_def.type == "choice" and var_def.options:
combo[var_name] = random.choice(var_def.options)
elif var_def.default:
combo[var_name] = var_def.default
combos.append(combo)
return combos
def _exhaustive_combos(
self,
variables: dict,
count: int,
pin: dict[str, str],
) -> list[dict[str, str]]:
"""Generate exhaustive (cartesian product) combinations, capped at count."""
axes: list[list[tuple[str, str]]] = []
for var_name, var_def in variables.items():
if var_name in pin:
axes.append([(var_name, pin[var_name])])
elif var_def.type == "choice" and var_def.options:
axes.append([(var_name, opt) for opt in var_def.options])
if not axes:
return [{}] * count
all_combos = [dict(combo) for combo in itertools.product(*axes)]
if len(all_combos) > count:
all_combos = random.sample(all_combos, count)
return all_combos
def _get_seed(
self, strategy: str, base_seed: int | None, index: int
) -> int:
"""Generate a seed based on strategy."""
if strategy == "fixed" and base_seed is not None:
return base_seed
elif strategy == "sequential" and base_seed is not None:
return base_seed + index
else:
return random.randint(0, 2**32 - 1)
|