Spaces:
Paused
Paused
| import os | |
| import sys | |
| import time | |
| import traceback | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.stderr.reconfigure(line_buffering=True) | |
| T_START = time.time() | |
| print(f"[app] __main__ start t={T_START:.2f}", flush=True) | |
| LOAD_STRATEGY = os.environ.get("LOAD_STRATEGY", "normal") | |
| MODEL_ID = "google/gemma-4-E2B-it" | |
| PROMPT = "Write a short haiku about mountains." | |
| MODEL_SOURCE = os.environ.get("MODEL_SOURCE", MODEL_ID) | |
| STATS = { | |
| "strategy": LOAD_STRATEGY, | |
| "model_id": MODEL_ID, | |
| "model_source": MODEL_SOURCE, | |
| "prompt": PROMPT, | |
| "app_import_epoch": T_START, | |
| } | |
| def _profile(): | |
| print("[app] profile: importing torch...", flush=True) | |
| import torch | |
| print("[app] profile: importing transformers...", flush=True) | |
| import transformers | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| print(f"[app] profile: torch={torch.__version__}, transformers={transformers.__version__}", flush=True) | |
| STATS["transformers_version"] = transformers.__version__ | |
| STATS["torch_version"] = torch.__version__ | |
| STATS["torch_cuda_available"] = torch.cuda.is_available() | |
| STATS["torch_device_name"] = ( | |
| torch.cuda.get_device_name(0) if torch.cuda.is_available() else None | |
| ) | |
| t_imports_done = time.time() | |
| STATS["imports_seconds"] = t_imports_done - T_START | |
| print(f"[app] loading processor from {MODEL_SOURCE!r}", flush=True) | |
| t0 = time.time() | |
| processor = AutoProcessor.from_pretrained(MODEL_SOURCE) | |
| t1 = time.time() | |
| print(f"[app] processor loaded in {t1-t0:.2f}s", flush=True) | |
| STATS["processor_load_seconds"] = t1 - t0 | |
| print(f"[app] loading model from {MODEL_SOURCE!r}", flush=True) | |
| t2 = time.time() | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_SOURCE, | |
| dtype="auto", | |
| device_map="auto", | |
| ) | |
| t3 = time.time() | |
| print(f"[app] model loaded in {t3-t2:.2f}s", flush=True) | |
| STATS["model_load_seconds"] = t3 - t2 | |
| STATS["total_load_seconds"] = t3 - t0 | |
| messages = [{"role": "user", "content": [{"type": "text", "text": PROMPT}]}] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| print("[app] generating...", flush=True) | |
| t4 = time.time() | |
| with torch.inference_mode(): | |
| out = model.generate(**inputs, max_new_tokens=64, do_sample=False) | |
| t5 = time.time() | |
| print(f"[app] generate done in {t5-t4:.2f}s", flush=True) | |
| STATS["predict_seconds"] = t5 - t4 | |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] | |
| STATS["response"] = processor.decode(new_tokens, skip_special_tokens=True) | |
| STATS["predict_done_epoch"] = t5 | |
| STATS["wall_seconds_import_to_predict_done"] = t5 - T_START | |
| try: | |
| _profile() | |
| STATS["status"] = "ok" | |
| except Exception as e: | |
| STATS["status"] = "error" | |
| STATS["error"] = repr(e) | |
| STATS["traceback"] = traceback.format_exc() | |
| import gradio as gr | |
| def get_stats(): | |
| return STATS | |
| with gr.Blocks(title=f"Gemma-4-E2B-it bench — {LOAD_STRATEGY}") as demo: | |
| gr.Markdown(f"# Gemma-4-E2B-it benchmark — `{LOAD_STRATEGY}` strategy") | |
| gr.Markdown( | |
| f"**Model source:** `{MODEL_SOURCE}` \n" | |
| f"**Prompt:** `{PROMPT}` \n" | |
| f"Prediction runs automatically once at container boot. " | |
| f"The JSON below shows the timings from that single boot-time run." | |
| ) | |
| gr.JSON(value=STATS, label="Boot-time stats") | |
| btn = gr.Button("Fetch stats again (same result)") | |
| out = gr.JSON(label="Stats") | |
| btn.click(fn=get_stats, inputs=None, outputs=out, api_name="get_stats") | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |