File size: 2,068 Bytes
ac35771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
import json

app = FastAPI()
security = HTTPBearer()

API_KEY = "connectkey"
OLLAMA_BASE = "http://localhost:11434"


def verify_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return credentials.credentials


@app.get("/api/version")
async def version(key: str = Depends(verify_key)):
    async with httpx.AsyncClient() as client:
        r = await client.get(f"{OLLAMA_BASE}/api/version")
        return r.json()


@app.get("/api/tags")
async def tags(key: str = Depends(verify_key)):
    async with httpx.AsyncClient() as client:
        r = await client.get(f"{OLLAMA_BASE}/api/tags")
        return r.json()


async def _stream(url: str, body: dict):
    body["stream"] = True
    async with httpx.AsyncClient(timeout=None) as client:
        async with client.stream("POST", url, json=body) as r:
            async for chunk in r.aiter_bytes():
                yield chunk


@app.post("/api/generate")
async def generate(request: Request, key: str = Depends(verify_key)):
    body = await request.json()
    body["stream"] = True
    return StreamingResponse(
        _stream(f"{OLLAMA_BASE}/api/generate", body),
        media_type="application/x-ndjson",
    )


@app.post("/api/chat")
async def chat(request: Request, key: str = Depends(verify_key)):
    body = await request.json()
    body["stream"] = True
    return StreamingResponse(
        _stream(f"{OLLAMA_BASE}/api/chat", body),
        media_type="application/x-ndjson",
    )


@app.post("/api/embeddings")
async def embeddings(request: Request, key: str = Depends(verify_key)):
    body = await request.json()
    async with httpx.AsyncClient(timeout=None) as client:
        r = await client.post(f"{OLLAMA_BASE}/api/embeddings", json=body)
        return r.json()