merve's picture
merve HF Staff
Add Gemma Diffusion 3D asset builder with google/diffusiongemma-26B-A4B-it
7f3234f verified
"""
Gemma Diffusion — text → 3D asset builder (gradio.Server backend + custom frontend).
ZeroGPU port. The block-diffusion model designs a standalone SVG illustration of a
described asset; the custom frontend extrudes that SVG into a live, spinning Three.js
3D scene. `gradio.Server` (a FastAPI subclass) provides Gradio's queue + SSE streaming
under our hand-written HTML/CSS/JS frontend. The single streaming endpoint `/generate`
yields one JSON frame per denoising step: the raw SVG canvas diffusing on the left, the
extruded 3D object rendering on the right.
ZeroGPU specifics:
- `import spaces` happens before `torch`.
- The model is loaded once at module scope with `.to("cuda")` (ZeroGPU registers it).
- The actual `model.generate` call lives inside the `@spaces.GPU` function `_gpu_stream`;
the `gradio.Server` endpoint only marshals picklable CPU tensors in/out of it.
Refs:
- https://huggingface.co/blog/introducing-gradio-server
- https://huggingface.co/docs/hub/spaces-zerogpu
"""
import glob
import os
import subprocess
import sys
# Set before torch is imported (transformers pulls torch in).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # must precede torch so ZeroGPU can patch it
def _ensure_transformers():
"""Install the bundled custom DiffusionGemma `transformers` wheel at runtime.
Spaces installs `requirements.txt` *before* copying the repo files into the image,
so the wheel can't be referenced by local path there. By the time this app runs the
file is present in the working directory, so we install it here (only if a stock /
no transformers is importable) before importing torch/transformers below.
"""
try:
import transformers # noqa: F401
if hasattr(transformers, "DiffusionGemmaForBlockDiffusion") or hasattr(
getattr(transformers, "models", object), "diffusion_gemma"
):
return
except Exception:
pass
wheels = sorted(glob.glob(os.path.join(os.path.dirname(os.path.abspath(__file__)), "transformers-*.whl")))
if not wheels:
return
print(f"[gdiff] Installing bundled transformers wheel: {os.path.basename(wheels[0])}", flush=True)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", wheels[0]])
import importlib
importlib.invalidate_caches()
_ensure_transformers()
import json
import queue as queue_lib
import re
import threading
import time as _time
import torch
from fastapi.responses import HTMLResponse
from gradio import Server
from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion
from transformers.generation.streamers import BaseStreamer
HERE = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.environ.get("GDIFF_MODEL_PATH", "google/diffusiongemma-26B-A4B-it")
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_ITERS_CAP = 120 # hard cap on denoising steps per block
# ZeroGPU: the 26B checkpoint (~49 GB bf16) needs the full backing card.
GPU_SIZE = os.environ.get("GDIFF_GPU_SIZE", "xlarge")
SYSTEM_PROMPT = (
"You are an expert vector artist. Given a TEXT description (usually a game asset — a "
"sword, shield, potion, coin, treasure chest, spaceship, robot, mushroom, key, gem, "
"etc.) you design an original, polished SVG illustration of it. The SVG will be extruded "
"into a spinning 3D object, so design it with clean, solid, extrudable shapes.\n"
"\n"
"Requirements:\n"
"- Output ONLY a single standalone SVG document: start your response with `<svg` and end "
"with `</svg>`. No HTML wrapper, no <?xml?> prologue, no markdown code fences, no "
"explanation.\n"
'- The opening tag must include xmlns="http://www.w3.org/2000/svg" and a square viewBox '
'(e.g. "0 0 100 100").\n'
"- Draw in a bold, readable, flat 'game asset / icon' style: several distinct shapes "
"(<path>, <rect>, <circle>, <polygon>) each with a SOLID fill color (the `fill` "
"attribute) and a coherent, attractive palette. Layer shapes to suggest detail (outline, "
"body, highlights, shading).\n"
"- Do NOT add a full-bleed background rectangle — keep the background transparent so each "
"shape becomes its own clean 3D piece against the dark scene.\n"
"- Use only solid filled shapes. Avoid gradients, filters, <text>, images, and "
"stroke-only / fill=\"none\" shapes — they do not extrude.\n"
"- Use enough shapes to look great while staying clean (roughly 6-16 shapes).\n"
"- When asked to modify the artwork, return the FULL updated SVG with the change applied, "
"keeping the same subject unless asked to change it.\n"
)
_MARKER_RE = re.compile(
r"<\|?(?:channel|turn|think|image|audio|video|tool(?:_call|_response)?)\|?>"
)
_FENCE_RE = re.compile(r"```(?:html|svg|xml)?\s*(.*?)\s*```", re.DOTALL)
_SVG_CHILD_RE = re.compile(
r"<(?:path|rect|circle|ellipse|polygon|polyline|line|g|defs)\b", re.I
)
_SVG_OPEN = '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">\n'
# --------------------------------------------------------------------------- #
# Model (loaded once at module scope; ZeroGPU registers .to("cuda") tensors)
# --------------------------------------------------------------------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[gdiff] Loading model from {MODEL_PATH} on {DEVICE} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
MODEL_PATH,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
token=HF_TOKEN,
).to(DEVICE)
model.eval()
CANVAS_LEN = model.config.canvas_length
PAD_ID = tokenizer.pad_token_id or 0
print(f"[gdiff] Model ready | canvas_length={CANVAS_LEN}", flush=True)
# Cache of the last *cleaned* SVG so a follow-up tweak can warm-start in place.
model._last_clean_html = None
# --------------------------------------------------------------------------- #
# Helpers (CPU-only; safe to run in the gradio.Server main process)
# --------------------------------------------------------------------------- #
def warm_canvas_from_cache():
"""Starting canvas (first block) built from the previous *cleaned* SVG.
Returns a CPU tensor (it is pickled across the ZeroGPU process boundary and moved
to CUDA inside the GPU worker). We re-tokenize the cleaned SVG rather than reuse
raw output tokens so a mangled ``<svg`` header can't compound across tweaks.
"""
svg = getattr(model, "_last_clean_html", None)
if not svg:
return None
ids = tokenizer(svg, add_special_tokens=False).input_ids[:CANVAS_LEN]
if not ids:
return None
if len(ids) < CANVAS_LEN:
ids = ids + [PAD_ID] * (CANVAS_LEN - len(ids))
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
def last_assistant_html(history_json: str):
try:
history = json.loads(history_json) if history_json else []
except json.JSONDecodeError:
return None
for turn in reversed(history):
if turn.get("role") == "assistant" and turn.get("content"):
return turn["content"]
return None
def clean_text(text: str) -> str:
return _MARKER_RE.sub("", text).lstrip()
def extract_svg(text: str) -> str:
"""Pull a clean standalone <svg>…</svg> out of the (possibly mangled) model output.
Warm-start diffusion frequently chews the very front of the document — the opening
``<svg`` loses its ``<`` (``svg viewBox=…``) or a few more chars. If we can't find an
intact ``<svg``, we rebuild a canonical wrapper around the first real child element so
the output is always valid (the 3D viewer auto-fits the camera, so a default viewBox is
fine). Repairing here is essential: the cleaned result is what we cache for the next
tweak's warm-start, which stops corruption from compounding across tweaks.
"""
text = clean_text(text)
fenced = _FENCE_RE.search(text)
if fenced:
text = fenced.group(1)
lower = text.lower()
s = lower.find("<svg")
if s != -1:
text = text[s:]
else:
m = _SVG_CHILD_RE.search(text)
if m:
text = _SVG_OPEN + text[m.start():]
# else: nothing salvageable; fall through and just trim/close it
lower = text.lower()
e = lower.rfind("</svg>")
if e != -1:
text = text[: e + len("</svg>")]
else:
text = text.rstrip() + "\n</svg>" # tail eaten mid-stream; close it
return text.strip()
class QueueDiffusionStreamer(BaseStreamer):
def __init__(self, tok, q: "queue_lib.Queue"):
self.tok = tok
self.q = q
self.confirmed_ids: list[int] = []
self.prompt_skipped = False
self.block = 0
self.step = 0
def _decode(self, ids):
return self.tok.decode(ids, skip_special_tokens=True)
def put(self, value):
ids = value[0].tolist() if value.dim() > 1 else value.tolist()
if not self.prompt_skipped:
self.prompt_skipped = True
return
self.confirmed_ids.extend(ids)
self.block += 1
self.step = 0
self.q.put(("commit", self._decode(self.confirmed_ids), self.block, self.step))
def put_draft(self, value):
self.step += 1
ids = value[0].tolist() if value.dim() > 1 else value.tolist()
self.q.put(("draft", self._decode(self.confirmed_ids + ids), self.block + 1, self.step))
def end(self):
self.q.put(("end", self._decode(self.confirmed_ids), self.block, self.step))
def build_messages(history_json: str, prompt: str):
try:
history = json.loads(history_json) if history_json else []
except json.JSONDecodeError:
history = []
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for turn in history:
role = turn.get("role")
content = turn.get("content", "")
if role in ("user", "assistant") and content:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": prompt})
return messages
# --------------------------------------------------------------------------- #
# GPU work — runs in a forked ZeroGPU worker process.
# Inputs/outputs cross the boundary via pickle, so only CPU tensors / plain
# Python objects go in and out (no CUDA tensors are returned).
# --------------------------------------------------------------------------- #
def _estimate_duration(input_ids, max_new_tokens=2048, max_iters=64, full_denoise=False, canvas_ids=None):
blocks = max(1, int(max_new_tokens) // max(1, CANVAS_LEN))
secs = 30 + blocks * int(max_iters) * 0.3
return int(min(120, secs)) # xlarge internally doubles this for the quota check
@spaces.GPU(duration=_estimate_duration, size=GPU_SIZE)
def _gpu_stream(input_ids, max_new_tokens, max_iters, full_denoise, canvas_ids):
input_ids = input_ids.to(model.device)
gen_kwargs = dict(max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_iters))
if full_denoise:
gen_kwargs["confidence_threshold"] = 1e-9
gen_kwargs["stability_threshold"] = int(max_iters)
if canvas_ids is not None:
gen_kwargs["canvas_ids"] = canvas_ids.to(model.device)
q: "queue_lib.Queue" = queue_lib.Queue()
streamer = QueueDiffusionStreamer(tokenizer, q)
err = {}
def worker():
try:
with torch.inference_mode():
model.generate(input_ids, streamer=streamer, **gen_kwargs)
except Exception as exc: # surface to the endpoint
err["msg"] = f"{type(exc).__name__}: {exc}"
q.put(("error", str(exc), 0, 0))
finally:
q.put(("end", "", 0, 0)) # always unblock the consumer
thread = threading.Thread(target=worker)
thread.start()
try:
while True:
kind, text, block, step = q.get()
if kind == "error":
yield ("error", err.get("msg", text), 0, 0)
return
if kind == "end":
return
yield (kind, text, block, step)
finally:
thread.join()
# --------------------------------------------------------------------------- #
# Server
# --------------------------------------------------------------------------- #
app = Server(title="Gemma Diffusion 3D Asset Builder")
@app.api(name="generate", concurrency_limit=1, time_limit=600, stream_every=0.05)
def generate(
prompt: str,
history_json: str = "[]",
max_new_tokens: int = 2048,
max_iters: int = 64,
full_denoise: bool = False,
anim_delay: float = 0.0,
warm_start: bool = True,
) -> str:
"""Stream the diffusion generation as JSON frames (one per denoising step).
The model emits a raw SVG illustration; the frontend extrudes it into 3D with Three.js.
"""
prompt = (prompt or "").strip()
if not prompt:
yield json.dumps({"kind": "error", "message": "Empty prompt."})
return
messages = build_messages(history_json, prompt)
max_iters = max(1, min(int(max_iters), MAX_ITERS_CAP))
# Tweak warm-start: seed the diffusion's first canvas with the previous artwork's own
# tokens (native `canvas_ids` API) so the model edits the existing SVG in place.
is_tweak = bool(last_assistant_html(history_json))
canvas_ids = warm_canvas_from_cache() if (warm_start and is_tweak) else None
warming = canvas_ids is not None
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)["input_ids"]
last_text = ""
for kind, text, block, step in _gpu_stream(
input_ids, int(max_new_tokens), max_iters, bool(full_denoise), canvas_ids
):
if kind == "error":
yield json.dumps({"kind": "error", "message": text})
return
last_text = text
yield json.dumps(
{
"kind": "draft" if kind == "draft" else "commit",
"source": clean_text(text),
"block": block,
"step": step,
"canvas": CANVAS_LEN,
"max_iters": max_iters,
"warming": warming,
}
)
if anim_delay and kind == "draft":
_time.sleep(float(anim_delay))
final_source = extract_svg(last_text)
# Cache the *cleaned* SVG so the next tweak warm-starts from a valid header.
if final_source.strip():
model._last_clean_html = final_source
yield json.dumps({"kind": "done", "source": final_source})
@app.get("/", response_class=HTMLResponse)
async def homepage():
with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f:
return f.read()
# HF Spaces' gradio runtime looks for a top-level `demo` (or `app`) to launch.
demo = app
if __name__ == "__main__":
app.launch(
server_name=os.environ.get("GDIFF_HOST", "0.0.0.0"),
server_port=int(os.environ.get("GDIFF_PORT", "7860")),
show_error=True,
)