Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| import sys | |
| from typing import Any, AsyncIterator | |
| import httpx | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("gemma4") | |
| LLAMA_BASE = "http://127.0.0.1:8080" | |
| app = FastAPI(title="Gemma 4 API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) | |
| async def log_requests(request: Request, call_next): | |
| logger.info("%s %s", request.method, request.url.path) | |
| response = await call_next(request) | |
| logger.info("%s %s -> %d", request.method, request.url.path, response.status_code) | |
| return response | |
| async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]: | |
| async with _client.stream("POST", url, json=payload) as resp: | |
| if resp.status_code != 200: | |
| yield await resp.aread() | |
| return | |
| async for chunk in resp.aiter_bytes(): | |
| yield chunk | |
| # --------------------------------------------------------------------------- | |
| # Pydantic models | |
| # --------------------------------------------------------------------------- | |
| class SimpleChatRequest(BaseModel): | |
| messages: list[dict[str, Any]] | |
| max_tokens: int = 2048 | |
| temperature: float = 0.7 | |
| stream: bool = False | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 2048 | |
| temperature: float = 0.7 | |
| stream: bool = False | |
| class VisionRequest(BaseModel): | |
| prompt: str | |
| image: str | |
| max_tokens: int = 2048 | |
| temperature: float = 0.7 | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def health() -> dict[str, Any]: | |
| try: | |
| resp = await _client.get(f"{LLAMA_BASE}/health", timeout=5.0) | |
| llama_status = resp.json() if resp.status_code == 200 else {"status": "error"} | |
| except httpx.TransportError: | |
| raise HTTPException(status_code=503, detail="llama.cpp server unreachable") | |
| return {"status": "ok", "llama": llama_status} | |
| async def list_models() -> Any: | |
| try: | |
| resp = await _client.get(f"{LLAMA_BASE}/v1/models", timeout=10.0) | |
| except httpx.TransportError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| return resp.json() | |
| async def chat_completions(request: Request) -> Any: | |
| payload = await request.json() | |
| if payload.get("stream", False): | |
| return StreamingResponse( | |
| _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), | |
| media_type="text/event-stream", | |
| ) | |
| try: | |
| resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) | |
| except httpx.TransportError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| if resp.status_code != 200: | |
| raise HTTPException(status_code=resp.status_code, detail=resp.text) | |
| return resp.json() | |
| async def chat(req: SimpleChatRequest) -> Any: | |
| payload: dict[str, Any] = { | |
| "messages": req.messages, | |
| "max_tokens": req.max_tokens, | |
| "temperature": req.temperature, | |
| "stream": req.stream, | |
| } | |
| if req.stream: | |
| return StreamingResponse( | |
| _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), | |
| media_type="text/event-stream", | |
| ) | |
| try: | |
| resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) | |
| except httpx.TransportError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| if resp.status_code != 200: | |
| raise HTTPException(status_code=resp.status_code, detail=resp.text) | |
| return resp.json() | |
| async def generate(req: GenerateRequest) -> Any: | |
| payload: dict[str, Any] = { | |
| "messages": [{"role": "user", "content": req.prompt}], | |
| "max_tokens": req.max_tokens, | |
| "temperature": req.temperature, | |
| "stream": req.stream, | |
| } | |
| if req.stream: | |
| return StreamingResponse( | |
| _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), | |
| media_type="text/event-stream", | |
| ) | |
| try: | |
| resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) | |
| except httpx.TransportError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| if resp.status_code != 200: | |
| raise HTTPException(status_code=resp.status_code, detail=resp.text) | |
| return resp.json() | |
| async def vision(req: VisionRequest) -> Any: | |
| if req.image.startswith("http://") or req.image.startswith("https://"): | |
| image_content: dict[str, Any] = {"type": "image_url", "image_url": {"url": req.image}} | |
| else: | |
| image_content = { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{req.image}"}, | |
| } | |
| payload: dict[str, Any] = { | |
| "messages": [{ | |
| "role": "user", | |
| "content": [{"type": "text", "text": req.prompt}, image_content], | |
| }], | |
| "max_tokens": req.max_tokens, | |
| "temperature": req.temperature, | |
| "stream": False, | |
| } | |
| try: | |
| resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) | |
| except httpx.TransportError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| if resp.status_code != 200: | |
| raise HTTPException(status_code=resp.status_code, detail=resp.text) | |
| return resp.json() | |