Mike0021's picture
Tune ZeroGPU duration estimate
a67bfb0 verified
import os
def _set_writable_env_dir(name: str, default: str, fallback: str) -> None:
path = os.environ.setdefault(name, default)
try:
os.makedirs(path, exist_ok=True)
except OSError:
os.environ[name] = fallback
os.makedirs(fallback, exist_ok=True)
_set_writable_env_dir("HF_HOME", "/data/.cache/huggingface", "/tmp/huggingface")
_set_writable_env_dir("HF_MODULES_CACHE", "/tmp/hf_modules", "/tmp/hf_modules")
_set_writable_env_dir("MPLCONFIGDIR", "/tmp/matplotlib", "/tmp/matplotlib")
os.environ.setdefault("GRADIO_SSR_MODE", "false")
import json
import math
import re
import time
from copy import deepcopy
from typing import Any
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers.models.auto.modeling_auto import AutoModelForMultimodalLM
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.models.diffusion_gemma.generation_diffusion_gemma import (
DiffusionGemmaGenerationConfig,
EntropyBoundSamplerConfig,
)
MODEL_ID = "google/diffusiongemma-26B-A4B-it"
IMAGE_TOKEN_BUDGETS = [70, 140, 280, 560, 1120]
DEFAULT_SYSTEM_PROMPT = "You are DiffusionGemma, a precise multimodal assistant."
PAD_TOKEN_ID = 0
EOS_TOKEN_IDS = {1, 50, 106}
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_ID,
dtype="auto",
low_cpu_mem_usage=True,
).to("cuda")
model.eval()
def _estimate_gpu_seconds(
prompt: str,
image: Image.Image | None,
chat_history: list[dict[str, Any]] | None,
model_history: list[dict[str, Any]] | None,
system_prompt: str,
enable_thinking: bool,
max_new_tokens: int,
max_denoising_steps: int,
entropy_bound: float,
t_min: float,
t_max: float,
confidence_threshold: float,
stability_threshold: int,
image_token_budget: int,
show_thinking: bool,
*args,
**kwargs,
) -> int:
canvases = max(1, math.ceil(int(max_new_tokens) / 256))
image_cost = 12 if image is not None else 0
thinking_cost = 20 if enable_thinking else 0
denoising_cost = canvases * max(1, int(max_denoising_steps)) * 0.2
return min(180, max(30, math.ceil(12 + image_cost + thinking_cost + denoising_cost)))
@spaces.GPU(duration=1)
def _zerogpu_probe() -> str:
return "ready"
def _as_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
return str(value)
def _clean_generated_text(text: str) -> str:
text = re.sub(r"<\|channel\>thought\n.*?<channel\|>", "", text, flags=re.DOTALL)
for marker in ("<turn|>", "<eos>", "<pad>", "<bos>", "<start_of_turn>", "<end_of_turn>"):
text = text.replace(marker, "")
return text.strip()
def _trim_generated_tail(tokens: torch.Tensor) -> torch.Tensor:
tokens = tokens.flatten()
for index, token_id in enumerate(tokens.tolist()):
if token_id == PAD_TOKEN_ID or token_id in EOS_TOKEN_IDS:
return tokens[:index]
return tokens
def _parse_generated(new_tokens: torch.Tensor) -> tuple[str, str, str]:
display_tokens = _trim_generated_tail(new_tokens)
fallback_tokens = display_tokens if display_tokens.numel() else new_tokens
try:
parsed = processor.parse_response(new_tokens)
except Exception:
parsed = None
if isinstance(parsed, dict):
answer = _clean_generated_text(_as_text(parsed.get("content")))
thinking = _clean_generated_text(_as_text(parsed.get("thinking")))
tool_calls = parsed.get("tool_calls") or []
tool_text = json.dumps(tool_calls, indent=2) if tool_calls else ""
if answer or thinking or tool_text:
return answer, thinking, tool_text
raw = processor.decode(fallback_tokens, skip_special_tokens=False)
return _clean_generated_text(raw), "", ""
def _message_text(message: dict[str, Any]) -> str:
content = message.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(_as_text(item.get("text")).strip())
return "\n".join(part for part in parts if part)
return _as_text(content)
def _trim_history(messages: list[dict[str, Any]], max_turns: int = 6) -> list[dict[str, Any]]:
if max_turns <= 0:
return []
turns: list[list[dict[str, Any]]] = []
current: list[dict[str, Any]] = []
for message in messages:
if message.get("role") == "user" and current:
turns.append(current)
current = []
current.append(message)
if current:
turns.append(current)
return [deepcopy(message) for turn in turns[-max_turns:] for message in turn]
def _build_user_content(prompt: str, image: Image.Image | None) -> str | list[dict[str, Any]]:
prompt = prompt.strip()
if image is None:
return prompt
content: list[dict[str, Any]] = [{"type": "image", "image": image}]
if prompt:
content.append({"type": "text", "text": prompt})
return content
def _build_messages(
prompt: str,
image: Image.Image | None,
model_history: list[dict[str, Any]] | None,
system_prompt: str,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
system_prompt = system_prompt.strip()
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for message in _trim_history(model_history or []):
role = message.get("role")
if role in {"user", "assistant"}:
messages.append(deepcopy(message))
messages.append({"role": "user", "content": _build_user_content(prompt, image)})
return messages
def _generation_config(
max_new_tokens: int,
max_denoising_steps: int,
entropy_bound: float,
t_min: float,
t_max: float,
confidence_threshold: float,
stability_threshold: int,
) -> DiffusionGemmaGenerationConfig:
return DiffusionGemmaGenerationConfig(
max_new_tokens=int(max_new_tokens),
max_denoising_steps=int(max_denoising_steps),
sampler_config=EntropyBoundSamplerConfig(entropy_bound=float(entropy_bound)),
t_min=float(t_min),
t_max=float(t_max),
confidence_threshold=float(confidence_threshold),
stability_threshold=int(stability_threshold),
pad_token_id=0,
eos_token_id=[1, 106, 50],
)
def _to_model_device(inputs: Any) -> Any:
if hasattr(inputs, "to"):
return inputs.to(model.device)
if isinstance(inputs, dict):
return {key: value.to(model.device) if hasattr(value, "to") else value for key, value in inputs.items()}
return inputs
@spaces.GPU(duration=_estimate_gpu_seconds, size="xlarge")
def respond(
prompt: str,
image: Image.Image | None,
chat_history: list[dict[str, Any]] | None,
model_history: list[dict[str, Any]] | None,
system_prompt: str,
enable_thinking: bool,
max_new_tokens: int,
max_denoising_steps: int,
entropy_bound: float,
t_min: float,
t_max: float,
confidence_threshold: float,
stability_threshold: int,
image_token_budget: int,
show_thinking: bool,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], str, str, str, None]:
prompt = prompt.strip()
chat_history = list(chat_history or [])
model_history = list(model_history or [])
if not prompt and image is None:
raise gr.Error("Enter a prompt or attach an image.")
if image_token_budget not in IMAGE_TOKEN_BUDGETS:
raise gr.Error("Select a supported image token budget.")
if t_max <= t_min:
raise gr.Error("Start temperature must be greater than end temperature.")
if enable_thinking and max_new_tokens < 512:
max_new_tokens = 512
progress(0.05, desc="Preparing inputs")
messages = _build_messages(prompt, image, model_history, system_prompt)
processor_kwargs = {"images_kwargs": {"max_soft_tokens": int(image_token_budget)}}
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
enable_thinking=bool(enable_thinking),
processor_kwargs=processor_kwargs,
)
inputs = _to_model_device(inputs)
progress(0.2, desc="Generating")
started_at = time.perf_counter()
with torch.inference_mode():
outputs = model.generate(
**inputs,
generation_config=_generation_config(
max_new_tokens=max_new_tokens,
max_denoising_steps=max_denoising_steps,
entropy_bound=entropy_bound,
t_min=t_min,
t_max=t_max,
confidence_threshold=confidence_threshold,
stability_threshold=stability_threshold,
),
)
elapsed = time.perf_counter() - started_at
sequences = outputs.sequences if hasattr(outputs, "sequences") else outputs[0]
prompt_length = inputs["input_ids"].shape[-1]
new_tokens = sequences[0, prompt_length:].detach().cpu()
displayed_tokens = _trim_generated_tail(new_tokens)
answer, thinking, tool_text = _parse_generated(new_tokens)
answer = answer or "(No final answer was generated.)"
user_display = prompt if prompt else "(image only)"
if image is not None:
user_display = f"{user_display}\n\n[image attached]"
chat_history.append({"role": "user", "content": user_display})
chat_history.append({"role": "assistant", "content": answer})
model_history.append({"role": "user", "content": _build_user_content(prompt, image)})
model_history.append({"role": "assistant", "content": answer})
model_history = _trim_history(model_history)
tokens_per_forward = getattr(outputs, "tokens_per_forward", None)
if isinstance(tokens_per_forward, torch.Tensor):
tokens_per_forward_text = f"{tokens_per_forward.float().mean().item():.2f}"
else:
tokens_per_forward_text = "n/a"
thought_markdown = thinking if show_thinking and thinking else ""
tool_markdown = f"```json\n{tool_text}\n```" if tool_text else ""
stats = (
f"Elapsed: {elapsed:.1f}s\n\n"
f"Displayed tokens: {int(displayed_tokens.numel())}\n\n"
f"Canvas tokens: {int(new_tokens.numel())}\n\n"
f"Tokens per forward: {tokens_per_forward_text}"
)
progress(1.0, desc="Done")
return chat_history, model_history, thought_markdown, tool_markdown, stats, None
def clear_chat() -> tuple[list, list, str, str, str, None, str]:
return [], [], "", "", "", None, ""
css = """
.contain { max-width: 1280px; }
#control-panel textarea { min-height: 86px !important; }
#stats-box textarea { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }
"""
with gr.Blocks(
title="DiffusionGemma 26B A4B",
) as demo:
model_state = gr.State([])
gr.Markdown("# DiffusionGemma 26B A4B")
with gr.Row(equal_height=False):
with gr.Column(scale=7):
chatbot = gr.Chatbot(
label="Conversation",
height=560,
buttons=["copy", "copy_all"],
)
prompt = gr.Textbox(
label="Message",
placeholder="Ask about text, code, reasoning, or an attached image.",
lines=3,
max_lines=8,
)
with gr.Row():
submit = gr.Button("Generate", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=4, elem_id="control-panel"):
image = gr.Image(label="Image", type="pil", height=260)
system_prompt = gr.Textbox(
label="System",
value=DEFAULT_SYSTEM_PROMPT,
lines=3,
max_lines=6,
)
enable_thinking = gr.Checkbox(label="Thinking", value=False)
show_thinking = gr.Checkbox(label="Show thought trace", value=False)
with gr.Accordion("Generation", open=True):
max_new_tokens = gr.Slider(256, 1024, value=512, step=256, label="Max new tokens")
max_denoising_steps = gr.Slider(8, 64, value=48, step=1, label="Denoising steps")
image_token_budget = gr.Radio(
IMAGE_TOKEN_BUDGETS,
value=280,
label="Image tokens",
)
with gr.Accordion("Sampler", open=False):
entropy_bound = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Entropy bound")
t_max = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Start temperature")
t_min = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="End temperature")
confidence_threshold = gr.Slider(
0.001,
0.05,
value=0.005,
step=0.001,
label="Confidence threshold",
)
stability_threshold = gr.Slider(0, 4, value=1, step=1, label="Stability threshold")
with gr.Accordion("Latest thought trace", open=False):
thinking_box = gr.Markdown()
with gr.Accordion("Tool calls", open=False):
tool_box = gr.Markdown()
stats_box = gr.Textbox(label="Run stats", lines=4, interactive=False, elem_id="stats-box")
inputs = [
prompt,
image,
chatbot,
model_state,
system_prompt,
enable_thinking,
max_new_tokens,
max_denoising_steps,
entropy_bound,
t_min,
t_max,
confidence_threshold,
stability_threshold,
image_token_budget,
show_thinking,
]
outputs = [chatbot, model_state, thinking_box, tool_box, stats_box, image]
submit.click(respond, inputs=inputs, outputs=outputs, api_name="generate", concurrency_limit=1)
prompt.submit(respond, inputs=inputs, outputs=outputs, api_name=False, concurrency_limit=1)
clear.click(
clear_chat,
outputs=[chatbot, model_state, thinking_box, tool_box, stats_box, image, prompt],
api_name="clear",
)
gr.Examples(
examples=[
["Explain why discrete diffusion can generate several tokens per forward pass.", None],
["Write a small Python function that topologically sorts a DAG.", None],
["Summarize the image and read any visible text.", None],
],
inputs=[prompt, image],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1, max_size=12).launch(
css=css,
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green", neutral_hue="slate"),
)