Spaces:
Running on Zero
Running on Zero
| """ | |
| Gemma Diffusion — text → 3D asset builder (gradio.Server backend + custom frontend). | |
| ZeroGPU port. The block-diffusion model designs a standalone SVG illustration of a | |
| described asset; the custom frontend extrudes that SVG into a live, spinning Three.js | |
| 3D scene. `gradio.Server` (a FastAPI subclass) provides Gradio's queue + SSE streaming | |
| under our hand-written HTML/CSS/JS frontend. The single streaming endpoint `/generate` | |
| yields one JSON frame per denoising step: the raw SVG canvas diffusing on the left, the | |
| extruded 3D object rendering on the right. | |
| ZeroGPU specifics: | |
| - `import spaces` happens before `torch`. | |
| - The model is loaded once at module scope with `.to("cuda")` (ZeroGPU registers it). | |
| - The actual `model.generate` call lives inside the `@spaces.GPU` function `_gpu_stream`; | |
| the `gradio.Server` endpoint only marshals picklable CPU tensors in/out of it. | |
| Refs: | |
| - https://huggingface.co/blog/introducing-gradio-server | |
| - https://huggingface.co/docs/hub/spaces-zerogpu | |
| """ | |
| import glob | |
| import os | |
| import subprocess | |
| import sys | |
| # Set before torch is imported (transformers pulls torch in). | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # must precede torch so ZeroGPU can patch it | |
| def _ensure_transformers(): | |
| """Install the bundled custom DiffusionGemma `transformers` wheel at runtime. | |
| Spaces installs `requirements.txt` *before* copying the repo files into the image, | |
| so the wheel can't be referenced by local path there. By the time this app runs the | |
| file is present in the working directory, so we install it here (only if a stock / | |
| no transformers is importable) before importing torch/transformers below. | |
| """ | |
| try: | |
| import transformers # noqa: F401 | |
| if hasattr(transformers, "DiffusionGemmaForBlockDiffusion") or hasattr( | |
| getattr(transformers, "models", object), "diffusion_gemma" | |
| ): | |
| return | |
| except Exception: | |
| pass | |
| wheels = sorted(glob.glob(os.path.join(os.path.dirname(os.path.abspath(__file__)), "transformers-*.whl"))) | |
| if not wheels: | |
| return | |
| print(f"[gdiff] Installing bundled transformers wheel: {os.path.basename(wheels[0])}", flush=True) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", wheels[0]]) | |
| import importlib | |
| importlib.invalidate_caches() | |
| _ensure_transformers() | |
| import json | |
| import queue as queue_lib | |
| import re | |
| import threading | |
| import time as _time | |
| import torch | |
| from fastapi.responses import HTMLResponse | |
| from gradio import Server | |
| from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion | |
| from transformers.generation.streamers import BaseStreamer | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.environ.get("GDIFF_MODEL_PATH", "google/diffusiongemma-26B-A4B-it") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| MAX_ITERS_CAP = 120 # hard cap on denoising steps per block | |
| # ZeroGPU: the 26B checkpoint (~49 GB bf16) needs the full backing card. | |
| GPU_SIZE = os.environ.get("GDIFF_GPU_SIZE", "xlarge") | |
| SYSTEM_PROMPT = ( | |
| "You are an expert vector artist. Given a TEXT description (usually a game asset — a " | |
| "sword, shield, potion, coin, treasure chest, spaceship, robot, mushroom, key, gem, " | |
| "etc.) you design an original, polished SVG illustration of it. The SVG will be extruded " | |
| "into a spinning 3D object, so design it with clean, solid, extrudable shapes.\n" | |
| "\n" | |
| "Requirements:\n" | |
| "- Output ONLY a single standalone SVG document: start your response with `<svg` and end " | |
| "with `</svg>`. No HTML wrapper, no <?xml?> prologue, no markdown code fences, no " | |
| "explanation.\n" | |
| '- The opening tag must include xmlns="http://www.w3.org/2000/svg" and a square viewBox ' | |
| '(e.g. "0 0 100 100").\n' | |
| "- Draw in a bold, readable, flat 'game asset / icon' style: several distinct shapes " | |
| "(<path>, <rect>, <circle>, <polygon>) each with a SOLID fill color (the `fill` " | |
| "attribute) and a coherent, attractive palette. Layer shapes to suggest detail (outline, " | |
| "body, highlights, shading).\n" | |
| "- Do NOT add a full-bleed background rectangle — keep the background transparent so each " | |
| "shape becomes its own clean 3D piece against the dark scene.\n" | |
| "- Use only solid filled shapes. Avoid gradients, filters, <text>, images, and " | |
| "stroke-only / fill=\"none\" shapes — they do not extrude.\n" | |
| "- Use enough shapes to look great while staying clean (roughly 6-16 shapes).\n" | |
| "- When asked to modify the artwork, return the FULL updated SVG with the change applied, " | |
| "keeping the same subject unless asked to change it.\n" | |
| ) | |
| _MARKER_RE = re.compile( | |
| r"<\|?(?:channel|turn|think|image|audio|video|tool(?:_call|_response)?)\|?>" | |
| ) | |
| _FENCE_RE = re.compile(r"```(?:html|svg|xml)?\s*(.*?)\s*```", re.DOTALL) | |
| _SVG_CHILD_RE = re.compile( | |
| r"<(?:path|rect|circle|ellipse|polygon|polyline|line|g|defs)\b", re.I | |
| ) | |
| _SVG_OPEN = '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">\n' | |
| # --------------------------------------------------------------------------- # | |
| # Model (loaded once at module scope; ZeroGPU registers .to("cuda") tensors) | |
| # --------------------------------------------------------------------------- # | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[gdiff] Loading model from {MODEL_PATH} on {DEVICE} ...", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN) | |
| model = DiffusionGemmaForBlockDiffusion.from_pretrained( | |
| MODEL_PATH, | |
| dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN, | |
| ).to(DEVICE) | |
| model.eval() | |
| CANVAS_LEN = model.config.canvas_length | |
| PAD_ID = tokenizer.pad_token_id or 0 | |
| print(f"[gdiff] Model ready | canvas_length={CANVAS_LEN}", flush=True) | |
| # Cache of the last *cleaned* SVG so a follow-up tweak can warm-start in place. | |
| model._last_clean_html = None | |
| # --------------------------------------------------------------------------- # | |
| # Helpers (CPU-only; safe to run in the gradio.Server main process) | |
| # --------------------------------------------------------------------------- # | |
| def warm_canvas_from_cache(): | |
| """Starting canvas (first block) built from the previous *cleaned* SVG. | |
| Returns a CPU tensor (it is pickled across the ZeroGPU process boundary and moved | |
| to CUDA inside the GPU worker). We re-tokenize the cleaned SVG rather than reuse | |
| raw output tokens so a mangled ``<svg`` header can't compound across tweaks. | |
| """ | |
| svg = getattr(model, "_last_clean_html", None) | |
| if not svg: | |
| return None | |
| ids = tokenizer(svg, add_special_tokens=False).input_ids[:CANVAS_LEN] | |
| if not ids: | |
| return None | |
| if len(ids) < CANVAS_LEN: | |
| ids = ids + [PAD_ID] * (CANVAS_LEN - len(ids)) | |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) | |
| def last_assistant_html(history_json: str): | |
| try: | |
| history = json.loads(history_json) if history_json else [] | |
| except json.JSONDecodeError: | |
| return None | |
| for turn in reversed(history): | |
| if turn.get("role") == "assistant" and turn.get("content"): | |
| return turn["content"] | |
| return None | |
| def clean_text(text: str) -> str: | |
| return _MARKER_RE.sub("", text).lstrip() | |
| def extract_svg(text: str) -> str: | |
| """Pull a clean standalone <svg>…</svg> out of the (possibly mangled) model output. | |
| Warm-start diffusion frequently chews the very front of the document — the opening | |
| ``<svg`` loses its ``<`` (``svg viewBox=…``) or a few more chars. If we can't find an | |
| intact ``<svg``, we rebuild a canonical wrapper around the first real child element so | |
| the output is always valid (the 3D viewer auto-fits the camera, so a default viewBox is | |
| fine). Repairing here is essential: the cleaned result is what we cache for the next | |
| tweak's warm-start, which stops corruption from compounding across tweaks. | |
| """ | |
| text = clean_text(text) | |
| fenced = _FENCE_RE.search(text) | |
| if fenced: | |
| text = fenced.group(1) | |
| lower = text.lower() | |
| s = lower.find("<svg") | |
| if s != -1: | |
| text = text[s:] | |
| else: | |
| m = _SVG_CHILD_RE.search(text) | |
| if m: | |
| text = _SVG_OPEN + text[m.start():] | |
| # else: nothing salvageable; fall through and just trim/close it | |
| lower = text.lower() | |
| e = lower.rfind("</svg>") | |
| if e != -1: | |
| text = text[: e + len("</svg>")] | |
| else: | |
| text = text.rstrip() + "\n</svg>" # tail eaten mid-stream; close it | |
| return text.strip() | |
| class QueueDiffusionStreamer(BaseStreamer): | |
| def __init__(self, tok, q: "queue_lib.Queue"): | |
| self.tok = tok | |
| self.q = q | |
| self.confirmed_ids: list[int] = [] | |
| self.prompt_skipped = False | |
| self.block = 0 | |
| self.step = 0 | |
| def _decode(self, ids): | |
| return self.tok.decode(ids, skip_special_tokens=True) | |
| def put(self, value): | |
| ids = value[0].tolist() if value.dim() > 1 else value.tolist() | |
| if not self.prompt_skipped: | |
| self.prompt_skipped = True | |
| return | |
| self.confirmed_ids.extend(ids) | |
| self.block += 1 | |
| self.step = 0 | |
| self.q.put(("commit", self._decode(self.confirmed_ids), self.block, self.step)) | |
| def put_draft(self, value): | |
| self.step += 1 | |
| ids = value[0].tolist() if value.dim() > 1 else value.tolist() | |
| self.q.put(("draft", self._decode(self.confirmed_ids + ids), self.block + 1, self.step)) | |
| def end(self): | |
| self.q.put(("end", self._decode(self.confirmed_ids), self.block, self.step)) | |
| def build_messages(history_json: str, prompt: str): | |
| try: | |
| history = json.loads(history_json) if history_json else [] | |
| except json.JSONDecodeError: | |
| history = [] | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| for turn in history: | |
| role = turn.get("role") | |
| content = turn.get("content", "") | |
| if role in ("user", "assistant") and content: | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": prompt}) | |
| return messages | |
| # --------------------------------------------------------------------------- # | |
| # GPU work — runs in a forked ZeroGPU worker process. | |
| # Inputs/outputs cross the boundary via pickle, so only CPU tensors / plain | |
| # Python objects go in and out (no CUDA tensors are returned). | |
| # --------------------------------------------------------------------------- # | |
| def _estimate_duration(input_ids, max_new_tokens=2048, max_iters=64, full_denoise=False, canvas_ids=None): | |
| blocks = max(1, int(max_new_tokens) // max(1, CANVAS_LEN)) | |
| secs = 30 + blocks * int(max_iters) * 0.3 | |
| return int(min(120, secs)) # xlarge internally doubles this for the quota check | |
| def _gpu_stream(input_ids, max_new_tokens, max_iters, full_denoise, canvas_ids): | |
| input_ids = input_ids.to(model.device) | |
| gen_kwargs = dict(max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_iters)) | |
| if full_denoise: | |
| gen_kwargs["confidence_threshold"] = 1e-9 | |
| gen_kwargs["stability_threshold"] = int(max_iters) | |
| if canvas_ids is not None: | |
| gen_kwargs["canvas_ids"] = canvas_ids.to(model.device) | |
| q: "queue_lib.Queue" = queue_lib.Queue() | |
| streamer = QueueDiffusionStreamer(tokenizer, q) | |
| err = {} | |
| def worker(): | |
| try: | |
| with torch.inference_mode(): | |
| model.generate(input_ids, streamer=streamer, **gen_kwargs) | |
| except Exception as exc: # surface to the endpoint | |
| err["msg"] = f"{type(exc).__name__}: {exc}" | |
| q.put(("error", str(exc), 0, 0)) | |
| finally: | |
| q.put(("end", "", 0, 0)) # always unblock the consumer | |
| thread = threading.Thread(target=worker) | |
| thread.start() | |
| try: | |
| while True: | |
| kind, text, block, step = q.get() | |
| if kind == "error": | |
| yield ("error", err.get("msg", text), 0, 0) | |
| return | |
| if kind == "end": | |
| return | |
| yield (kind, text, block, step) | |
| finally: | |
| thread.join() | |
| # --------------------------------------------------------------------------- # | |
| # Server | |
| # --------------------------------------------------------------------------- # | |
| app = Server(title="Gemma Diffusion 3D Asset Builder") | |
| def generate( | |
| prompt: str, | |
| history_json: str = "[]", | |
| max_new_tokens: int = 2048, | |
| max_iters: int = 64, | |
| full_denoise: bool = False, | |
| anim_delay: float = 0.0, | |
| warm_start: bool = True, | |
| ) -> str: | |
| """Stream the diffusion generation as JSON frames (one per denoising step). | |
| The model emits a raw SVG illustration; the frontend extrudes it into 3D with Three.js. | |
| """ | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| yield json.dumps({"kind": "error", "message": "Empty prompt."}) | |
| return | |
| messages = build_messages(history_json, prompt) | |
| max_iters = max(1, min(int(max_iters), MAX_ITERS_CAP)) | |
| # Tweak warm-start: seed the diffusion's first canvas with the previous artwork's own | |
| # tokens (native `canvas_ids` API) so the model edits the existing SVG in place. | |
| is_tweak = bool(last_assistant_html(history_json)) | |
| canvas_ids = warm_canvas_from_cache() if (warm_start and is_tweak) else None | |
| warming = canvas_ids is not None | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| )["input_ids"] | |
| last_text = "" | |
| for kind, text, block, step in _gpu_stream( | |
| input_ids, int(max_new_tokens), max_iters, bool(full_denoise), canvas_ids | |
| ): | |
| if kind == "error": | |
| yield json.dumps({"kind": "error", "message": text}) | |
| return | |
| last_text = text | |
| yield json.dumps( | |
| { | |
| "kind": "draft" if kind == "draft" else "commit", | |
| "source": clean_text(text), | |
| "block": block, | |
| "step": step, | |
| "canvas": CANVAS_LEN, | |
| "max_iters": max_iters, | |
| "warming": warming, | |
| } | |
| ) | |
| if anim_delay and kind == "draft": | |
| _time.sleep(float(anim_delay)) | |
| final_source = extract_svg(last_text) | |
| # Cache the *cleaned* SVG so the next tweak warm-starts from a valid header. | |
| if final_source.strip(): | |
| model._last_clean_html = final_source | |
| yield json.dumps({"kind": "done", "source": final_source}) | |
| async def homepage(): | |
| with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f: | |
| return f.read() | |
| # HF Spaces' gradio runtime looks for a top-level `demo` (or `app`) to launch. | |
| demo = app | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name=os.environ.get("GDIFF_HOST", "0.0.0.0"), | |
| server_port=int(os.environ.get("GDIFF_PORT", "7860")), | |
| show_error=True, | |
| ) | |