File size: 2,068 Bytes
ac35771 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
import json
app = FastAPI()
security = HTTPBearer()
API_KEY = "connectkey"
OLLAMA_BASE = "http://localhost:11434"
def verify_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return credentials.credentials
@app.get("/api/version")
async def version(key: str = Depends(verify_key)):
async with httpx.AsyncClient() as client:
r = await client.get(f"{OLLAMA_BASE}/api/version")
return r.json()
@app.get("/api/tags")
async def tags(key: str = Depends(verify_key)):
async with httpx.AsyncClient() as client:
r = await client.get(f"{OLLAMA_BASE}/api/tags")
return r.json()
async def _stream(url: str, body: dict):
body["stream"] = True
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("POST", url, json=body) as r:
async for chunk in r.aiter_bytes():
yield chunk
@app.post("/api/generate")
async def generate(request: Request, key: str = Depends(verify_key)):
body = await request.json()
body["stream"] = True
return StreamingResponse(
_stream(f"{OLLAMA_BASE}/api/generate", body),
media_type="application/x-ndjson",
)
@app.post("/api/chat")
async def chat(request: Request, key: str = Depends(verify_key)):
body = await request.json()
body["stream"] = True
return StreamingResponse(
_stream(f"{OLLAMA_BASE}/api/chat", body),
media_type="application/x-ndjson",
)
@app.post("/api/embeddings")
async def embeddings(request: Request, key: str = Depends(verify_key)):
body = await request.json()
async with httpx.AsyncClient(timeout=None) as client:
r = await client.post(f"{OLLAMA_BASE}/api/embeddings", json=body)
return r.json()
|