ZeroPointMonkey's picture
Disable gradio SSR so Python serves custom routes directly
c003c21 verified
"""Gemma 4 E4B — ZeroGPU, OpenAI-compatible chat-completions endpoint.
Serves an uncensored Gemma 4 E4B checkpoint behind an OpenAI-shaped
``POST /v1/chat/completions`` route so the image-edit bot's
``GemmaLLM`` client (apps/image-edit-bot/src/gemma-client.ts) can use it
by only setting env vars. GPU is acquired on demand via ``@spaces.GPU``
(ephemeral), which fits the bursty one-off script-generation use case.
Gradio owns the server (so the Spaces gradio SDK launches/health-checks
it on port 7860). The OpenAI route is added as ASGI middleware via
``app_kwargs`` so it runs *before* gradio's catch-all SPA route.
"""
import os
import time
import uuid
from typing import Optional
import torch
import spaces
import gradio as gr
from fastapi.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = os.environ.get("MODEL_ID", "OBLITERATUS/gemma-4-E4B-it-OBLITERATED")
ENDPOINT_API_KEY = os.environ.get("ENDPOINT_API_KEY", "").strip()
MAX_NEW_TOKENS_CAP = int(os.environ.get("MAX_NEW_TOKENS_CAP", "2048"))
GPU_DURATION = int(os.environ.get("GPU_DURATION", "120"))
print(f"[boot] loading tokenizer/model for {MODEL_ID} (CPU) ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=torch.bfloat16, low_cpu_mem_usage=True
)
model.eval()
print("[boot] model loaded on CPU; GPU is acquired per request", flush=True)
@spaces.GPU(duration=GPU_DURATION)
def _generate(text: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
model.to("cuda")
inputs = tokenizer(text, return_tensors="pt").to("cuda")
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = {"max_new_tokens": int(max_new_tokens)}
if temperature and float(temperature) > 0:
gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=float(top_p))
else:
gen_kwargs.update(do_sample=False)
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
return tokenizer.decode(out[0][input_len:], skip_special_tokens=True).strip()
def _flatten_content(content) -> str:
if isinstance(content, list):
return "".join(
p.get("text", "")
for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return "" if content is None else str(content)
def run_chat(messages, max_new_tokens: int, temperature: float, top_p: float) -> str:
norm = [
{"role": m.get("role", "user"), "content": _flatten_content(m.get("content"))}
for m in messages
]
text = tokenizer.apply_chat_template(
norm, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
return _generate(text, max_new_tokens, temperature, top_p)
def _authorized(authorization: Optional[str]) -> bool:
if not ENDPOINT_API_KEY:
return True
if not authorization:
return False
token = authorization[7:] if authorization.lower().startswith("bearer ") else authorization
return token.strip() == ENDPOINT_API_KEY
async def _handle_chat(request: Request) -> JSONResponse:
if not _authorized(request.headers.get("authorization")):
return JSONResponse(status_code=401, content={"error": {"message": "invalid api key"}})
body = await request.json()
messages = body.get("messages") or []
if not messages:
return JSONResponse(status_code=400, content={"error": {"message": "messages required"}})
max_tokens = max(1, min(int(body.get("max_tokens") or 1024), MAX_NEW_TOKENS_CAP))
temperature = body.get("temperature", 0.8)
top_p = body.get("top_p", 0.95)
model_name = body.get("model") or MODEL_ID
try:
content = run_chat(messages, max_tokens, temperature, top_p)
except Exception as exc: # surface a clean error to the client
return JSONResponse(status_code=500, content={"error": {"message": str(exc)}})
return JSONResponse(content={
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
class OpenAIRouteMiddleware(BaseHTTPMiddleware):
"""Intercepts the OpenAI-shaped routes before gradio's own routing."""
async def dispatch(self, request: Request, call_next):
path = request.url.path.rstrip("/") or "/"
if path == "/v1/chat/completions" and request.method == "POST":
return await _handle_chat(request)
if path == "/health" and request.method == "GET":
return JSONResponse(content={"status": "ok", "model": MODEL_ID})
return await call_next(request)
# The Spaces gradio runtime builds and serves its own FastAPI app (ignoring
# launch kwargs). Patch gradio's app factory so our OpenAI middleware is
# injected into whatever app actually gets served.
import gradio.routes as _groutes # noqa: E402
_orig_create_app = _groutes.App.create_app
def _patched_create_app(*args, **kwargs):
created = _orig_create_app(*args, **kwargs)
created.add_middleware(OpenAIRouteMiddleware)
print("[boot] injected OpenAI middleware into gradio app", flush=True)
return created
_groutes.App.create_app = staticmethod(_patched_create_app)
with gr.Blocks() as demo:
gr.Markdown(
"# Gemma 4 E4B — ZeroGPU\n"
"OpenAI-compatible endpoint at `POST /v1/chat/completions`.\n\n"
f"Model: `{MODEL_ID}`"
)
prompt = gr.Textbox(label="Prompt", lines=3)
answer = gr.Textbox(label="Response", lines=8)
gr.Button("Generate").click(
lambda p: run_chat([{"role": "user", "content": p}], 512, 0.8, 0.95),
prompt,
answer,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
ssr_mode=False,
)