Spaces:
Running on Zero
Running on Zero
File size: 13,633 Bytes
7d99dfd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 | """
Gemma Diffusion — live website builder (gradio.Server backend + custom frontend).
ZeroGPU port. `gradio.Server` (a FastAPI subclass) gives us Gradio's queue + SSE
streaming while we serve our own hand-written HTML/CSS/JS frontend. The single
streaming endpoint `/generate` runs the block-diffusion model and yields JSON frames
(one per denoising step) that the frontend renders side-by-side: the raw HTML canvas
diffusing on the left, the live rendered page on the right.
ZeroGPU specifics:
- `import spaces` happens before `torch`.
- The model is loaded once at module scope with `.to("cuda")` (ZeroGPU registers it).
- The actual `model.generate` call lives inside the `@spaces.GPU` function `_gpu_stream`;
the `gradio.Server` endpoint only marshals picklable CPU tensors in/out of it.
Refs:
- https://huggingface.co/blog/introducing-gradio-server
- https://huggingface.co/docs/hub/spaces-zerogpu
"""
import glob
import os
import subprocess
import sys
# Set before torch is imported (transformers pulls torch in).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # must precede torch so ZeroGPU can patch it
def _ensure_transformers():
"""Install the bundled custom DiffusionGemma `transformers` wheel at runtime.
Spaces installs `requirements.txt` *before* copying the repo files into the image,
so the wheel can't be referenced by local path there. By the time this app runs the
file is present in the working directory, so we install it here (only if a stock /
no transformers is importable) before importing torch/transformers below.
"""
try:
import transformers # noqa: F401
if hasattr(transformers, "DiffusionGemmaForBlockDiffusion") or hasattr(
getattr(transformers, "models", object), "diffusion_gemma"
):
return
except Exception:
pass
wheels = sorted(glob.glob(os.path.join(os.path.dirname(os.path.abspath(__file__)), "transformers-*.whl")))
if not wheels:
return
print(f"[gdiff] Installing bundled transformers wheel: {os.path.basename(wheels[0])}", flush=True)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", wheels[0]])
import importlib
importlib.invalidate_caches()
_ensure_transformers()
import json
import queue as queue_lib
import re
import threading
import time as _time
import torch
from fastapi.responses import HTMLResponse
from gradio import Server
from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion
from transformers.generation.streamers import BaseStreamer
HERE = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.environ.get("GDIFF_MODEL_PATH", "google/diffusiongemma-26B-A4B-it")
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_ITERS_CAP = 120 # hard cap on denoising steps per block
# ZeroGPU: the 26B checkpoint (~49 GB bf16) needs the full backing card.
GPU_SIZE = os.environ.get("GDIFF_GPU_SIZE", "xlarge")
SYSTEM_PROMPT = (
"You are an expert front-end web developer with great visual taste. When asked to "
"build or change a web page, respond with a SINGLE, complete, self-contained HTML5 "
"document. Put all CSS in a <style> tag and any JavaScript in a <script> tag inside "
"the document. Do not load external assets. When asked to modify an existing page, "
"return the FULL updated HTML document with the change applied. Do not include "
"explanations or markdown code fences — output only raw HTML, starting with "
"<!DOCTYPE html>."
)
_MARKER_RE = re.compile(
r"<\|?(?:channel|turn|think|image|audio|video|tool(?:_call|_response)?)\|?>"
)
_FENCE_RE = re.compile(r"```(?:html)?\s*(.*?)\s*```", re.DOTALL)
# --------------------------------------------------------------------------- #
# Model (loaded once at module scope; ZeroGPU registers .to("cuda") tensors)
# --------------------------------------------------------------------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[gdiff] Loading model from {MODEL_PATH} on {DEVICE} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
MODEL_PATH,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
token=HF_TOKEN,
).to(DEVICE)
model.eval()
CANVAS_LEN = model.config.canvas_length
PAD_ID = tokenizer.pad_token_id or 0
print(f"[gdiff] Model ready | canvas_length={CANVAS_LEN}", flush=True)
# Cache of the last *cleaned* page so a follow-up tweak can warm-start in place.
model._last_clean_html = None
# --------------------------------------------------------------------------- #
# Helpers (CPU-only; safe to run in the gradio.Server main process)
# --------------------------------------------------------------------------- #
def warm_canvas_from_cache():
"""Starting canvas (first block) built from the previous *cleaned* page.
Returns a CPU tensor (it is pickled across the ZeroGPU process boundary and moved
to CUDA inside the GPU worker). We re-tokenize the cleaned HTML rather than reuse
raw output tokens so a mangled header can't compound across tweaks.
"""
html = getattr(model, "_last_clean_html", None)
if not html:
return None
ids = tokenizer(html, add_special_tokens=False).input_ids[:CANVAS_LEN]
if not ids:
return None
if len(ids) < CANVAS_LEN:
ids = ids + [PAD_ID] * (CANVAS_LEN - len(ids))
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
def last_assistant_html(history_json: str):
try:
history = json.loads(history_json) if history_json else []
except json.JSONDecodeError:
return None
for turn in reversed(history):
if turn.get("role") == "assistant" and turn.get("content"):
return turn["content"]
return None
def clean_text(text: str) -> str:
return _MARKER_RE.sub("", text).lstrip()
def extract_html(text: str) -> str:
"""Pull a usable HTML document out of the (possibly mangled) model output.
Anchor on the first intact structural tag and rebuild whatever the diffused tweak ate
off the front, so the result is always a valid document (never quirks mode and never a
broken ``DOCTYPE>`` / ``html lang=`` header).
"""
text = clean_text(text)
fenced = _FENCE_RE.search(text)
if fenced:
text = fenced.group(1)
lower = text.lower()
dt = lower.find("<!doctype")
if dt != -1:
return text[dt:].strip()
h = lower.find("<html")
if h != -1:
return "<!DOCTYPE html>\n" + text[h:].strip()
hd = lower.find("<head")
if hd != -1:
return '<!DOCTYPE html>\n<html lang="en">\n' + text[hd:].strip()
bd = lower.find("<body")
if bd != -1:
return (
'<!DOCTYPE html>\n<html lang="en">\n<head><meta charset="UTF-8">'
'<meta name="viewport" content="width=device-width, initial-scale=1.0"></head>\n'
+ text[bd:].strip()
)
return text.strip()
class QueueDiffusionStreamer(BaseStreamer):
def __init__(self, tok, q: "queue_lib.Queue"):
self.tok = tok
self.q = q
self.confirmed_ids: list[int] = []
self.prompt_skipped = False
self.block = 0
self.step = 0
def _decode(self, ids):
return self.tok.decode(ids, skip_special_tokens=True)
def put(self, value):
ids = value[0].tolist() if value.dim() > 1 else value.tolist()
if not self.prompt_skipped:
self.prompt_skipped = True
return
self.confirmed_ids.extend(ids)
self.block += 1
self.step = 0
self.q.put(("commit", self._decode(self.confirmed_ids), self.block, self.step))
def put_draft(self, value):
self.step += 1
ids = value[0].tolist() if value.dim() > 1 else value.tolist()
self.q.put(("draft", self._decode(self.confirmed_ids + ids), self.block + 1, self.step))
def end(self):
self.q.put(("end", self._decode(self.confirmed_ids), self.block, self.step))
def build_messages(history_json: str, prompt: str):
try:
history = json.loads(history_json) if history_json else []
except json.JSONDecodeError:
history = []
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for turn in history:
role = turn.get("role")
content = turn.get("content", "")
if role in ("user", "assistant") and content:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": prompt})
return messages
# --------------------------------------------------------------------------- #
# GPU work — runs in a forked ZeroGPU worker process.
# Inputs/outputs cross the boundary via pickle, so only CPU tensors / plain
# Python objects go in and out (no CUDA tensors are returned).
# --------------------------------------------------------------------------- #
def _estimate_duration(input_ids, max_new_tokens=2048, max_iters=64, full_denoise=False, canvas_ids=None):
blocks = max(1, int(max_new_tokens) // max(1, CANVAS_LEN))
secs = 30 + blocks * int(max_iters) * 0.3
return int(min(120, secs)) # xlarge internally doubles this for the quota check
@spaces.GPU(duration=_estimate_duration, size=GPU_SIZE)
def _gpu_stream(input_ids, max_new_tokens, max_iters, full_denoise, canvas_ids):
input_ids = input_ids.to(model.device)
gen_kwargs = dict(max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_iters))
if full_denoise:
gen_kwargs["confidence_threshold"] = 1e-9
gen_kwargs["stability_threshold"] = int(max_iters)
if canvas_ids is not None:
gen_kwargs["canvas_ids"] = canvas_ids.to(model.device)
q: "queue_lib.Queue" = queue_lib.Queue()
streamer = QueueDiffusionStreamer(tokenizer, q)
err = {}
def worker():
try:
with torch.inference_mode():
model.generate(input_ids, streamer=streamer, **gen_kwargs)
except Exception as exc: # surface to the endpoint
err["msg"] = f"{type(exc).__name__}: {exc}"
q.put(("error", str(exc), 0, 0))
finally:
q.put(("end", "", 0, 0)) # always unblock the consumer
thread = threading.Thread(target=worker)
thread.start()
try:
while True:
kind, text, block, step = q.get()
if kind == "error":
yield ("error", err.get("msg", text), 0, 0)
return
if kind == "end":
return
yield (kind, text, block, step)
finally:
thread.join()
# --------------------------------------------------------------------------- #
# Server
# --------------------------------------------------------------------------- #
app = Server(title="Gemma Diffusion Website Builder")
@app.api(name="generate", concurrency_limit=1, time_limit=600, stream_every=0.05)
def generate(
prompt: str,
history_json: str = "[]",
max_new_tokens: int = 2048,
max_iters: int = 64,
full_denoise: bool = False,
anim_delay: float = 0.0,
warm_start: bool = True,
) -> str:
"""Stream the diffusion generation as JSON frames (one per denoising step).
The model writes a self-contained HTML document; the frontend renders it live.
"""
prompt = (prompt or "").strip()
if not prompt:
yield json.dumps({"kind": "error", "message": "Empty prompt."})
return
messages = build_messages(history_json, prompt)
max_iters = max(1, min(int(max_iters), MAX_ITERS_CAP))
# Tweak warm-start: seed the diffusion's first canvas with the previous page's own
# tokens (native `canvas_ids` API) so the model edits the existing page in place.
is_tweak = bool(last_assistant_html(history_json))
canvas_ids = warm_canvas_from_cache() if (warm_start and is_tweak) else None
warming = canvas_ids is not None
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)["input_ids"]
last_text = ""
for kind, text, block, step in _gpu_stream(
input_ids, int(max_new_tokens), max_iters, bool(full_denoise), canvas_ids
):
if kind == "error":
yield json.dumps({"kind": "error", "message": text})
return
last_text = text
yield json.dumps(
{
"kind": "draft" if kind == "draft" else "commit",
"source": clean_text(text),
"block": block,
"step": step,
"canvas": CANVAS_LEN,
"max_iters": max_iters,
"warming": warming,
}
)
if anim_delay and kind == "draft":
_time.sleep(float(anim_delay))
final_source = extract_html(last_text)
# Cache the *cleaned* output so the next tweak warm-starts from a valid header.
if final_source.strip():
model._last_clean_html = final_source
yield json.dumps({"kind": "done", "source": final_source})
@app.get("/", response_class=HTMLResponse)
async def homepage():
with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f:
return f.read()
# HF Spaces' gradio runtime looks for a top-level `demo` (or `app`) to launch.
demo = app
if __name__ == "__main__":
app.launch(
server_name=os.environ.get("GDIFF_HOST", "0.0.0.0"),
server_port=int(os.environ.get("GDIFF_PORT", "7860")),
show_error=True,
)
|