File size: 10,657 Bytes
39049cd
 
 
 
 
9fd6021
39049cd
 
 
 
 
 
 
 
1719d38
 
39049cd
 
1719d38
 
 
 
 
 
 
 
 
 
 
39049cd
 
1719d38
39049cd
1719d38
 
 
 
39049cd
 
 
08f4b6e
39049cd
 
08f4b6e
 
39049cd
1719d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39049cd
 
 
 
1719d38
08f4b6e
1719d38
 
 
 
 
 
 
 
 
39049cd
 
 
 
 
 
08f4b6e
39049cd
 
08f4b6e
1719d38
 
 
39049cd
 
 
 
 
1719d38
39049cd
 
1719d38
 
39049cd
 
 
 
 
9fd6021
39049cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1719d38
 
39049cd
 
 
 
 
08f4b6e
1719d38
39049cd
 
252def7
 
 
 
 
 
 
 
 
 
 
39049cd
252def7
39049cd
 
 
 
252def7
39049cd
252def7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39049cd
 
 
 
 
 
 
 
 
 
 
851105d
39049cd
 
 
1719d38
 
39049cd
 
 
 
6834b26
39049cd
 
 
 
 
 
 
 
 
 
 
1719d38
39049cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fd6021
39049cd
 
 
 
 
 
 
 
9fd6021
39049cd
 
 
 
 
 
 
1719d38
39049cd
 
 
 
 
 
 
 
 
9fd6021
39049cd
 
 
 
 
 
 
 
 
851105d
 
 
39049cd
851105d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import spaces
import torch
import gradio as gr
from threading import Thread
from queue import Queue
from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor, TextDiffusionStreamer

MODEL_ID = "google/diffusiongemma-26B-A4B-it"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(MODEL_ID, dtype=torch.bfloat16)
model.to("cuda")
model.eval()

CANVAS_LENGTH = getattr(model.config, "canvas_length", 256)

_SENTINEL = object()

# Per-token denoising colors. DiffusionGemma uses random-token *renoising* (not [MASK]
# diffusion): the entropy sampler locks low-entropy positions while the rest are random
# noise each step. So a position's stability across steps is our confidence proxy.
COLOR_MAP = {
    "done": "#66CC66",    # committed block (final)
    "stable": "#8FD18F",  # settled for several steps
    "mid": "#FFCC66",     # settling
    "noise": "#E8896B",   # just changed / still noisy
}
_STABLE_STEPS = 3  # unchanged for >= this many steps -> "stable"


class CanvasStreamer(TextDiffusionStreamer):
    """Pushes (committed_text, draft_segments, draft_plain) snapshots to a queue.

    `put_draft` fires every denoising step with the full argmax canvas of the block
    being denoised. We track, per position, how many consecutive steps its token has
    been unchanged ("settle" count) and color it accordingly, so the canvas visibly
    condenses from noise into settled text. `put` fires when a block is committed.
    """

    def __init__(self, tokenizer, **kwargs):
        super().__init__(tokenizer, skip_special_tokens=True, **kwargs)
        self.queue = Queue()
        self.committed = ""
        self.last_draft = ""
        self.started = False
        self._takes_logits = False
        self.prev_ids = None
        self.settle = None
        self.special_ids = set(tokenizer.all_special_ids)

    def _render(self, ids):
        """Color positions up to the furthest settled real token (the 'frontier')."""
        frontier = -1
        for i, tid in enumerate(ids):
            if tid not in self.special_ids and self.settle[i] >= 2:
                frontier = i
        segments = []
        plain = []
        cur_text = ""
        cur_cls = None
        for i in range(frontier + 1):
            tid = ids[i]
            if tid in self.special_ids:
                continue
            piece = self.tokenizer.decode([tid], skip_special_tokens=True)
            if not piece:
                continue
            s = self.settle[i]
            cls = "stable" if s >= _STABLE_STEPS else ("mid" if s >= 1 else "noise")
            plain.append(piece)
            if cls == cur_cls:
                cur_text += piece
            else:
                if cur_text:
                    segments.append((cur_text, cur_cls))
                cur_text, cur_cls = piece, cls
        if cur_text:
            segments.append((cur_text, cur_cls))
        return segments, "".join(plain)

    def put_draft(self, value, **kwargs):
        if len(value.shape) > 1:
            value = value[0]
        ids = value.tolist()
        self.started = True
        if self.prev_ids is None or len(self.prev_ids) != len(ids):
            self.settle = [0] * len(ids)
        else:
            for i, tid in enumerate(ids):
                self.settle[i] = self.settle[i] + 1 if tid == self.prev_ids[i] else 0
        self.prev_ids = ids
        segments, plain = self._render(ids)
        self.last_draft = plain
        self.queue.put((self.committed, segments, plain))

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("batch size 1 only")
        elif len(value.shape) > 1:
            value = value[0]
        if not self.started:  # prompt context, before any denoising step
            return
        self.committed += self.tokenizer.decode(value, skip_special_tokens=True)
        self.last_draft = ""
        self.prev_ids = None
        self.settle = None
        self.queue.put((self.committed, [], ""))

    def end(self):
        self.queue.put(_SENTINEL)


def build_display(committed, segments):
    out = []
    if committed:
        out.append((committed, "done"))
    out.extend(segments)
    return out


