Spaces:
Running on Zero
Running on Zero
| """Gemma 4 E4B — ZeroGPU, OpenAI-compatible chat-completions endpoint. | |
| Serves an uncensored Gemma 4 E4B checkpoint behind an OpenAI-shaped | |
| ``POST /v1/chat/completions`` route so the image-edit bot's | |
| ``GemmaLLM`` client (apps/image-edit-bot/src/gemma-client.ts) can use it | |
| by only setting env vars. GPU is acquired on demand via ``@spaces.GPU`` | |
| (ephemeral), which fits the bursty one-off script-generation use case. | |
| Gradio owns the server (so the Spaces gradio SDK launches/health-checks | |
| it on port 7860). The OpenAI route is added as ASGI middleware via | |
| ``app_kwargs`` so it runs *before* gradio's catch-all SPA route. | |
| """ | |
| import os | |
| import time | |
| import uuid | |
| from typing import Optional | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from fastapi.responses import JSONResponse | |
| from starlette.middleware import Middleware | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.requests import Request | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = os.environ.get("MODEL_ID", "OBLITERATUS/gemma-4-E4B-it-OBLITERATED") | |
| ENDPOINT_API_KEY = os.environ.get("ENDPOINT_API_KEY", "").strip() | |
| MAX_NEW_TOKENS_CAP = int(os.environ.get("MAX_NEW_TOKENS_CAP", "2048")) | |
| GPU_DURATION = int(os.environ.get("GPU_DURATION", "120")) | |
| print(f"[boot] loading tokenizer/model for {MODEL_ID} (CPU) ...", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, dtype=torch.bfloat16, low_cpu_mem_usage=True | |
| ) | |
| model.eval() | |
| print("[boot] model loaded on CPU; GPU is acquired per request", flush=True) | |
| def _generate(text: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
| model.to("cuda") | |
| inputs = tokenizer(text, return_tensors="pt").to("cuda") | |
| input_len = inputs["input_ids"].shape[-1] | |
| gen_kwargs = {"max_new_tokens": int(max_new_tokens)} | |
| if temperature and float(temperature) > 0: | |
| gen_kwargs.update(do_sample=True, temperature=float(temperature), top_p=float(top_p)) | |
| else: | |
| gen_kwargs.update(do_sample=False) | |
| with torch.inference_mode(): | |
| out = model.generate(**inputs, **gen_kwargs) | |
| return tokenizer.decode(out[0][input_len:], skip_special_tokens=True).strip() | |
| def _flatten_content(content) -> str: | |
| if isinstance(content, list): | |
| return "".join( | |
| p.get("text", "") | |
| for p in content | |
| if isinstance(p, dict) and p.get("type") == "text" | |
| ) | |
| return "" if content is None else str(content) | |
| def run_chat(messages, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
| norm = [ | |
| {"role": m.get("role", "user"), "content": _flatten_content(m.get("content"))} | |
| for m in messages | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| norm, tokenize=False, add_generation_prompt=True, enable_thinking=False | |
| ) | |
| return _generate(text, max_new_tokens, temperature, top_p) | |
| def _authorized(authorization: Optional[str]) -> bool: | |
| if not ENDPOINT_API_KEY: | |
| return True | |
| if not authorization: | |
| return False | |
| token = authorization[7:] if authorization.lower().startswith("bearer ") else authorization | |
| return token.strip() == ENDPOINT_API_KEY | |
| async def _handle_chat(request: Request) -> JSONResponse: | |
| if not _authorized(request.headers.get("authorization")): | |
| return JSONResponse(status_code=401, content={"error": {"message": "invalid api key"}}) | |
| body = await request.json() | |
| messages = body.get("messages") or [] | |
| if not messages: | |
| return JSONResponse(status_code=400, content={"error": {"message": "messages required"}}) | |
| max_tokens = max(1, min(int(body.get("max_tokens") or 1024), MAX_NEW_TOKENS_CAP)) | |
| temperature = body.get("temperature", 0.8) | |
| top_p = body.get("top_p", 0.95) | |
| model_name = body.get("model") or MODEL_ID | |
| try: | |
| content = run_chat(messages, max_tokens, temperature, top_p) | |
| except Exception as exc: # surface a clean error to the client | |
| return JSONResponse(status_code=500, content={"error": {"message": str(exc)}}) | |
| return JSONResponse(content={ | |
| "id": f"chatcmpl-{uuid.uuid4().hex}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model_name, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": content}, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | |
| }) | |
| class OpenAIRouteMiddleware(BaseHTTPMiddleware): | |
| """Intercepts the OpenAI-shaped routes before gradio's own routing.""" | |
| async def dispatch(self, request: Request, call_next): | |
| path = request.url.path.rstrip("/") or "/" | |
| if path == "/v1/chat/completions" and request.method == "POST": | |
| return await _handle_chat(request) | |
| if path == "/health" and request.method == "GET": | |
| return JSONResponse(content={"status": "ok", "model": MODEL_ID}) | |
| return await call_next(request) | |
| # The Spaces gradio runtime builds and serves its own FastAPI app (ignoring | |
| # launch kwargs). Patch gradio's app factory so our OpenAI middleware is | |
| # injected into whatever app actually gets served. | |
| import gradio.routes as _groutes # noqa: E402 | |
| _orig_create_app = _groutes.App.create_app | |
| def _patched_create_app(*args, **kwargs): | |
| created = _orig_create_app(*args, **kwargs) | |
| created.add_middleware(OpenAIRouteMiddleware) | |
| print("[boot] injected OpenAI middleware into gradio app", flush=True) | |
| return created | |
| _groutes.App.create_app = staticmethod(_patched_create_app) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Gemma 4 E4B — ZeroGPU\n" | |
| "OpenAI-compatible endpoint at `POST /v1/chat/completions`.\n\n" | |
| f"Model: `{MODEL_ID}`" | |
| ) | |
| prompt = gr.Textbox(label="Prompt", lines=3) | |
| answer = gr.Textbox(label="Response", lines=8) | |
| gr.Button("Generate").click( | |
| lambda p: run_chat([{"role": "user", "content": p}], 512, 0.8, 0.95), | |
| prompt, | |
| answer, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")), | |
| ssr_mode=False, | |
| ) | |