ollamaapi / proxy.py
oki692's picture
Upload 2 files
f2f6faa verified
import os
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
import json
import time
import uuid
app = FastAPI()
security = HTTPBearer()
API_KEY = os.environ.get("API_KEY", "!TU MUSISZ EDYTOWAC!")
MODEL = os.environ.get("MODEL", "!TU MUSISZ EDYTOWAC!")
OLLAMA_BASE = "http://127.0.0.1:11434"
if "!TU MUSISZ EDYTOWAC!" in (API_KEY, MODEL):
raise RuntimeError("Ustaw zmienne API_KEY i MODEL w HF Space Settings -> Variables")
def verify_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return credentials.credentials
# --- Ollama Compatible Endpoints (Directly at /api/...) ---
@app.post("/api/chat")
async def ollama_chat(request: Request, key: str = Depends(verify_key)):
body = await request.json()
# Force streaming and ensure the correct model is used
body["stream"] = True
body["model"] = MODEL
async def generate():
async with httpx.AsyncClient(timeout=600.0) as client:
async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=body) as resp:
async for line in resp.aiter_lines():
if not line: continue
# Ollama's native chat API already includes 'thinking' if supported
yield line + "\n"
return StreamingResponse(generate(), media_type="application/x-ndjson")
@app.post("/api/generate")
async def ollama_generate(request: Request, key: str = Depends(verify_key)):
body = await request.json()
body["stream"] = True
body["model"] = MODEL
async def generate():
async with httpx.AsyncClient(timeout=600.0) as client:
async with client.stream("POST", f"{OLLAMA_BASE}/api/generate", json=body) as resp:
async for line in resp.aiter_lines():
if not line: continue
yield line + "\n"
return StreamingResponse(generate(), media_type="application/x-ndjson")
@app.get("/api/tags")
async def ollama_tags(key: str = Depends(verify_key)):
return {
"models": [{
"name": MODEL,
"model": MODEL,
"modified_at": "2024-01-01T00:00:00Z",
"size": 0,
"digest": "sha256:0000000000000000000000000000000000000000000000000000000000000000",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "unknown",
"quantization_level": "unknown"
}
}]
}
@app.get("/api/version")
async def ollama_version():
return {"version": "0.1.0-proxy"}
# --- OpenAI Compatible Endpoints (Maintained for flexibility) ---
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, key: str = Depends(verify_key)):
body = await request.json()
body["stream"] = True
ollama_payload = {
"model": MODEL,
"messages": body.get("messages", []),
"stream": True,
"options": {
"temperature": body.get("temperature", 0.6),
"top_p": body.get("top_p", 0.95),
}
}
if "max_tokens" in body:
ollama_payload["options"]["num_predict"] = body["max_tokens"]
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
async def generate():
async with httpx.AsyncClient(timeout=600.0) as client:
async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=ollama_payload) as resp:
async for line in resp.aiter_lines():
if not line: continue
try:
chunk = json.loads(line)
except: continue
msg = chunk.get("message", {})
done = chunk.get("done", False)
delta = {}
# Explicitly handle thinking tags for OpenAI format
thinking = msg.get("thinking")
content = msg.get("content")
if thinking:
delta["reasoning_content"] = thinking
if content:
delta["content"] = content
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL,
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": "stop" if done else None,
}]
}
yield f"data: {json.dumps(data)}\n\n"
if done: break
yield "data: [DONE]\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.get("/v1/models")
async def list_models(key: str = Depends(verify_key)):
return {
"object": "list",
"data": [{
"id": MODEL,
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
}]
}
# --- Health Check ---
@app.get("/")
@app.get("/health")
async def health():
async with httpx.AsyncClient(timeout=5.0) as client:
try:
r = await client.get(f"{OLLAMA_BASE}/api/version")
ollama_ok = r.status_code == 200
except Exception:
ollama_ok = False
return {"status": "ok" if ollama_ok else "starting", "model": MODEL}