"""Gemma 4 E4B — ZeroGPU, OpenAI-compatible chat-completions endpoint. Serves an uncensored Gemma 4 E4B checkpoint behind an OpenAI-shaped ``POST /v1/chat/completions`` route so the image-edit bot's ``GemmaLLM`` client (apps/image-edit-bot/src/gemma-client.ts) can use it by only setting env vars. GPU is acquired on demand via ``@spaces.GPU`` (ephemeral), which fits the bursty one-off script-generation use case. Gradio owns the server (so the Spaces gradio SDK launches/health-checks it on port 7860). The OpenAI route is added as ASGI middleware via ``app_kwargs`` so it runs *before* gradio's catch-all SPA route. """ import os import time import uuid from typing import Optional import torch import spaces import gradio as gr from fastapi.responses import JSONResponse from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = os.environ.get("MODEL_ID", "OBLITERATUS/gemma-4-E4B-it-OBLITERATED") ENDPOINT_API_KEY = os.environ.get("ENDPOINT_API_KEY", "").strip() MAX_NEW_TOKENS_CAP = int(os.environ.get("MAX_NEW_TOKENS_CAP", "2048")) GPU_DURATION = int(os.environ.get("GPU_DURATION", "120")) print(f"[boot] loading tokenizer/model for {MODEL_ID} (CPU) ...", flush=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, low_cpu_mem_usage=True ) model.eval() print("[boot] model loaded on CPU; GPU is acquired per request", flush=True) @spaces.GPU(duration=GPU_DURATION) def _generate(text: str, max_new_tokens: int, temperature: float, top_p: float) -> str: model.to("cuda") inputs = tokenizer(text, return_tensors="pt").to("cuda") input_len = inputs["input_ids"].shape[-1] gen_kwargs = {"max_new_tokens": int(max_new_tokens)} if temperature and float(temperature) > 0: gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=float(top_p)) else: gen_kwargs.update(do_sample=False) with torch.inference_mode(): out = model.generate(**inputs, **gen_kwargs) return tokenizer.decode(out[0][input_len:], skip_special_tokens=True).strip() def _flatten_content(content) -> str: if isinstance(content, list): return "".join( p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "text" ) return "" if content is None else str(content) def run_chat(messages, max_new_tokens: int, temperature: float, top_p: float) -> str: norm = [ {"role": m.get("role", "user"), "content": _flatten_content(m.get("content"))} for m in messages ] text = tokenizer.apply_chat_template( norm, tokenize=False, add_generation_prompt=True, enable_thinking=False ) return _generate(text, max_new_tokens, temperature, top_p) def _authorized(authorization: Optional[str]) -> bool: if not ENDPOINT_API_KEY: return True if not authorization: return False token = authorization[7:] if authorization.lower().startswith("bearer ") else authorization return token.strip() == ENDPOINT_API_KEY async def _handle_chat(request: Request) -> JSONResponse: if not _authorized(request.headers.get("authorization")): return JSONResponse(status_code=401, content={"error": {"message": "invalid api key"}}) body = await request.json() messages = body.get("messages") or [] if not messages: return JSONResponse(status_code=400, content={"error": {"message": "messages required"}}) max_tokens = max(1, min(int(body.get("max_tokens") or 1024), MAX_NEW_TOKENS_CAP)) temperature = body.get("temperature", 0.8) top_p = body.get("top_p", 0.95) model_name = body.get("model") or MODEL_ID try: content = run_chat(messages, max_tokens, temperature, top_p) except Exception as exc: # surface a clean error to the client return JSONResponse(status_code=500, content={"error": {"message": str(exc)}}) return JSONResponse(content={ "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": model_name, "choices": [ { "index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, }) class OpenAIRouteMiddleware(BaseHTTPMiddleware): """Intercepts the OpenAI-shaped routes before gradio's own routing.""" async def dispatch(self, request: Request, call_next): path = request.url.path.rstrip("/") or "/" if path == "/v1/chat/completions" and request.method == "POST": return await _handle_chat(request) if path == "/health" and request.method == "GET": return JSONResponse(content={"status": "ok", "model": MODEL_ID}) return await call_next(request) # The Spaces gradio runtime builds and serves its own FastAPI app (ignoring # launch kwargs). Patch gradio's app factory so our OpenAI middleware is # injected into whatever app actually gets served. import gradio.routes as _groutes # noqa: E402 _orig_create_app = _groutes.App.create_app def _patched_create_app(*args, **kwargs): created = _orig_create_app(*args, **kwargs) created.add_middleware(OpenAIRouteMiddleware) print("[boot] injected OpenAI middleware into gradio app", flush=True) return created _groutes.App.create_app = staticmethod(_patched_create_app) with gr.Blocks() as demo: gr.Markdown( "# Gemma 4 E4B — ZeroGPU\n" "OpenAI-compatible endpoint at `POST /v1/chat/completions`.\n\n" f"Model: `{MODEL_ID}`" ) prompt = gr.Textbox(label="Prompt", lines=3) answer = gr.Textbox(label="Response", lines=8) gr.Button("Generate").click( lambda p: run_chat([{"role": "user", "content": p}], 512, 0.8, 0.95), prompt, answer, ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")), ssr_mode=False, )