gemme4 / main.py
d3evil4's picture
feat: huh
e536cd5
from __future__ import annotations
import logging
import sys
from typing import Any, AsyncIterator
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger("gemma4")
LLAMA_BASE = "http://127.0.0.1:8080"
app = FastAPI(title="Gemma 4 API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0))
@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info("%s %s", request.method, request.url.path)
response = await call_next(request)
logger.info("%s %s -> %d", request.method, request.url.path, response.status_code)
return response
async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]:
async with _client.stream("POST", url, json=payload) as resp:
if resp.status_code != 200:
yield await resp.aread()
return
async for chunk in resp.aiter_bytes():
yield chunk
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
class SimpleChatRequest(BaseModel):
messages: list[dict[str, Any]]
max_tokens: int = 2048
temperature: float = 0.7
stream: bool = False
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 2048
temperature: float = 0.7
stream: bool = False
class VisionRequest(BaseModel):
prompt: str
image: str
max_tokens: int = 2048
temperature: float = 0.7
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
async def health() -> dict[str, Any]:
try:
resp = await _client.get(f"{LLAMA_BASE}/health", timeout=5.0)
llama_status = resp.json() if resp.status_code == 200 else {"status": "error"}
except httpx.TransportError:
raise HTTPException(status_code=503, detail="llama.cpp server unreachable")
return {"status": "ok", "llama": llama_status}
@app.get("/v1/models")
async def list_models() -> Any:
try:
resp = await _client.get(f"{LLAMA_BASE}/v1/models", timeout=10.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
return resp.json()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Any:
payload = await request.json()
if payload.get("stream", False):
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/chat")
async def chat(req: SimpleChatRequest) -> Any:
payload: dict[str, Any] = {
"messages": req.messages,
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": req.stream,
}
if req.stream:
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/generate")
async def generate(req: GenerateRequest) -> Any:
payload: dict[str, Any] = {
"messages": [{"role": "user", "content": req.prompt}],
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": req.stream,
}
if req.stream:
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/vision")
async def vision(req: VisionRequest) -> Any:
if req.image.startswith("http://") or req.image.startswith("https://"):
image_content: dict[str, Any] = {"type": "image_url", "image_url": {"url": req.image}}
else:
image_content = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{req.image}"},
}
payload: dict[str, Any] = {
"messages": [{
"role": "user",
"content": [{"type": "text", "text": req.prompt}, image_content],
}],
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": False,
}
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()