Spaces:
Running
Running
File size: 6,047 Bytes
e2fc204 6c84960 e2fc204 6c84960 e2fc204 6c84960 e2fc204 6c84960 e2fc204 6c84960 e2fc204 e536cd5 35424c3 e536cd5 6c84960 e2fc204 6c84960 e536cd5 6c84960 e2fc204 6c84960 e2fc204 e536cd5 e2fc204 e536cd5 e2fc204 6c84960 e2fc204 6c84960 e2fc204 e6db69c e2fc204 5002f3a e2fc204 5002f3a e2fc204 e6db69c e2fc204 e536cd5 e2fc204 e536cd5 6c84960 e2fc204 0fe7743 e536cd5 e2fc204 e536cd5 e2fc204 0fe7743 e2fc204 e536cd5 e2fc204 5002f3a e2fc204 5002f3a e536cd5 e2fc204 6c84960 e2fc204 6c84960 e2fc204 e536cd5 e2fc204 e536cd5 e2fc204 5002f3a e2fc204 e536cd5 e2fc204 e536cd5 e2fc204 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | from __future__ import annotations
import logging
import sys
from typing import Any, AsyncIterator
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger("gemma4")
LLAMA_BASE = "http://127.0.0.1:8080"
app = FastAPI(title="Gemma 4 API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0))
@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info("%s %s", request.method, request.url.path)
response = await call_next(request)
logger.info("%s %s -> %d", request.method, request.url.path, response.status_code)
return response
async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]:
async with _client.stream("POST", url, json=payload) as resp:
if resp.status_code != 200:
yield await resp.aread()
return
async for chunk in resp.aiter_bytes():
yield chunk
# ---------------------------------------------------------------------------
# Pydantic models
# ---------------------------------------------------------------------------
class SimpleChatRequest(BaseModel):
messages: list[dict[str, Any]]
max_tokens: int = 2048
temperature: float = 0.7
stream: bool = False
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 2048
temperature: float = 0.7
stream: bool = False
class VisionRequest(BaseModel):
prompt: str
image: str
max_tokens: int = 2048
temperature: float = 0.7
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
async def health() -> dict[str, Any]:
try:
resp = await _client.get(f"{LLAMA_BASE}/health", timeout=5.0)
llama_status = resp.json() if resp.status_code == 200 else {"status": "error"}
except httpx.TransportError:
raise HTTPException(status_code=503, detail="llama.cpp server unreachable")
return {"status": "ok", "llama": llama_status}
@app.get("/v1/models")
async def list_models() -> Any:
try:
resp = await _client.get(f"{LLAMA_BASE}/v1/models", timeout=10.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
return resp.json()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Any:
payload = await request.json()
if payload.get("stream", False):
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/chat")
async def chat(req: SimpleChatRequest) -> Any:
payload: dict[str, Any] = {
"messages": req.messages,
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": req.stream,
}
if req.stream:
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/generate")
async def generate(req: GenerateRequest) -> Any:
payload: dict[str, Any] = {
"messages": [{"role": "user", "content": req.prompt}],
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": req.stream,
}
if req.stream:
return StreamingResponse(
_proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
media_type="text/event-stream",
)
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.post("/vision")
async def vision(req: VisionRequest) -> Any:
if req.image.startswith("http://") or req.image.startswith("https://"):
image_content: dict[str, Any] = {"type": "image_url", "image_url": {"url": req.image}}
else:
image_content = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{req.image}"},
}
payload: dict[str, Any] = {
"messages": [{
"role": "user",
"content": [{"type": "text", "text": req.prompt}, image_content],
}],
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": False,
}
try:
resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
except httpx.TransportError as exc:
raise HTTPException(status_code=503, detail=str(exc))
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
|