Spaces:
Sleeping
Sleeping
| import base64 | |
| import json | |
| import threading | |
| from dataclasses import dataclass, field, replace | |
| from io import BytesIO | |
| from typing import Any, Callable, Protocol, Sequence | |
| from urllib import request | |
| from budget import Card | |
| from zerogpu import gpu | |
| NEGATIVE_ART_PROMPT = ( | |
| "text, letters, words, labels, ui, interface, card frame, border, watermark, logo, " | |
| "playing card, paper, white edge, cropped subject, cut off subject, split panel" | |
| ) | |
| class ModalArtClient: | |
| """Generate card art by calling a Modal SDXL endpoint; returns a data URI.""" | |
| endpoint: str | |
| steps: int = 4 | |
| guidance_scale: float = 0.0 | |
| width: int = 512 | |
| height: int = 320 | |
| timeout_seconds: int = 120 | |
| # POST one prompt to the Modal art endpoint and return its image data URI. | |
| def create_art(self, prompt: str) -> str: | |
| payload = { | |
| "prompt": prompt, | |
| "steps": self.steps, | |
| "guidance": self.guidance_scale, | |
| "width": self.width, | |
| "height": self.height, | |
| "negative_prompt": NEGATIVE_ART_PROMPT, | |
| } | |
| req = request.Request( | |
| self.endpoint, | |
| data=json.dumps(payload).encode("utf-8"), | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| with request.urlopen(req, timeout=self.timeout_seconds) as response: | |
| return str(json.loads(response.read().decode("utf-8"))["image"]) | |
| class ArtClient(Protocol): | |
| # Return a browser-renderable image URI for one prompt. | |
| def create_art(self, prompt: str) -> str: # pragma: no cover | |
| ... | |
| class LazyArtClient: | |
| loader: Callable[[], ArtClient] | |
| lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) | |
| client: ArtClient | None = None | |
| # Load the real art backend on first image request. | |
| def create_art(self, prompt: str) -> str: | |
| if self.client is None: | |
| with self.lock: | |
| if self.client is None: | |
| self.client = self.loader() | |
| return self.client.create_art(prompt) | |
| class DiffusersImageClient: | |
| pipe: Any | |
| steps: int = 1 | |
| guidance_scale: float = 0.0 | |
| width: int = 512 | |
| height: int = 320 | |
| lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) | |
| # Set the active pipe global, then generate the image on a ZeroGPU allocation. | |
| # The @gpu worker reads the pipe from the global (inherited via fork) rather | |
| # than receiving the multi-GB pipeline as a pickled argument. | |
| def create_art(self, prompt: str) -> str: | |
| global _art_pipe | |
| with self.lock: | |
| _art_pipe = self.pipe | |
| return _run_art_pipe(prompt, self.steps, self.guidance_scale, self.width, self.height) | |
| # Load a diffusers text-to-image pipeline (in the main process). | |
| def load( | |
| cls, | |
| model_id: str, | |
| steps: int = 1, | |
| guidance_scale: float = 0.0, | |
| width: int = 512, | |
| height: int = 320, | |
| ) -> "DiffusersImageClient": # pragma: no cover | |
| import torch | |
| from diffusers import AutoPipelineForText2Image | |
| # Stay on CPU here: on ZeroGPU the move to CUDA happens in _run_art_pipe. | |
| pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=best_torch_dtype(torch)) | |
| pipe.set_progress_bar_config(disable=True) | |
| return cls(pipe, steps=steps, guidance_scale=guidance_scale, width=width, height=height) | |
| _art_pipe: Any = None | |
| # Generate one image on a ZeroGPU allocation, reading the pipe from a module global | |
| # so the forked GPU worker inherits it; returns a data URI string (picklable). | |
| def _run_art_pipe(prompt: str, steps: int, guidance_scale: float, width: int, height: int) -> str: | |
| from local_llm import best_device | |
| _art_pipe.to(best_device()) | |
| result = _art_pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| width=width, | |
| height=height, | |
| negative_prompt=NEGATIVE_ART_PROMPT, | |
| ) | |
| return image_data_uri(result.images[0]) | |
| # Return the best local dtype for SDXL generation. | |
| def best_torch_dtype(torch: Any) -> Any: | |
| if torch.cuda.is_available(): | |
| return torch.float16 | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return torch.float16 | |
| return torch.float32 | |
| # Move a diffusers pipeline to the best available local device. | |
| def move_pipe_to_device(pipe: Any, torch: Any) -> Any: | |
| if torch.cuda.is_available(): | |
| return pipe.to("cuda") | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return pipe.to("mps") | |
| return pipe | |
| # Encode a PIL-like image as a PNG data URI. | |
| def image_data_uri(image: Any) -> str: | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| encoded = base64.b64encode(buffer.getvalue()).decode("ascii") | |
| return f"data:image/png;base64,{encoded}" | |
| # Return the prompt sent to the art model for one card. | |
| def card_art_prompt(card: Card) -> str: | |
| subject = card.art_prompt or f"{card.name}, {card.school} magic, {card.rules_text()}" | |
| return ( | |
| f"{theme_style(card.theme)}, {subject}, " | |
| "single dramatic spell scene, clear centered focal subject, full composition, no text, no card frame" | |
| ) | |
| # Return a visual style clause inferred from the user's world. | |
| def theme_style(theme: str) -> str: | |
| lowered = theme.lower() | |
| if "anime" in lowered: | |
| return "wide anime fantasy key visual, cel shaded, clean linework, luminous color, dynamic composition" | |
| if "dark" in lowered or "gothic" in lowered: | |
| return "wide dark fantasy spell illustration, dramatic chiaroscuro, painterly concept art" | |
| if "wuxia" in lowered: | |
| return "wide wuxia fantasy spell illustration, flowing silk, ink-wash energy, cinematic motion" | |
| return f"wide fantasy spell illustration, {theme}, painterly concept art" | |
| # Add generated art to one card, falling back to the unchanged card on failure. | |
| def illustrate_card(client: ArtClient | None, card: Card) -> Card: | |
| if client is None: | |
| return card | |
| try: | |
| return replace(card, art_uri=client.create_art(card_art_prompt(card))) | |
| except Exception: | |
| return card | |
| # Add generated art to a sequence of cards. | |
| def illustrate_cards(client: ArtClient | None, cards: Sequence[Card]) -> tuple[Card, ...]: | |
| return tuple(illustrate_card(client, card) for card in cards) | |