Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import torch | |
| import gradio as gr | |
| from threading import Thread | |
| from queue import Queue | |
| from transformers import DiffusionGemmaForBlockDiffusion, AutoProcessor, TextDiffusionStreamer | |
| MODEL_ID = "google/diffusiongemma-26B-A4B-it" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = DiffusionGemmaForBlockDiffusion.from_pretrained(MODEL_ID, dtype=torch.bfloat16) | |
| model.to("cuda") | |
| model.eval() | |
| CANVAS_LENGTH = getattr(model.config, "canvas_length", 256) | |
| _SENTINEL = object() | |
| # Per-token denoising colors. DiffusionGemma uses random-token *renoising* (not [MASK] | |
| # diffusion): the entropy sampler locks low-entropy positions while the rest are random | |
| # noise each step. So a position's stability across steps is our confidence proxy. | |
| COLOR_MAP = { | |
| "done": "#66CC66", # committed block (final) | |
| "stable": "#8FD18F", # settled for several steps | |
| "mid": "#FFCC66", # settling | |
| "noise": "#E8896B", # just changed / still noisy | |
| } | |
| _STABLE_STEPS = 3 # unchanged for >= this many steps -> "stable" | |
| class CanvasStreamer(TextDiffusionStreamer): | |
| """Pushes (committed_text, draft_segments, draft_plain) snapshots to a queue. | |
| `put_draft` fires every denoising step with the full argmax canvas of the block | |
| being denoised. We track, per position, how many consecutive steps its token has | |
| been unchanged ("settle" count) and color it accordingly, so the canvas visibly | |
| condenses from noise into settled text. `put` fires when a block is committed. | |
| """ | |
| def __init__(self, tokenizer, **kwargs): | |
| super().__init__(tokenizer, skip_special_tokens=True, **kwargs) | |
| self.queue = Queue() | |
| self.committed = "" | |
| self.last_draft = "" | |
| self.started = False | |
| self._takes_logits = False | |
| self.prev_ids = None | |
| self.settle = None | |
| self.special_ids = set(tokenizer.all_special_ids) | |
| def _render(self, ids): | |
| """Color positions up to the furthest settled real token (the 'frontier').""" | |
| frontier = -1 | |
| for i, tid in enumerate(ids): | |
| if tid not in self.special_ids and self.settle[i] >= 2: | |
| frontier = i | |
| segments = [] | |
| plain = [] | |
| cur_text = "" | |
| cur_cls = None | |
| for i in range(frontier + 1): | |
| tid = ids[i] | |
| if tid in self.special_ids: | |
| continue | |
| piece = self.tokenizer.decode([tid], skip_special_tokens=True) | |
| if not piece: | |
| continue | |
| s = self.settle[i] | |
| cls = "stable" if s >= _STABLE_STEPS else ("mid" if s >= 1 else "noise") | |
| plain.append(piece) | |
| if cls == cur_cls: | |
| cur_text += piece | |
| else: | |
| if cur_text: | |
| segments.append((cur_text, cur_cls)) | |
| cur_text, cur_cls = piece, cls | |
| if cur_text: | |
| segments.append((cur_text, cur_cls)) | |
| return segments, "".join(plain) | |
| def put_draft(self, value, **kwargs): | |
| if len(value.shape) > 1: | |
| value = value[0] | |
| ids = value.tolist() | |
| self.started = True | |
| if self.prev_ids is None or len(self.prev_ids) != len(ids): | |
| self.settle = [0] * len(ids) | |
| else: | |
| for i, tid in enumerate(ids): | |
| self.settle[i] = self.settle[i] + 1 if tid == self.prev_ids[i] else 0 | |
| self.prev_ids = ids | |
| segments, plain = self._render(ids) | |
| self.last_draft = plain | |
| self.queue.put((self.committed, segments, plain)) | |
| def put(self, value): | |
| if len(value.shape) > 1 and value.shape[0] > 1: | |
| raise ValueError("batch size 1 only") | |
| elif len(value.shape) > 1: | |
| value = value[0] | |
| if not self.started: # prompt context, before any denoising step | |
| return | |
| self.committed += self.tokenizer.decode(value, skip_special_tokens=True) | |
| self.last_draft = "" | |
| self.prev_ids = None | |
| self.settle = None | |
| self.queue.put((self.committed, [], "")) | |
| def end(self): | |
| self.queue.put(_SENTINEL) | |
| def build_display(committed, segments): | |
| out = [] | |
| if committed: | |
| out.append((committed, "done")) | |
| out.extend(segments) | |
| return out | |
| def generate_streaming(messages, max_new_tokens, max_denoising_steps, enable_thinking): | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=enable_thinking, | |
| ).to("cuda") | |
| streamer = CanvasStreamer(processor.tokenizer) | |
| result = {} | |
| def run(): | |
| try: | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| max_denoising_steps=int(max_denoising_steps), | |
| streamer=streamer, | |
| ) | |
| result["ids"] = out | |
| except Exception as e: | |
| result["error"] = e | |
| streamer.queue.put(_SENTINEL) | |
| thread = Thread(target=run) | |
| thread.start() | |
| while True: | |
| item = streamer.queue.get() | |
| if item is _SENTINEL: | |
| break | |
| committed, segments, plain = item | |
| yield build_display(committed, segments), (committed + plain), None | |
| thread.join() | |
| if "error" in result: | |
| raise result["error"] | |
| final_text = (streamer.committed + streamer.last_draft).strip() | |
| yield [(final_text, "done")] if final_text else [], final_text, final_text | |
| def _file_path(item): | |
| """Extract a local file path from a Gradio content part / file dict.""" | |
| for key in ("file", "path", "url"): | |
| val = item.get(key) | |
| if isinstance(val, dict): | |
| val = val.get("path") or val.get("url") | |
| if val: | |
| return val | |
| return None | |
| def to_model_messages(history): | |
| """Convert Gradio (messages format) history into processor chat format with images.""" | |
| messages = [] | |
| for msg in history: | |
| role = msg["role"] | |
| content = msg["content"] | |
| parts = [] | |
| if isinstance(content, str): | |
| parts.append({"type": "text", "text": content}) | |
| elif isinstance(content, tuple): # legacy (path, alt) file tuple | |
| parts.append({"type": "image", "url": content[0]}) | |
| elif isinstance(content, dict): # single file part, e.g. {"path": ...} | |
| p = _file_path(content) | |
| if p: | |
| parts.append({"type": "image", "url": p}) | |
| elif isinstance(content, list): | |
| for item in content: | |
| if not isinstance(item, dict): | |
| parts.append({"type": "text", "text": str(item)}) | |
| elif item.get("type") == "text" or "text" in item: | |
| parts.append({"type": "text", "text": item.get("text", "")}) | |
| else: | |
| p = _file_path(item) | |
| if p: | |
| parts.append({"type": "image", "url": p}) | |
| if parts: | |
| messages.append({"role": role, "content": parts}) | |
| return messages | |
| css = """ | |
| .category-legend{display:none} | |
| .legend{margin-bottom: 5px} | |
| .legend-item{height: 25px} | |
| """ | |
| def create_demo(): | |
| with gr.Blocks(title="DiffusionGemma", css=css) as demo: | |
| gr.Markdown("# DiffusionGemma 26B-A4B — Block Diffusion Chat") | |
| gr.Markdown( | |
| "[model](https://huggingface.co/google/diffusiongemma-26B-A4B-it) · " | |
| "Watch the canvas denoise in real time on the right — text condenses from " | |
| "noise (orange) into settled output (green). Attach an image to ask about it." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Conversation") | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Type a message and/or attach an image…", | |
| show_label=False, | |
| ) | |
| with gr.Column(scale=2): | |
| canvas = gr.HighlightedText( | |
| label="Denoising canvas", | |
| combine_adjacent=False, | |
| show_legend=True, | |
| color_map=COLOR_MAP, | |
| ) | |
| with gr.Accordion("Generation settings", open=False): | |
| with gr.Row(): | |
| enable_thinking = gr.Checkbox(value=False, label="Thinking mode") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(64, 1024, value=256, step=64, label="Max new tokens") | |
| max_denoising_steps = gr.Slider(8, 64, value=48, step=4, label="Max denoising steps") | |
| clear_btn = gr.Button("Clear conversation") | |
| def add_message(message, history): | |
| history = history or [] | |
| for f in message.get("files", []): | |
| history.append({"role": "user", "content": {"path": f}}) | |
| if message.get("text"): | |
| history.append({"role": "user", "content": message["text"]}) | |
| return history, gr.MultimodalTextbox(value=None, interactive=False) | |
| def bot(history, max_new_tokens, max_denoising_steps, enable_thinking): | |
| if not history: | |
| yield history, [] | |
| return | |
| messages = to_model_messages(history) | |
| history = history + [{"role": "assistant", "content": ""}] | |
| final = None | |
| try: | |
| for canvas_state, plain, text in generate_streaming( | |
| messages, max_new_tokens, max_denoising_steps, enable_thinking | |
| ): | |
| if text is not None: | |
| final = text | |
| history[-1]["content"] = final if final is not None else plain | |
| yield history, canvas_state | |
| except Exception as e: | |
| history[-1]["content"] = f"Error: {e}" | |
| yield history, [(str(e), "noise")] | |
| def reenable(): | |
| return gr.MultimodalTextbox(interactive=True) | |
| chat_msg = chat_input.submit( | |
| add_message, [chat_input, chatbot], [chatbot, chat_input] | |
| ) | |
| bot_msg = chat_msg.then( | |
| bot, | |
| [chatbot, max_new_tokens, max_denoising_steps, enable_thinking], | |
| [chatbot, canvas], | |
| ) | |
| bot_msg.then(reenable, None, [chat_input]) | |
| clear_btn.click(lambda: ([], []), None, [chatbot, canvas]) | |
| return demo | |
| demo = create_demo() | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |