tabras / art.py
Codex
Add MODE switch: LOCAL (on-device, default) vs MODAL (GPU endpoints)
8c897f9
Raw
History Blame Contribute Delete
6.53 kB
import base64
import json
import threading
from dataclasses import dataclass, field, replace
from io import BytesIO
from typing import Any, Callable, Protocol, Sequence
from urllib import request
from budget import Card
from zerogpu import gpu
NEGATIVE_ART_PROMPT = (
"text, letters, words, labels, ui, interface, card frame, border, watermark, logo, "
"playing card, paper, white edge, cropped subject, cut off subject, split panel"
)
@dataclass(frozen=True)
class ModalArtClient:
"""Generate card art by calling a Modal SDXL endpoint; returns a data URI."""
endpoint: str
steps: int = 4
guidance_scale: float = 0.0
width: int = 512
height: int = 320
timeout_seconds: int = 120
# POST one prompt to the Modal art endpoint and return its image data URI.
def create_art(self, prompt: str) -> str:
payload = {
"prompt": prompt,
"steps": self.steps,
"guidance": self.guidance_scale,
"width": self.width,
"height": self.height,
"negative_prompt": NEGATIVE_ART_PROMPT,
}
req = request.Request(
self.endpoint,
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
with request.urlopen(req, timeout=self.timeout_seconds) as response:
return str(json.loads(response.read().decode("utf-8"))["image"])
class ArtClient(Protocol):
# Return a browser-renderable image URI for one prompt.
def create_art(self, prompt: str) -> str: # pragma: no cover
...
@dataclass
class LazyArtClient:
loader: Callable[[], ArtClient]
lock: Any = field(default_factory=threading.Lock, repr=False, compare=False)
client: ArtClient | None = None
# Load the real art backend on first image request.
def create_art(self, prompt: str) -> str:
if self.client is None:
with self.lock:
if self.client is None:
self.client = self.loader()
return self.client.create_art(prompt)
@dataclass(frozen=True)
class DiffusersImageClient:
pipe: Any
steps: int = 1
guidance_scale: float = 0.0
width: int = 512
height: int = 320
lock: Any = field(default_factory=threading.Lock, repr=False, compare=False)
# Set the active pipe global, then generate the image on a ZeroGPU allocation.
# The @gpu worker reads the pipe from the global (inherited via fork) rather
# than receiving the multi-GB pipeline as a pickled argument.
def create_art(self, prompt: str) -> str:
global _art_pipe
with self.lock:
_art_pipe = self.pipe
return _run_art_pipe(prompt, self.steps, self.guidance_scale, self.width, self.height)
# Load a diffusers text-to-image pipeline (in the main process).
@classmethod
def load(
cls,
model_id: str,
steps: int = 1,
guidance_scale: float = 0.0,
width: int = 512,
height: int = 320,
) -> "DiffusersImageClient": # pragma: no cover
import torch
from diffusers import AutoPipelineForText2Image
# Stay on CPU here: on ZeroGPU the move to CUDA happens in _run_art_pipe.
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=best_torch_dtype(torch))
pipe.set_progress_bar_config(disable=True)
return cls(pipe, steps=steps, guidance_scale=guidance_scale, width=width, height=height)
_art_pipe: Any = None
# Generate one image on a ZeroGPU allocation, reading the pipe from a module global
# so the forked GPU worker inherits it; returns a data URI string (picklable).
@gpu
def _run_art_pipe(prompt: str, steps: int, guidance_scale: float, width: int, height: int) -> str:
from local_llm import best_device
_art_pipe.to(best_device())
result = _art_pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale,
width=width,
height=height,
negative_prompt=NEGATIVE_ART_PROMPT,
)
return image_data_uri(result.images[0])
# Return the best local dtype for SDXL generation.
def best_torch_dtype(torch: Any) -> Any:
if torch.cuda.is_available():
return torch.float16
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.float16
return torch.float32
# Move a diffusers pipeline to the best available local device.
def move_pipe_to_device(pipe: Any, torch: Any) -> Any:
if torch.cuda.is_available():
return pipe.to("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return pipe.to("mps")
return pipe
# Encode a PIL-like image as a PNG data URI.
def image_data_uri(image: Any) -> str:
buffer = BytesIO()
image.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
return f"data:image/png;base64,{encoded}"
# Return the prompt sent to the art model for one card.
def card_art_prompt(card: Card) -> str:
subject = card.art_prompt or f"{card.name}, {card.school} magic, {card.rules_text()}"
return (
f"{theme_style(card.theme)}, {subject}, "
"single dramatic spell scene, clear centered focal subject, full composition, no text, no card frame"
)
# Return a visual style clause inferred from the user's world.
def theme_style(theme: str) -> str:
lowered = theme.lower()
if "anime" in lowered:
return "wide anime fantasy key visual, cel shaded, clean linework, luminous color, dynamic composition"
if "dark" in lowered or "gothic" in lowered:
return "wide dark fantasy spell illustration, dramatic chiaroscuro, painterly concept art"
if "wuxia" in lowered:
return "wide wuxia fantasy spell illustration, flowing silk, ink-wash energy, cinematic motion"
return f"wide fantasy spell illustration, {theme}, painterly concept art"
# Add generated art to one card, falling back to the unchanged card on failure.
def illustrate_card(client: ArtClient | None, card: Card) -> Card:
if client is None:
return card
try:
return replace(card, art_uri=client.create_art(card_art_prompt(card)))
except Exception:
return card
# Add generated art to a sequence of cards.
def illustrate_cards(client: ArtClient | None, cards: Sequence[Card]) -> tuple[Card, ...]:
return tuple(illustrate_card(client, card) for card in cards)