import os def _set_writable_env_dir(name: str, default: str, fallback: str) -> None: path = os.environ.setdefault(name, default) try: os.makedirs(path, exist_ok=True) except OSError: os.environ[name] = fallback os.makedirs(fallback, exist_ok=True) _set_writable_env_dir("HF_HOME", "/data/.cache/huggingface", "/tmp/huggingface") _set_writable_env_dir("HF_MODULES_CACHE", "/tmp/hf_modules", "/tmp/hf_modules") _set_writable_env_dir("MPLCONFIGDIR", "/tmp/matplotlib", "/tmp/matplotlib") os.environ.setdefault("GRADIO_SSR_MODE", "false") import json import math import re import time from copy import deepcopy from typing import Any import spaces import gradio as gr import torch from PIL import Image from transformers.models.auto.modeling_auto import AutoModelForMultimodalLM from transformers.models.auto.processing_auto import AutoProcessor from transformers.models.diffusion_gemma.generation_diffusion_gemma import ( DiffusionGemmaGenerationConfig, EntropyBoundSamplerConfig, ) MODEL_ID = "google/diffusiongemma-26B-A4B-it" IMAGE_TOKEN_BUDGETS = [70, 140, 280, 560, 1120] DEFAULT_SYSTEM_PROMPT = "You are DiffusionGemma, a precise multimodal assistant." PAD_TOKEN_ID = 0 EOS_TOKEN_IDS = {1, 50, 106} processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype="auto", low_cpu_mem_usage=True, ).to("cuda") model.eval() def _estimate_gpu_seconds( prompt: str, image: Image.Image | None, chat_history: list[dict[str, Any]] | None, model_history: list[dict[str, Any]] | None, system_prompt: str, enable_thinking: bool, max_new_tokens: int, max_denoising_steps: int, entropy_bound: float, t_min: float, t_max: float, confidence_threshold: float, stability_threshold: int, image_token_budget: int, show_thinking: bool, *args, **kwargs, ) -> int: canvases = max(1, math.ceil(int(max_new_tokens) / 256)) image_cost = 12 if image is not None else 0 thinking_cost = 20 if enable_thinking else 0 denoising_cost = canvases * max(1, int(max_denoising_steps)) * 0.2 return min(180, max(30, math.ceil(12 + image_cost + thinking_cost + denoising_cost))) @spaces.GPU(duration=1) def _zerogpu_probe() -> str: return "ready" def _as_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value return str(value) def _clean_generated_text(text: str) -> str: text = re.sub(r"<\|channel\>thought\n.*?", "", text, flags=re.DOTALL) for marker in ("", "", "", "", "", ""): text = text.replace(marker, "") return text.strip() def _trim_generated_tail(tokens: torch.Tensor) -> torch.Tensor: tokens = tokens.flatten() for index, token_id in enumerate(tokens.tolist()): if token_id == PAD_TOKEN_ID or token_id in EOS_TOKEN_IDS: return tokens[:index] return tokens def _parse_generated(new_tokens: torch.Tensor) -> tuple[str, str, str]: display_tokens = _trim_generated_tail(new_tokens) fallback_tokens = display_tokens if display_tokens.numel() else new_tokens try: parsed = processor.parse_response(new_tokens) except Exception: parsed = None if isinstance(parsed, dict): answer = _clean_generated_text(_as_text(parsed.get("content"))) thinking = _clean_generated_text(_as_text(parsed.get("thinking"))) tool_calls = parsed.get("tool_calls") or [] tool_text = json.dumps(tool_calls, indent=2) if tool_calls else "" if answer or thinking or tool_text: return answer, thinking, tool_text raw = processor.decode(fallback_tokens, skip_special_tokens=False) return _clean_generated_text(raw), "", "" def _message_text(message: dict[str, Any]) -> str: content = message.get("content", "") if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(_as_text(item.get("text")).strip()) return "\n".join(part for part in parts if part) return _as_text(content) def _trim_history(messages: list[dict[str, Any]], max_turns: int = 6) -> list[dict[str, Any]]: if max_turns <= 0: return [] turns: list[list[dict[str, Any]]] = [] current: list[dict[str, Any]] = [] for message in messages: if message.get("role") == "user" and current: turns.append(current) current = [] current.append(message) if current: turns.append(current) return [deepcopy(message) for turn in turns[-max_turns:] for message in turn] def _build_user_content(prompt: str, image: Image.Image | None) -> str | list[dict[str, Any]]: prompt = prompt.strip() if image is None: return prompt content: list[dict[str, Any]] = [{"type": "image", "image": image}] if prompt: content.append({"type": "text", "text": prompt}) return content def _build_messages( prompt: str, image: Image.Image | None, model_history: list[dict[str, Any]] | None, system_prompt: str, ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] system_prompt = system_prompt.strip() if system_prompt: messages.append({"role": "system", "content": system_prompt}) for message in _trim_history(model_history or []): role = message.get("role") if role in {"user", "assistant"}: messages.append(deepcopy(message)) messages.append({"role": "user", "content": _build_user_content(prompt, image)}) return messages def _generation_config( max_new_tokens: int, max_denoising_steps: int, entropy_bound: float, t_min: float, t_max: float, confidence_threshold: float, stability_threshold: int, ) -> DiffusionGemmaGenerationConfig: return DiffusionGemmaGenerationConfig( max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_denoising_steps), sampler_config=EntropyBoundSamplerConfig(entropy_bound=float(entropy_bound)), t_min=float(t_min), t_max=float(t_max), confidence_threshold=float(confidence_threshold), stability_threshold=int(stability_threshold), pad_token_id=0, eos_token_id=[1, 106, 50], ) def _to_model_device(inputs: Any) -> Any: if hasattr(inputs, "to"): return inputs.to(model.device) if isinstance(inputs, dict): return {key: value.to(model.device) if hasattr(value, "to") else value for key, value in inputs.items()} return inputs @spaces.GPU(duration=_estimate_gpu_seconds, size="xlarge") def respond( prompt: str, image: Image.Image | None, chat_history: list[dict[str, Any]] | None, model_history: list[dict[str, Any]] | None, system_prompt: str, enable_thinking: bool, max_new_tokens: int, max_denoising_steps: int, entropy_bound: float, t_min: float, t_max: float, confidence_threshold: float, stability_threshold: int, image_token_budget: int, show_thinking: bool, progress: gr.Progress = gr.Progress(track_tqdm=True), ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], str, str, str, None]: prompt = prompt.strip() chat_history = list(chat_history or []) model_history = list(model_history or []) if not prompt and image is None: raise gr.Error("Enter a prompt or attach an image.") if image_token_budget not in IMAGE_TOKEN_BUDGETS: raise gr.Error("Select a supported image token budget.") if t_max <= t_min: raise gr.Error("Start temperature must be greater than end temperature.") if enable_thinking and max_new_tokens < 512: max_new_tokens = 512 progress(0.05, desc="Preparing inputs") messages = _build_messages(prompt, image, model_history, system_prompt) processor_kwargs = {"images_kwargs": {"max_soft_tokens": int(image_token_budget)}} inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", enable_thinking=bool(enable_thinking), processor_kwargs=processor_kwargs, ) inputs = _to_model_device(inputs) progress(0.2, desc="Generating") started_at = time.perf_counter() with torch.inference_mode(): outputs = model.generate( **inputs, generation_config=_generation_config( max_new_tokens=max_new_tokens, max_denoising_steps=max_denoising_steps, entropy_bound=entropy_bound, t_min=t_min, t_max=t_max, confidence_threshold=confidence_threshold, stability_threshold=stability_threshold, ), ) elapsed = time.perf_counter() - started_at sequences = outputs.sequences if hasattr(outputs, "sequences") else outputs[0] prompt_length = inputs["input_ids"].shape[-1] new_tokens = sequences[0, prompt_length:].detach().cpu() displayed_tokens = _trim_generated_tail(new_tokens) answer, thinking, tool_text = _parse_generated(new_tokens) answer = answer or "(No final answer was generated.)" user_display = prompt if prompt else "(image only)" if image is not None: user_display = f"{user_display}\n\n[image attached]" chat_history.append({"role": "user", "content": user_display}) chat_history.append({"role": "assistant", "content": answer}) model_history.append({"role": "user", "content": _build_user_content(prompt, image)}) model_history.append({"role": "assistant", "content": answer}) model_history = _trim_history(model_history) tokens_per_forward = getattr(outputs, "tokens_per_forward", None) if isinstance(tokens_per_forward, torch.Tensor): tokens_per_forward_text = f"{tokens_per_forward.float().mean().item():.2f}" else: tokens_per_forward_text = "n/a" thought_markdown = thinking if show_thinking and thinking else "" tool_markdown = f"```json\n{tool_text}\n```" if tool_text else "" stats = ( f"Elapsed: {elapsed:.1f}s\n\n" f"Displayed tokens: {int(displayed_tokens.numel())}\n\n" f"Canvas tokens: {int(new_tokens.numel())}\n\n" f"Tokens per forward: {tokens_per_forward_text}" ) progress(1.0, desc="Done") return chat_history, model_history, thought_markdown, tool_markdown, stats, None def clear_chat() -> tuple[list, list, str, str, str, None, str]: return [], [], "", "", "", None, "" css = """ .contain { max-width: 1280px; } #control-panel textarea { min-height: 86px !important; } #stats-box textarea { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; } """ with gr.Blocks( title="DiffusionGemma 26B A4B", ) as demo: model_state = gr.State([]) gr.Markdown("# DiffusionGemma 26B A4B") with gr.Row(equal_height=False): with gr.Column(scale=7): chatbot = gr.Chatbot( label="Conversation", height=560, buttons=["copy", "copy_all"], ) prompt = gr.Textbox( label="Message", placeholder="Ask about text, code, reasoning, or an attached image.", lines=3, max_lines=8, ) with gr.Row(): submit = gr.Button("Generate", variant="primary") clear = gr.Button("Clear") with gr.Column(scale=4, elem_id="control-panel"): image = gr.Image(label="Image", type="pil", height=260) system_prompt = gr.Textbox( label="System", value=DEFAULT_SYSTEM_PROMPT, lines=3, max_lines=6, ) enable_thinking = gr.Checkbox(label="Thinking", value=False) show_thinking = gr.Checkbox(label="Show thought trace", value=False) with gr.Accordion("Generation", open=True): max_new_tokens = gr.Slider(256, 1024, value=512, step=256, label="Max new tokens") max_denoising_steps = gr.Slider(8, 64, value=48, step=1, label="Denoising steps") image_token_budget = gr.Radio( IMAGE_TOKEN_BUDGETS, value=280, label="Image tokens", ) with gr.Accordion("Sampler", open=False): entropy_bound = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Entropy bound") t_max = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Start temperature") t_min = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="End temperature") confidence_threshold = gr.Slider( 0.001, 0.05, value=0.005, step=0.001, label="Confidence threshold", ) stability_threshold = gr.Slider(0, 4, value=1, step=1, label="Stability threshold") with gr.Accordion("Latest thought trace", open=False): thinking_box = gr.Markdown() with gr.Accordion("Tool calls", open=False): tool_box = gr.Markdown() stats_box = gr.Textbox(label="Run stats", lines=4, interactive=False, elem_id="stats-box") inputs = [ prompt, image, chatbot, model_state, system_prompt, enable_thinking, max_new_tokens, max_denoising_steps, entropy_bound, t_min, t_max, confidence_threshold, stability_threshold, image_token_budget, show_thinking, ] outputs = [chatbot, model_state, thinking_box, tool_box, stats_box, image] submit.click(respond, inputs=inputs, outputs=outputs, api_name="generate", concurrency_limit=1) prompt.submit(respond, inputs=inputs, outputs=outputs, api_name=False, concurrency_limit=1) clear.click( clear_chat, outputs=[chatbot, model_state, thinking_box, tool_box, stats_box, image, prompt], api_name="clear", ) gr.Examples( examples=[ ["Explain why discrete diffusion can generate several tokens per forward pass.", None], ["Write a small Python function that topologically sorts a DAG.", None], ["Summarize the image and read any visible text.", None], ], inputs=[prompt, image], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1, max_size=12).launch( css=css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green", neutral_hue="slate"), )