Spaces:
Sleeping
Sleeping
| """VisualProfile — 4-axis aesthetic profile: art_style × color × art_medium × lighting.""" | |
| from __future__ import annotations | |
| import itertools | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterator | |
| AXES: tuple[str, ...] = ("art_style", "color", "art_medium", "lighting") | |
| REPO_ROOT = Path(__file__).resolve().parents[2] | |
| DEFAULT_VOCAB_PATH = REPO_ROOT / "configs" / "profile_vocab.json" | |
| class VisualProfile: | |
| """4-axis aesthetic profile.""" | |
| art_style: str | |
| color: str | |
| art_medium: str | |
| lighting: str | |
| def to_dict(self) -> dict[str, str]: | |
| return {axis: getattr(self, axis) for axis in AXES} | |
| def to_tuple(self) -> tuple[str, str, str, str]: | |
| return (self.art_style, self.color, self.art_medium, self.lighting) | |
| def from_dict(cls, data: dict[str, str]) -> "VisualProfile": | |
| return cls(**{axis: data[axis] for axis in AXES}) | |
| def validate(self, vocab: dict[str, list[str]]) -> list[str]: | |
| errors: list[str] = [] | |
| for axis in AXES: | |
| value = getattr(self, axis) | |
| if value not in vocab.get(axis, []): | |
| errors.append(f"{axis}={value!r} not in {vocab.get(axis)}") | |
| return errors | |
| def differs_on(self, other: "VisualProfile") -> list[str]: | |
| return [axis for axis in AXES if getattr(self, axis) != getattr(other, axis)] | |
| def hamming(self, other: "VisualProfile") -> int: | |
| return len(self.differs_on(other)) | |
| def load_vocab(path: str | Path | None = None) -> dict[str, list[str]]: | |
| if path is None: | |
| path = DEFAULT_VOCAB_PATH | |
| with open(path) as f: | |
| vocab = json.load(f) | |
| missing = [axis for axis in AXES if axis not in vocab] | |
| if missing: | |
| raise ValueError(f"vocab missing axes: {missing}") | |
| return vocab | |
| # Style × medium compatibility rules. Combinations where the medium would | |
| # override the style's signature are filtered out of profile enumeration. | |
| _STYLE_MEDIUM_ALLOWLIST: dict[str, set[str]] = { | |
| "Photorealism": {"Digital Painting"}, | |
| "Anime": {"Digital Painting", "Pixel Art", "Watercolor", "Ink Drawing"}, | |
| } | |
| def is_compatible(profile: VisualProfile) -> bool: | |
| allowed = _STYLE_MEDIUM_ALLOWLIST.get(profile.art_style) | |
| if allowed is not None and profile.art_medium not in allowed: | |
| return False | |
| return True | |
| def enumerate_profiles( | |
| vocab: dict[str, list[str]], | |
| compat_filter: bool = True, | |
| ) -> Iterator[VisualProfile]: | |
| """Yield profile combinations in a fixed order. | |
| With `compat_filter=True` (default) the `_STYLE_MEDIUM_ALLOWLIST` rules are | |
| applied so only renderable combinations are yielded. | |
| """ | |
| for values in itertools.product(*(vocab[axis] for axis in AXES)): | |
| profile = VisualProfile(*values) | |
| if compat_filter and not is_compatible(profile): | |
| continue | |
| yield profile | |