from __future__ import annotations import logging import sys from typing import Any, AsyncIterator import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", stream=sys.stdout, ) logger = logging.getLogger("gemma4") LLAMA_BASE = "http://127.0.0.1:8080" app = FastAPI(title="Gemma 4 API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) @app.middleware("http") async def log_requests(request: Request, call_next): logger.info("%s %s", request.method, request.url.path) response = await call_next(request) logger.info("%s %s -> %d", request.method, request.url.path, response.status_code) return response async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]: async with _client.stream("POST", url, json=payload) as resp: if resp.status_code != 200: yield await resp.aread() return async for chunk in resp.aiter_bytes(): yield chunk # --------------------------------------------------------------------------- # Pydantic models # --------------------------------------------------------------------------- class SimpleChatRequest(BaseModel): messages: list[dict[str, Any]] max_tokens: int = 2048 temperature: float = 0.7 stream: bool = False class GenerateRequest(BaseModel): prompt: str max_tokens: int = 2048 temperature: float = 0.7 stream: bool = False class VisionRequest(BaseModel): prompt: str image: str max_tokens: int = 2048 temperature: float = 0.7 # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health") async def health() -> dict[str, Any]: try: resp = await _client.get(f"{LLAMA_BASE}/health", timeout=5.0) llama_status = resp.json() if resp.status_code == 200 else {"status": "error"} except httpx.TransportError: raise HTTPException(status_code=503, detail="llama.cpp server unreachable") return {"status": "ok", "llama": llama_status} @app.get("/v1/models") async def list_models() -> Any: try: resp = await _client.get(f"{LLAMA_BASE}/v1/models", timeout=10.0) except httpx.TransportError as exc: raise HTTPException(status_code=503, detail=str(exc)) return resp.json() @app.post("/v1/chat/completions") async def chat_completions(request: Request) -> Any: payload = await request.json() if payload.get("stream", False): return StreamingResponse( _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), media_type="text/event-stream", ) try: resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) except httpx.TransportError as exc: raise HTTPException(status_code=503, detail=str(exc)) if resp.status_code != 200: raise HTTPException(status_code=resp.status_code, detail=resp.text) return resp.json() @app.post("/chat") async def chat(req: SimpleChatRequest) -> Any: payload: dict[str, Any] = { "messages": req.messages, "max_tokens": req.max_tokens, "temperature": req.temperature, "stream": req.stream, } if req.stream: return StreamingResponse( _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), media_type="text/event-stream", ) try: resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) except httpx.TransportError as exc: raise HTTPException(status_code=503, detail=str(exc)) if resp.status_code != 200: raise HTTPException(status_code=resp.status_code, detail=resp.text) return resp.json() @app.post("/generate") async def generate(req: GenerateRequest) -> Any: payload: dict[str, Any] = { "messages": [{"role": "user", "content": req.prompt}], "max_tokens": req.max_tokens, "temperature": req.temperature, "stream": req.stream, } if req.stream: return StreamingResponse( _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload), media_type="text/event-stream", ) try: resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) except httpx.TransportError as exc: raise HTTPException(status_code=503, detail=str(exc)) if resp.status_code != 200: raise HTTPException(status_code=resp.status_code, detail=resp.text) return resp.json() @app.post("/vision") async def vision(req: VisionRequest) -> Any: if req.image.startswith("http://") or req.image.startswith("https://"): image_content: dict[str, Any] = {"type": "image_url", "image_url": {"url": req.image}} else: image_content = { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{req.image}"}, } payload: dict[str, Any] = { "messages": [{ "role": "user", "content": [{"type": "text", "text": req.prompt}, image_content], }], "max_tokens": req.max_tokens, "temperature": req.temperature, "stream": False, } try: resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0) except httpx.TransportError as exc: raise HTTPException(status_code=503, detail=str(exc)) if resp.status_code != 200: raise HTTPException(status_code=resp.status_code, detail=resp.text) return resp.json()