import base64 import json import threading from dataclasses import dataclass, field, replace from io import BytesIO from typing import Any, Callable, Protocol, Sequence from urllib import request from budget import Card from zerogpu import gpu NEGATIVE_ART_PROMPT = ( "text, letters, words, labels, ui, interface, card frame, border, watermark, logo, " "playing card, paper, white edge, cropped subject, cut off subject, split panel" ) @dataclass(frozen=True) class ModalArtClient: """Generate card art by calling a Modal SDXL endpoint; returns a data URI.""" endpoint: str steps: int = 4 guidance_scale: float = 0.0 width: int = 512 height: int = 320 timeout_seconds: int = 120 # POST one prompt to the Modal art endpoint and return its image data URI. def create_art(self, prompt: str) -> str: payload = { "prompt": prompt, "steps": self.steps, "guidance": self.guidance_scale, "width": self.width, "height": self.height, "negative_prompt": NEGATIVE_ART_PROMPT, } req = request.Request( self.endpoint, data=json.dumps(payload).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) with request.urlopen(req, timeout=self.timeout_seconds) as response: return str(json.loads(response.read().decode("utf-8"))["image"]) class ArtClient(Protocol): # Return a browser-renderable image URI for one prompt. def create_art(self, prompt: str) -> str: # pragma: no cover ... @dataclass class LazyArtClient: loader: Callable[[], ArtClient] lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) client: ArtClient | None = None # Load the real art backend on first image request. def create_art(self, prompt: str) -> str: if self.client is None: with self.lock: if self.client is None: self.client = self.loader() return self.client.create_art(prompt) @dataclass(frozen=True) class DiffusersImageClient: pipe: Any steps: int = 1 guidance_scale: float = 0.0 width: int = 512 height: int = 320 lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) # Set the active pipe global, then generate the image on a ZeroGPU allocation. # The @gpu worker reads the pipe from the global (inherited via fork) rather # than receiving the multi-GB pipeline as a pickled argument. def create_art(self, prompt: str) -> str: global _art_pipe with self.lock: _art_pipe = self.pipe return _run_art_pipe(prompt, self.steps, self.guidance_scale, self.width, self.height) # Load a diffusers text-to-image pipeline (in the main process). @classmethod def load( cls, model_id: str, steps: int = 1, guidance_scale: float = 0.0, width: int = 512, height: int = 320, ) -> "DiffusersImageClient": # pragma: no cover import torch from diffusers import AutoPipelineForText2Image # Stay on CPU here: on ZeroGPU the move to CUDA happens in _run_art_pipe. pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=best_torch_dtype(torch)) pipe.set_progress_bar_config(disable=True) return cls(pipe, steps=steps, guidance_scale=guidance_scale, width=width, height=height) _art_pipe: Any = None # Generate one image on a ZeroGPU allocation, reading the pipe from a module global # so the forked GPU worker inherits it; returns a data URI string (picklable). @gpu def _run_art_pipe(prompt: str, steps: int, guidance_scale: float, width: int, height: int) -> str: from local_llm import best_device _art_pipe.to(best_device()) result = _art_pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance_scale, width=width, height=height, negative_prompt=NEGATIVE_ART_PROMPT, ) return image_data_uri(result.images[0]) # Return the best local dtype for SDXL generation. def best_torch_dtype(torch: Any) -> Any: if torch.cuda.is_available(): return torch.float16 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.float16 return torch.float32 # Move a diffusers pipeline to the best available local device. def move_pipe_to_device(pipe: Any, torch: Any) -> Any: if torch.cuda.is_available(): return pipe.to("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return pipe.to("mps") return pipe # Encode a PIL-like image as a PNG data URI. def image_data_uri(image: Any) -> str: buffer = BytesIO() image.save(buffer, format="PNG") encoded = base64.b64encode(buffer.getvalue()).decode("ascii") return f"data:image/png;base64,{encoded}" # Return the prompt sent to the art model for one card. def card_art_prompt(card: Card) -> str: subject = card.art_prompt or f"{card.name}, {card.school} magic, {card.rules_text()}" return ( f"{theme_style(card.theme)}, {subject}, " "single dramatic spell scene, clear centered focal subject, full composition, no text, no card frame" ) # Return a visual style clause inferred from the user's world. def theme_style(theme: str) -> str: lowered = theme.lower() if "anime" in lowered: return "wide anime fantasy key visual, cel shaded, clean linework, luminous color, dynamic composition" if "dark" in lowered or "gothic" in lowered: return "wide dark fantasy spell illustration, dramatic chiaroscuro, painterly concept art" if "wuxia" in lowered: return "wide wuxia fantasy spell illustration, flowing silk, ink-wash energy, cinematic motion" return f"wide fantasy spell illustration, {theme}, painterly concept art" # Add generated art to one card, falling back to the unchanged card on failure. def illustrate_card(client: ArtClient | None, card: Card) -> Card: if client is None: return card try: return replace(card, art_uri=client.create_art(card_art_prompt(card))) except Exception: return card # Add generated art to a sequence of cards. def illustrate_cards(client: ArtClient | None, cards: Sequence[Card]) -> tuple[Card, ...]: return tuple(illustrate_card(client, card) for card in cards)