@spaces.GPU(duration=150, size="xlarge")
@torch.no_grad()
def generate_streaming(messages, max_new_tokens, max_denoising_steps, enable_thinking):
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        enable_thinking=enable_thinking,
    ).to("cuda")

    streamer = CanvasStreamer(processor.tokenizer)
    result = {}

    def run():
        try:
            out = model.generate(
                **inputs,
                max_new_tokens=int(max_new_tokens),
                max_denoising_steps=int(max_denoising_steps),
                streamer=streamer,
            )
            result["ids"] = out
        except Exception as e:
            result["error"] = e
            streamer.queue.put(_SENTINEL)

    thread = Thread(target=run)
    thread.start()

    while True:
        item = streamer.queue.get()
        if item is _SENTINEL:
            break
        committed, segments, plain = item
        yield build_display(committed, segments), (committed + plain), None

    thread.join()
    if "error" in result:
        raise result["error"]

    final_text = (streamer.committed + streamer.last_draft).strip()
    yield [(final_text, "done")] if final_text else [], final_text, final_text


def _file_path(item):
    """Extract a local file path from a Gradio content part / file dict."""
    for key in ("file", "path", "url"):
        val = item.get(key)
        if isinstance(val, dict):
            val = val.get("path") or val.get("url")
        if val:
            return val
    return None


def to_model_messages(history):
    """Convert Gradio (messages format) history into processor chat format with images."""
    messages = []
    for msg in history:
        role = msg["role"]
        content = msg["content"]
        parts = []
        if isinstance(content, str):
            parts.append({"type": "text", "text": content})
        elif isinstance(content, tuple):  # legacy (path, alt) file tuple
            parts.append({"type": "image", "url": content[0]})
        elif isinstance(content, dict):  # single file part, e.g. {"path": ...}
            p = _file_path(content)
            if p:
                parts.append({"type": "image", "url": p})
        elif isinstance(content, list):
            for item in content:
                if not isinstance(item, dict):
                    parts.append({"type": "text", "text": str(item)})
                elif item.get("type") == "text" or "text" in item:
                    parts.append({"type": "text", "text": item.get("text", "")})
                else:
                    p = _file_path(item)
                    if p:
                        parts.append({"type": "image", "url": p})
        if parts:
            messages.append({"role": role, "content": parts})
    return messages


css = """
.category-legend{display:none}
.legend{margin-bottom: 5px}
.legend-item{height: 25px}
"""


def create_demo():
    with gr.Blocks(title="DiffusionGemma", css=css) as demo:
        gr.Markdown("# DiffusionGemma 26B-A4B — Block Diffusion Chat")
        gr.Markdown(
            "[model](https://huggingface.co/google/diffusiongemma-26B-A4B-it) · "
            "Watch the canvas denoise in real time on the right — text condenses from "
            "noise (orange) into settled output (green). Attach an image to ask about it."
        )

        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(label="Conversation")
                chat_input = gr.MultimodalTextbox(
                    interactive=True,
                    file_types=["image"],
                    placeholder="Type a message and/or attach an image…",
                    show_label=False,
                )
            with gr.Column(scale=2):
                canvas = gr.HighlightedText(
                    label="Denoising canvas",
                    combine_adjacent=False,
                    show_legend=True,
                    color_map=COLOR_MAP,
                )

        with gr.Accordion("Generation settings", open=False):
            with gr.Row():
                enable_thinking = gr.Checkbox(value=False, label="Thinking mode")
            with gr.Row():
                max_new_tokens = gr.Slider(64, 1024, value=256, step=64, label="Max new tokens")
                max_denoising_steps = gr.Slider(8, 64, value=48, step=4, label="Max denoising steps")

        clear_btn = gr.Button("Clear conversation")

        def add_message(message, history):
            history = history or []
            for f in message.get("files", []):
                history.append({"role": "user", "content": {"path": f}})
            if message.get("text"):
                history.append({"role": "user", "content": message["text"]})
            return history, gr.MultimodalTextbox(value=None, interactive=False)

        def bot(history, max_new_tokens, max_denoising_steps, enable_thinking):
            if not history:
                yield history, []
                return
            messages = to_model_messages(history)
            history = history + [{"role": "assistant", "content": ""}]
            final = None
            try:
                for canvas_state, plain, text in generate_streaming(
                    messages, max_new_tokens, max_denoising_steps, enable_thinking
                ):
                    if text is not None:
                        final = text
                    history[-1]["content"] = final if final is not None else plain
                    yield history, canvas_state
            except Exception as e:
                history[-1]["content"] = f"Error: {e}"
                yield history, [(str(e), "noise")]

        def reenable():
            return gr.MultimodalTextbox(interactive=True)

        chat_msg = chat_input.submit(
            add_message, [chat_input, chatbot], [chatbot, chat_input]
        )
        bot_msg = chat_msg.then(
            bot,
            [chatbot, max_new_tokens, max_denoising_steps, enable_thinking],
            [chatbot, canvas],
        )
        bot_msg.then(reenable, None, [chat_input])

        clear_btn.click(lambda: ([], []), None, [chatbot, canvas])

    return demo


demo = create_demo()
demo.queue()

if __name__ == "__main__":
    demo.launch()