Spaces:
Running on Zero
Running on Zero
| import os | |
| def _set_writable_env_dir(name: str, default: str, fallback: str) -> None: | |
| path = os.environ.setdefault(name, default) | |
| try: | |
| os.makedirs(path, exist_ok=True) | |
| except OSError: | |
| os.environ[name] = fallback | |
| os.makedirs(fallback, exist_ok=True) | |
| _set_writable_env_dir("HF_HOME", "/data/.cache/huggingface", "/tmp/huggingface") | |
| _set_writable_env_dir("HF_MODULES_CACHE", "/tmp/hf_modules", "/tmp/hf_modules") | |
| _set_writable_env_dir("MPLCONFIGDIR", "/tmp/matplotlib", "/tmp/matplotlib") | |
| os.environ.setdefault("GRADIO_SSR_MODE", "false") | |
| import json | |
| import math | |
| import re | |
| import time | |
| from copy import deepcopy | |
| from typing import Any | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers.models.auto.modeling_auto import AutoModelForMultimodalLM | |
| from transformers.models.auto.processing_auto import AutoProcessor | |
| from transformers.models.diffusion_gemma.generation_diffusion_gemma import ( | |
| DiffusionGemmaGenerationConfig, | |
| EntropyBoundSamplerConfig, | |
| ) | |
| MODEL_ID = "google/diffusiongemma-26B-A4B-it" | |
| IMAGE_TOKEN_BUDGETS = [70, 140, 280, 560, 1120] | |
| DEFAULT_SYSTEM_PROMPT = "You are DiffusionGemma, a precise multimodal assistant." | |
| PAD_TOKEN_ID = 0 | |
| EOS_TOKEN_IDS = {1, 50, 106} | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForMultimodalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype="auto", | |
| low_cpu_mem_usage=True, | |
| ).to("cuda") | |
| model.eval() | |
| def _estimate_gpu_seconds( | |
| prompt: str, | |
| image: Image.Image | None, | |
| chat_history: list[dict[str, Any]] | None, | |
| model_history: list[dict[str, Any]] | None, | |
| system_prompt: str, | |
| enable_thinking: bool, | |
| max_new_tokens: int, | |
| max_denoising_steps: int, | |
| entropy_bound: float, | |
| t_min: float, | |
| t_max: float, | |
| confidence_threshold: float, | |
| stability_threshold: int, | |
| image_token_budget: int, | |
| show_thinking: bool, | |
| *args, | |
| **kwargs, | |
| ) -> int: | |
| canvases = max(1, math.ceil(int(max_new_tokens) / 256)) | |
| image_cost = 12 if image is not None else 0 | |
| thinking_cost = 20 if enable_thinking else 0 | |
| denoising_cost = canvases * max(1, int(max_denoising_steps)) * 0.2 | |
| return min(180, max(30, math.ceil(12 + image_cost + thinking_cost + denoising_cost))) | |
| def _zerogpu_probe() -> str: | |
| return "ready" | |
| def _as_text(value: Any) -> str: | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value | |
| return str(value) | |
| def _clean_generated_text(text: str) -> str: | |
| text = re.sub(r"<\|channel\>thought\n.*?<channel\|>", "", text, flags=re.DOTALL) | |
| for marker in ("<turn|>", "<eos>", "<pad>", "<bos>", "<start_of_turn>", "<end_of_turn>"): | |
| text = text.replace(marker, "") | |
| return text.strip() | |
| def _trim_generated_tail(tokens: torch.Tensor) -> torch.Tensor: | |
| tokens = tokens.flatten() | |
| for index, token_id in enumerate(tokens.tolist()): | |
| if token_id == PAD_TOKEN_ID or token_id in EOS_TOKEN_IDS: | |
| return tokens[:index] | |
| return tokens | |
| def _parse_generated(new_tokens: torch.Tensor) -> tuple[str, str, str]: | |
| display_tokens = _trim_generated_tail(new_tokens) | |
| fallback_tokens = display_tokens if display_tokens.numel() else new_tokens | |
| try: | |
| parsed = processor.parse_response(new_tokens) | |
| except Exception: | |
| parsed = None | |
| if isinstance(parsed, dict): | |
| answer = _clean_generated_text(_as_text(parsed.get("content"))) | |
| thinking = _clean_generated_text(_as_text(parsed.get("thinking"))) | |
| tool_calls = parsed.get("tool_calls") or [] | |
| tool_text = json.dumps(tool_calls, indent=2) if tool_calls else "" | |
| if answer or thinking or tool_text: | |
| return answer, thinking, tool_text | |
| raw = processor.decode(fallback_tokens, skip_special_tokens=False) | |
| return _clean_generated_text(raw), "", "" | |
| def _message_text(message: dict[str, Any]) -> str: | |
| content = message.get("content", "") | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, dict) and item.get("type") == "text": | |
| parts.append(_as_text(item.get("text")).strip()) | |
| return "\n".join(part for part in parts if part) | |
| return _as_text(content) | |
| def _trim_history(messages: list[dict[str, Any]], max_turns: int = 6) -> list[dict[str, Any]]: | |
| if max_turns <= 0: | |
| return [] | |
| turns: list[list[dict[str, Any]]] = [] | |
| current: list[dict[str, Any]] = [] | |
| for message in messages: | |
| if message.get("role") == "user" and current: | |
| turns.append(current) | |
| current = [] | |
| current.append(message) | |
| if current: | |
| turns.append(current) | |
| return [deepcopy(message) for turn in turns[-max_turns:] for message in turn] | |
| def _build_user_content(prompt: str, image: Image.Image | None) -> str | list[dict[str, Any]]: | |
| prompt = prompt.strip() | |
| if image is None: | |
| return prompt | |
| content: list[dict[str, Any]] = [{"type": "image", "image": image}] | |
| if prompt: | |
| content.append({"type": "text", "text": prompt}) | |
| return content | |
| def _build_messages( | |
| prompt: str, | |
| image: Image.Image | None, | |
| model_history: list[dict[str, Any]] | None, | |
| system_prompt: str, | |
| ) -> list[dict[str, Any]]: | |
| messages: list[dict[str, Any]] = [] | |
| system_prompt = system_prompt.strip() | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for message in _trim_history(model_history or []): | |
| role = message.get("role") | |
| if role in {"user", "assistant"}: | |
| messages.append(deepcopy(message)) | |
| messages.append({"role": "user", "content": _build_user_content(prompt, image)}) | |
| return messages | |
| def _generation_config( | |
| max_new_tokens: int, | |
| max_denoising_steps: int, | |
| entropy_bound: float, | |
| t_min: float, | |
| t_max: float, | |
| confidence_threshold: float, | |
| stability_threshold: int, | |
| ) -> DiffusionGemmaGenerationConfig: | |
| return DiffusionGemmaGenerationConfig( | |
| max_new_tokens=int(max_new_tokens), | |
| max_denoising_steps=int(max_denoising_steps), | |
| sampler_config=EntropyBoundSamplerConfig(entropy_bound=float(entropy_bound)), | |
| t_min=float(t_min), | |
| t_max=float(t_max), | |
| confidence_threshold=float(confidence_threshold), | |
| stability_threshold=int(stability_threshold), | |
| pad_token_id=0, | |
| eos_token_id=[1, 106, 50], | |
| ) | |
| def _to_model_device(inputs: Any) -> Any: | |
| if hasattr(inputs, "to"): | |
| return inputs.to(model.device) | |
| if isinstance(inputs, dict): | |
| return {key: value.to(model.device) if hasattr(value, "to") else value for key, value in inputs.items()} | |
| return inputs | |
| def respond( | |
| prompt: str, | |
| image: Image.Image | None, | |
| chat_history: list[dict[str, Any]] | None, | |
| model_history: list[dict[str, Any]] | None, | |
| system_prompt: str, | |
| enable_thinking: bool, | |
| max_new_tokens: int, | |
| max_denoising_steps: int, | |
| entropy_bound: float, | |
| t_min: float, | |
| t_max: float, | |
| confidence_threshold: float, | |
| stability_threshold: int, | |
| image_token_budget: int, | |
| show_thinking: bool, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], str, str, str, None]: | |
| prompt = prompt.strip() | |
| chat_history = list(chat_history or []) | |
| model_history = list(model_history or []) | |
| if not prompt and image is None: | |
| raise gr.Error("Enter a prompt or attach an image.") | |
| if image_token_budget not in IMAGE_TOKEN_BUDGETS: | |
| raise gr.Error("Select a supported image token budget.") | |
| if t_max <= t_min: | |
| raise gr.Error("Start temperature must be greater than end temperature.") | |
| if enable_thinking and max_new_tokens < 512: | |
| max_new_tokens = 512 | |
| progress(0.05, desc="Preparing inputs") | |
| messages = _build_messages(prompt, image, model_history, system_prompt) | |
| processor_kwargs = {"images_kwargs": {"max_soft_tokens": int(image_token_budget)}} | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=bool(enable_thinking), | |
| processor_kwargs=processor_kwargs, | |
| ) | |
| inputs = _to_model_device(inputs) | |
| progress(0.2, desc="Generating") | |
| started_at = time.perf_counter() | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| generation_config=_generation_config( | |
| max_new_tokens=max_new_tokens, | |
| max_denoising_steps=max_denoising_steps, | |
| entropy_bound=entropy_bound, | |
| t_min=t_min, | |
| t_max=t_max, | |
| confidence_threshold=confidence_threshold, | |
| stability_threshold=stability_threshold, | |
| ), | |
| ) | |
| elapsed = time.perf_counter() - started_at | |
| sequences = outputs.sequences if hasattr(outputs, "sequences") else outputs[0] | |
| prompt_length = inputs["input_ids"].shape[-1] | |
| new_tokens = sequences[0, prompt_length:].detach().cpu() | |
| displayed_tokens = _trim_generated_tail(new_tokens) | |
| answer, thinking, tool_text = _parse_generated(new_tokens) | |
| answer = answer or "(No final answer was generated.)" | |
| user_display = prompt if prompt else "(image only)" | |
| if image is not None: | |
| user_display = f"{user_display}\n\n[image attached]" | |
| chat_history.append({"role": "user", "content": user_display}) | |
| chat_history.append({"role": "assistant", "content": answer}) | |
| model_history.append({"role": "user", "content": _build_user_content(prompt, image)}) | |
| model_history.append({"role": "assistant", "content": answer}) | |
| model_history = _trim_history(model_history) | |
| tokens_per_forward = getattr(outputs, "tokens_per_forward", None) | |
| if isinstance(tokens_per_forward, torch.Tensor): | |
| tokens_per_forward_text = f"{tokens_per_forward.float().mean().item():.2f}" | |
| else: | |
| tokens_per_forward_text = "n/a" | |
| thought_markdown = thinking if show_thinking and thinking else "" | |
| tool_markdown = f"```json\n{tool_text}\n```" if tool_text else "" | |
| stats = ( | |
| f"Elapsed: {elapsed:.1f}s\n\n" | |
| f"Displayed tokens: {int(displayed_tokens.numel())}\n\n" | |
| f"Canvas tokens: {int(new_tokens.numel())}\n\n" | |
| f"Tokens per forward: {tokens_per_forward_text}" | |
| ) | |
| progress(1.0, desc="Done") | |
| return chat_history, model_history, thought_markdown, tool_markdown, stats, None | |
| def clear_chat() -> tuple[list, list, str, str, str, None, str]: | |
| return [], [], "", "", "", None, "" | |
| css = """ | |
| .contain { max-width: 1280px; } | |
| #control-panel textarea { min-height: 86px !important; } | |
| #stats-box textarea { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; } | |
| """ | |
| with gr.Blocks( | |
| title="DiffusionGemma 26B A4B", | |
| ) as demo: | |
| model_state = gr.State([]) | |
| gr.Markdown("# DiffusionGemma 26B A4B") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=7): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=560, | |
| buttons=["copy", "copy_all"], | |
| ) | |
| prompt = gr.Textbox( | |
| label="Message", | |
| placeholder="Ask about text, code, reasoning, or an attached image.", | |
| lines=3, | |
| max_lines=8, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Generate", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=4, elem_id="control-panel"): | |
| image = gr.Image(label="Image", type="pil", height=260) | |
| system_prompt = gr.Textbox( | |
| label="System", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| lines=3, | |
| max_lines=6, | |
| ) | |
| enable_thinking = gr.Checkbox(label="Thinking", value=False) | |
| show_thinking = gr.Checkbox(label="Show thought trace", value=False) | |
| with gr.Accordion("Generation", open=True): | |
| max_new_tokens = gr.Slider(256, 1024, value=512, step=256, label="Max new tokens") | |
| max_denoising_steps = gr.Slider(8, 64, value=48, step=1, label="Denoising steps") | |
| image_token_budget = gr.Radio( | |
| IMAGE_TOKEN_BUDGETS, | |
| value=280, | |
| label="Image tokens", | |
| ) | |
| with gr.Accordion("Sampler", open=False): | |
| entropy_bound = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Entropy bound") | |
| t_max = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Start temperature") | |
| t_min = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="End temperature") | |
| confidence_threshold = gr.Slider( | |
| 0.001, | |
| 0.05, | |
| value=0.005, | |
| step=0.001, | |
| label="Confidence threshold", | |
| ) | |
| stability_threshold = gr.Slider(0, 4, value=1, step=1, label="Stability threshold") | |
| with gr.Accordion("Latest thought trace", open=False): | |
| thinking_box = gr.Markdown() | |
| with gr.Accordion("Tool calls", open=False): | |
| tool_box = gr.Markdown() | |
| stats_box = gr.Textbox(label="Run stats", lines=4, interactive=False, elem_id="stats-box") | |
| inputs = [ | |
| prompt, | |
| image, | |
| chatbot, | |
| model_state, | |
| system_prompt, | |
| enable_thinking, | |
| max_new_tokens, | |
| max_denoising_steps, | |
| entropy_bound, | |
| t_min, | |
| t_max, | |
| confidence_threshold, | |
| stability_threshold, | |
| image_token_budget, | |
| show_thinking, | |
| ] | |
| outputs = [chatbot, model_state, thinking_box, tool_box, stats_box, image] | |
| submit.click(respond, inputs=inputs, outputs=outputs, api_name="generate", concurrency_limit=1) | |
| prompt.submit(respond, inputs=inputs, outputs=outputs, api_name=False, concurrency_limit=1) | |
| clear.click( | |
| clear_chat, | |
| outputs=[chatbot, model_state, thinking_box, tool_box, stats_box, image, prompt], | |
| api_name="clear", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Explain why discrete diffusion can generate several tokens per forward pass.", None], | |
| ["Write a small Python function that topologically sorts a DAG.", None], | |
| ["Summarize the image and read any visible text.", None], | |
| ], | |
| inputs=[prompt, image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1, max_size=12).launch( | |
| css=css, | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green", neutral_hue="slate"), | |
| ) | |