File size: 13,633 Bytes
7d99dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
Gemma Diffusion — live website builder (gradio.Server backend + custom frontend).

ZeroGPU port. `gradio.Server` (a FastAPI subclass) gives us Gradio's queue + SSE
streaming while we serve our own hand-written HTML/CSS/JS frontend. The single
streaming endpoint `/generate` runs the block-diffusion model and yields JSON frames
(one per denoising step) that the frontend renders side-by-side: the raw HTML canvas
diffusing on the left, the live rendered page on the right.

ZeroGPU specifics:
- `import spaces` happens before `torch`.
- The model is loaded once at module scope with `.to("cuda")` (ZeroGPU registers it).
- The actual `model.generate` call lives inside the `@spaces.GPU` function `_gpu_stream`;
  the `gradio.Server` endpoint only marshals picklable CPU tensors in/out of it.

Refs:
- https://huggingface.co/blog/introducing-gradio-server
- https://huggingface.co/docs/hub/spaces-zerogpu
"""

import glob
import os
import subprocess
import sys

# Set before torch is imported (transformers pulls torch in).
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import spaces  # must precede torch so ZeroGPU can patch it


def _ensure_transformers():
    """Install the bundled custom DiffusionGemma `transformers` wheel at runtime.

    Spaces installs `requirements.txt` *before* copying the repo files into the image,
    so the wheel can't be referenced by local path there. By the time this app runs the
    file is present in the working directory, so we install it here (only if a stock /
    no transformers is importable) before importing torch/transformers below.
    """
    try:
        import transformers  # noqa: F401

        if hasattr(transformers, "DiffusionGemmaForBlockDiffusion") or hasattr(
            getattr(transformers, "models", object), "diffusion_gemma"
        ):
            return
    except Exception:
        pass
    wheels = sorted(glob.glob(os.path.join(os.path.dirname(os.path.abspath(__file__)), "transformers-*.whl")))
    if not wheels:
        return
    print(f"[gdiff] Installing bundled transformers wheel: {os.path.basename(wheels[0])}", flush=True)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", wheels[0]])
    import importlib

    importlib.invalidate_caches()


_ensure_transformers()

import json
import queue as queue_lib
import re
import threading
import time as _time

import torch
from fastapi.responses import HTMLResponse
from gradio import Server
from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion
from transformers.generation.streamers import BaseStreamer

HERE = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.environ.get("GDIFF_MODEL_PATH", "google/diffusiongemma-26B-A4B-it")
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_ITERS_CAP = 120  # hard cap on denoising steps per block
# ZeroGPU: the 26B checkpoint (~49 GB bf16) needs the full backing card.
GPU_SIZE = os.environ.get("GDIFF_GPU_SIZE", "xlarge")

SYSTEM_PROMPT = (
    "You are an expert front-end web developer with great visual taste. When asked to "
    "build or change a web page, respond with a SINGLE, complete, self-contained HTML5 "
    "document. Put all CSS in a <style> tag and any JavaScript in a <script> tag inside "
    "the document. Do not load external assets. When asked to modify an existing page, "
    "return the FULL updated HTML document with the change applied. Do not include "
    "explanations or markdown code fences — output only raw HTML, starting with "
    "<!DOCTYPE html>."
)

_MARKER_RE = re.compile(
    r"<\|?(?:channel|turn|think|image|audio|video|tool(?:_call|_response)?)\|?>"
)
_FENCE_RE = re.compile(r"```(?:html)?\s*(.*?)\s*```", re.DOTALL)


# --------------------------------------------------------------------------- #
# Model (loaded once at module scope; ZeroGPU registers .to("cuda") tensors)
# --------------------------------------------------------------------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[gdiff] Loading model from {MODEL_PATH} on {DEVICE} ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
    MODEL_PATH,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    token=HF_TOKEN,
).to(DEVICE)
model.eval()
CANVAS_LEN = model.config.canvas_length
PAD_ID = tokenizer.pad_token_id or 0
print(f"[gdiff] Model ready | canvas_length={CANVAS_LEN}", flush=True)

# Cache of the last *cleaned* page so a follow-up tweak can warm-start in place.
model._last_clean_html = None


# --------------------------------------------------------------------------- #
# Helpers (CPU-only; safe to run in the gradio.Server main process)
# --------------------------------------------------------------------------- #
def warm_canvas_from_cache():
    """Starting canvas (first block) built from the previous *cleaned* page.

    Returns a CPU tensor (it is pickled across the ZeroGPU process boundary and moved
    to CUDA inside the GPU worker). We re-tokenize the cleaned HTML rather than reuse
    raw output tokens so a mangled header can't compound across tweaks.
    """
    html = getattr(model, "_last_clean_html", None)
    if not html:
        return None
    ids = tokenizer(html, add_special_tokens=False).input_ids[:CANVAS_LEN]
    if not ids:
        return None
    if len(ids) < CANVAS_LEN:
        ids = ids + [PAD_ID] * (CANVAS_LEN - len(ids))
    return torch.tensor(ids, dtype=torch.long).unsqueeze(0)


def last_assistant_html(history_json: str):
    try:
        history = json.loads(history_json) if history_json else []
    except json.JSONDecodeError:
        return None
    for turn in reversed(history):
        if turn.get("role") == "assistant" and turn.get("content"):
            return turn["content"]
    return None


def clean_text(text: str) -> str:
    return _MARKER_RE.sub("", text).lstrip()


def extract_html(text: str) -> str:
    """Pull a usable HTML document out of the (possibly mangled) model output.

    Anchor on the first intact structural tag and rebuild whatever the diffused tweak ate
    off the front, so the result is always a valid document (never quirks mode and never a
    broken ``DOCTYPE>`` / ``html lang=`` header).
    """
    text = clean_text(text)
    fenced = _FENCE_RE.search(text)
    if fenced:
        text = fenced.group(1)
    lower = text.lower()
    dt = lower.find("<!doctype")
    if dt != -1:
        return text[dt:].strip()
    h = lower.find("<html")
    if h != -1:
        return "<!DOCTYPE html>\n" + text[h:].strip()
    hd = lower.find("<head")
    if hd != -1:
        return '<!DOCTYPE html>\n<html lang="en">\n' + text[hd:].strip()
    bd = lower.find("<body")
    if bd != -1:
        return (
            '<!DOCTYPE html>\n<html lang="en">\n<head><meta charset="UTF-8">'
            '<meta name="viewport" content="width=device-width, initial-scale=1.0"></head>\n'
            + text[bd:].strip()
        )
    return text.strip()


class QueueDiffusionStreamer(BaseStreamer):
    def __init__(self, tok, q: "queue_lib.Queue"):
        self.tok = tok
        self.q = q
        self.confirmed_ids: list[int] = []
        self.prompt_skipped = False
        self.block = 0
        self.step = 0

    def _decode(self, ids):
        return self.tok.decode(ids, skip_special_tokens=True)

    def put(self, value):
        ids = value[0].tolist() if value.dim() > 1 else value.tolist()
        if not self.prompt_skipped:
            self.prompt_skipped = True
            return
        self.confirmed_ids.extend(ids)
        self.block += 1
        self.step = 0
        self.q.put(("commit", self._decode(self.confirmed_ids), self.block, self.step))

    def put_draft(self, value):
        self.step += 1
        ids = value[0].tolist() if value.dim() > 1 else value.tolist()
        self.q.put(("draft", self._decode(self.confirmed_ids + ids), self.block + 1, self.step))

    def end(self):
        self.q.put(("end", self._decode(self.confirmed_ids), self.block, self.step))


def build_messages(history_json: str, prompt: str):
    try:
        history = json.loads(history_json) if history_json else []
    except json.JSONDecodeError:
        history = []
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for turn in history:
        role = turn.get("role")
        content = turn.get("content", "")
        if role in ("user", "assistant") and content:
            messages.append({"role": role, "content": content})
    messages.append({"role": "user", "content": prompt})
    return messages


# --------------------------------------------------------------------------- #
# GPU work — runs in a forked ZeroGPU worker process.
# Inputs/outputs cross the boundary via pickle, so only CPU tensors / plain
# Python objects go in and out (no CUDA tensors are returned).
# --------------------------------------------------------------------------- #
def _estimate_duration(input_ids, max_new_tokens=2048, max_iters=64, full_denoise=False, canvas_ids=None):
    blocks = max(1, int(max_new_tokens) // max(1, CANVAS_LEN))
    secs = 30 + blocks * int(max_iters) * 0.3
    return int(min(120, secs))  # xlarge internally doubles this for the quota check


@spaces.GPU(duration=_estimate_duration, size=GPU_SIZE)
def _gpu_stream(input_ids, max_new_tokens, max_iters, full_denoise, canvas_ids):
    input_ids = input_ids.to(model.device)
    gen_kwargs = dict(max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_iters))
    if full_denoise:
        gen_kwargs["confidence_threshold"] = 1e-9
        gen_kwargs["stability_threshold"] = int(max_iters)
    if canvas_ids is not None:
        gen_kwargs["canvas_ids"] = canvas_ids.to(model.device)

    q: "queue_lib.Queue" = queue_lib.Queue()
    streamer = QueueDiffusionStreamer(tokenizer, q)
    err = {}

    def worker():
        try:
            with torch.inference_mode():
                model.generate(input_ids, streamer=streamer, **gen_kwargs)
        except Exception as exc:  # surface to the endpoint
            err["msg"] = f"{type(exc).__name__}: {exc}"
            q.put(("error", str(exc), 0, 0))
        finally:
            q.put(("end", "", 0, 0))  # always unblock the consumer

    thread = threading.Thread(target=worker)
    thread.start()
    try:
        while True:
            kind, text, block, step = q.get()
            if kind == "error":
                yield ("error", err.get("msg", text), 0, 0)
                return
            if kind == "end":
                return
            yield (kind, text, block, step)
    finally:
        thread.join()


# --------------------------------------------------------------------------- #
# Server
# --------------------------------------------------------------------------- #
app = Server(title="Gemma Diffusion Website Builder")


@app.api(name="generate", concurrency_limit=1, time_limit=600, stream_every=0.05)
def generate(
    prompt: str,
    history_json: str = "[]",
    max_new_tokens: int = 2048,
    max_iters: int = 64,
    full_denoise: bool = False,
    anim_delay: float = 0.0,
    warm_start: bool = True,
) -> str:
    """Stream the diffusion generation as JSON frames (one per denoising step).

    The model writes a self-contained HTML document; the frontend renders it live.
    """
    prompt = (prompt or "").strip()
    if not prompt:
        yield json.dumps({"kind": "error", "message": "Empty prompt."})
        return

    messages = build_messages(history_json, prompt)
    max_iters = max(1, min(int(max_iters), MAX_ITERS_CAP))

    # Tweak warm-start: seed the diffusion's first canvas with the previous page's own
    # tokens (native `canvas_ids` API) so the model edits the existing page in place.
    is_tweak = bool(last_assistant_html(history_json))
    canvas_ids = warm_canvas_from_cache() if (warm_start and is_tweak) else None
    warming = canvas_ids is not None

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    )["input_ids"]

    last_text = ""
    for kind, text, block, step in _gpu_stream(
        input_ids, int(max_new_tokens), max_iters, bool(full_denoise), canvas_ids
    ):
        if kind == "error":
            yield json.dumps({"kind": "error", "message": text})
            return
        last_text = text
        yield json.dumps(
            {
                "kind": "draft" if kind == "draft" else "commit",
                "source": clean_text(text),
                "block": block,
                "step": step,
                "canvas": CANVAS_LEN,
                "max_iters": max_iters,
                "warming": warming,
            }
        )
        if anim_delay and kind == "draft":
            _time.sleep(float(anim_delay))

    final_source = extract_html(last_text)
    # Cache the *cleaned* output so the next tweak warm-starts from a valid header.
    if final_source.strip():
        model._last_clean_html = final_source
    yield json.dumps({"kind": "done", "source": final_source})


@app.get("/", response_class=HTMLResponse)
async def homepage():
    with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f:
        return f.read()


# HF Spaces' gradio runtime looks for a top-level `demo` (or `app`) to launch.
demo = app

if __name__ == "__main__":
    app.launch(
        server_name=os.environ.get("GDIFF_HOST", "0.0.0.0"),
        server_port=int(os.environ.get("GDIFF_PORT", "7860")),
        show_error=True,
    )