Wauplin's picture
Wauplin HF Staff
pin transformers to PR #45547 merge commit (auto-detect disable_mmap on hf-mount)
e39cda7 verified
import os
import sys
import time
import traceback
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
T_START = time.time()
print(f"[app] __main__ start t={T_START:.2f}", flush=True)
LOAD_STRATEGY = os.environ.get("LOAD_STRATEGY", "normal")
MODEL_ID = "google/gemma-4-E2B-it"
PROMPT = "Write a short haiku about mountains."
MODEL_SOURCE = os.environ.get("MODEL_SOURCE", MODEL_ID)
STATS = {
"strategy": LOAD_STRATEGY,
"model_id": MODEL_ID,
"model_source": MODEL_SOURCE,
"prompt": PROMPT,
"app_import_epoch": T_START,
}
def _profile():
print("[app] profile: importing torch...", flush=True)
import torch
print("[app] profile: importing transformers...", flush=True)
import transformers
from transformers import AutoModelForImageTextToText, AutoProcessor
print(f"[app] profile: torch={torch.__version__}, transformers={transformers.__version__}", flush=True)
STATS["transformers_version"] = transformers.__version__
STATS["torch_version"] = torch.__version__
STATS["torch_cuda_available"] = torch.cuda.is_available()
STATS["torch_device_name"] = (
torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
)
t_imports_done = time.time()
STATS["imports_seconds"] = t_imports_done - T_START
print(f"[app] loading processor from {MODEL_SOURCE!r}", flush=True)
t0 = time.time()
processor = AutoProcessor.from_pretrained(MODEL_SOURCE)
t1 = time.time()
print(f"[app] processor loaded in {t1-t0:.2f}s", flush=True)
STATS["processor_load_seconds"] = t1 - t0
print(f"[app] loading model from {MODEL_SOURCE!r}", flush=True)
t2 = time.time()
model = AutoModelForImageTextToText.from_pretrained(
MODEL_SOURCE,
dtype="auto",
device_map="auto",
)
t3 = time.time()
print(f"[app] model loaded in {t3-t2:.2f}s", flush=True)
STATS["model_load_seconds"] = t3 - t2
STATS["total_load_seconds"] = t3 - t0
messages = [{"role": "user", "content": [{"type": "text", "text": PROMPT}]}]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
print("[app] generating...", flush=True)
t4 = time.time()
with torch.inference_mode():
out = model.generate(**inputs, max_new_tokens=64, do_sample=False)
t5 = time.time()
print(f"[app] generate done in {t5-t4:.2f}s", flush=True)
STATS["predict_seconds"] = t5 - t4
new_tokens = out[0][inputs["input_ids"].shape[1]:]
STATS["response"] = processor.decode(new_tokens, skip_special_tokens=True)
STATS["predict_done_epoch"] = t5
STATS["wall_seconds_import_to_predict_done"] = t5 - T_START
try:
_profile()
STATS["status"] = "ok"
except Exception as e:
STATS["status"] = "error"
STATS["error"] = repr(e)
STATS["traceback"] = traceback.format_exc()
import gradio as gr
def get_stats():
return STATS
with gr.Blocks(title=f"Gemma-4-E2B-it bench — {LOAD_STRATEGY}") as demo:
gr.Markdown(f"# Gemma-4-E2B-it benchmark — `{LOAD_STRATEGY}` strategy")
gr.Markdown(
f"**Model source:** `{MODEL_SOURCE}` \n"
f"**Prompt:** `{PROMPT}` \n"
f"Prediction runs automatically once at container boot. "
f"The JSON below shows the timings from that single boot-time run."
)
gr.JSON(value=STATS, label="Boot-time stats")
btn = gr.Button("Fetch stats again (same result)")
out = gr.JSON(label="Stats")
btn.click(fn=get_stats, inputs=None, outputs=out, api_name="get_stats")
demo.queue().launch(server_name="0.0.0.0", server_port=7860)