lanczos's picture
deploy: labeling server
871ff87 verified
"""VisualProfile — 4-axis aesthetic profile: art_style × color × art_medium × lighting."""
from __future__ import annotations
import itertools
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
AXES: tuple[str, ...] = ("art_style", "color", "art_medium", "lighting")
REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_VOCAB_PATH = REPO_ROOT / "configs" / "profile_vocab.json"
@dataclass(frozen=True)
class VisualProfile:
"""4-axis aesthetic profile."""
art_style: str
color: str
art_medium: str
lighting: str
def to_dict(self) -> dict[str, str]:
return {axis: getattr(self, axis) for axis in AXES}
def to_tuple(self) -> tuple[str, str, str, str]:
return (self.art_style, self.color, self.art_medium, self.lighting)
@classmethod
def from_dict(cls, data: dict[str, str]) -> "VisualProfile":
return cls(**{axis: data[axis] for axis in AXES})
def validate(self, vocab: dict[str, list[str]]) -> list[str]:
errors: list[str] = []
for axis in AXES:
value = getattr(self, axis)
if value not in vocab.get(axis, []):
errors.append(f"{axis}={value!r} not in {vocab.get(axis)}")
return errors
def differs_on(self, other: "VisualProfile") -> list[str]:
return [axis for axis in AXES if getattr(self, axis) != getattr(other, axis)]
def hamming(self, other: "VisualProfile") -> int:
return len(self.differs_on(other))
def load_vocab(path: str | Path | None = None) -> dict[str, list[str]]:
if path is None:
path = DEFAULT_VOCAB_PATH
with open(path) as f:
vocab = json.load(f)
missing = [axis for axis in AXES if axis not in vocab]
if missing:
raise ValueError(f"vocab missing axes: {missing}")
return vocab
# Style × medium compatibility rules. Combinations where the medium would
# override the style's signature are filtered out of profile enumeration.
_STYLE_MEDIUM_ALLOWLIST: dict[str, set[str]] = {
"Photorealism": {"Digital Painting"},
"Anime": {"Digital Painting", "Pixel Art", "Watercolor", "Ink Drawing"},
}
def is_compatible(profile: VisualProfile) -> bool:
allowed = _STYLE_MEDIUM_ALLOWLIST.get(profile.art_style)
if allowed is not None and profile.art_medium not in allowed:
return False
return True
def enumerate_profiles(
vocab: dict[str, list[str]],
compat_filter: bool = True,
) -> Iterator[VisualProfile]:
"""Yield profile combinations in a fixed order.
With `compat_filter=True` (default) the `_STYLE_MEDIUM_ALLOWLIST` rules are
applied so only renderable combinations are yielded.
"""
for values in itertools.product(*(vocab[axis] for axis in AXES)):
profile = VisualProfile(*values)
if compat_filter and not is_compatible(profile):
continue
yield profile