import os import sys import time import traceback sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) T_START = time.time() print(f"[app] __main__ start t={T_START:.2f}", flush=True) LOAD_STRATEGY = os.environ.get("LOAD_STRATEGY", "normal") MODEL_ID = "google/gemma-4-E2B-it" PROMPT = "Write a short haiku about mountains." MODEL_SOURCE = os.environ.get("MODEL_SOURCE", MODEL_ID) STATS = { "strategy": LOAD_STRATEGY, "model_id": MODEL_ID, "model_source": MODEL_SOURCE, "prompt": PROMPT, "app_import_epoch": T_START, } def _profile(): print("[app] profile: importing torch...", flush=True) import torch print("[app] profile: importing transformers...", flush=True) import transformers from transformers import AutoModelForImageTextToText, AutoProcessor print(f"[app] profile: torch={torch.__version__}, transformers={transformers.__version__}", flush=True) STATS["transformers_version"] = transformers.__version__ STATS["torch_version"] = torch.__version__ STATS["torch_cuda_available"] = torch.cuda.is_available() STATS["torch_device_name"] = ( torch.cuda.get_device_name(0) if torch.cuda.is_available() else None ) t_imports_done = time.time() STATS["imports_seconds"] = t_imports_done - T_START print(f"[app] loading processor from {MODEL_SOURCE!r}", flush=True) t0 = time.time() processor = AutoProcessor.from_pretrained(MODEL_SOURCE) t1 = time.time() print(f"[app] processor loaded in {t1-t0:.2f}s", flush=True) STATS["processor_load_seconds"] = t1 - t0 print(f"[app] loading model from {MODEL_SOURCE!r}", flush=True) t2 = time.time() model = AutoModelForImageTextToText.from_pretrained( MODEL_SOURCE, dtype="auto", device_map="auto", ) t3 = time.time() print(f"[app] model loaded in {t3-t2:.2f}s", flush=True) STATS["model_load_seconds"] = t3 - t2 STATS["total_load_seconds"] = t3 - t0 messages = [{"role": "user", "content": [{"type": "text", "text": PROMPT}]}] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) print("[app] generating...", flush=True) t4 = time.time() with torch.inference_mode(): out = model.generate(**inputs, max_new_tokens=64, do_sample=False) t5 = time.time() print(f"[app] generate done in {t5-t4:.2f}s", flush=True) STATS["predict_seconds"] = t5 - t4 new_tokens = out[0][inputs["input_ids"].shape[1]:] STATS["response"] = processor.decode(new_tokens, skip_special_tokens=True) STATS["predict_done_epoch"] = t5 STATS["wall_seconds_import_to_predict_done"] = t5 - T_START try: _profile() STATS["status"] = "ok" except Exception as e: STATS["status"] = "error" STATS["error"] = repr(e) STATS["traceback"] = traceback.format_exc() import gradio as gr def get_stats(): return STATS with gr.Blocks(title=f"Gemma-4-E2B-it bench — {LOAD_STRATEGY}") as demo: gr.Markdown(f"# Gemma-4-E2B-it benchmark — `{LOAD_STRATEGY}` strategy") gr.Markdown( f"**Model source:** `{MODEL_SOURCE}` \n" f"**Prompt:** `{PROMPT}` \n" f"Prediction runs automatically once at container boot. " f"The JSON below shows the timings from that single boot-time run." ) gr.JSON(value=STATS, label="Boot-time stats") btn = gr.Button("Fetch stats again (same result)") out = gr.JSON(label="Stats") btn.click(fn=get_stats, inputs=None, outputs=out, api_name="get_stats") demo.queue().launch(server_name="0.0.0.0", server_port=7860)