| import os |
| from fastapi import FastAPI, Request, HTTPException, Depends |
| from fastapi.responses import StreamingResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| import httpx |
| import json |
| import time |
| import uuid |
|
|
| app = FastAPI() |
| security = HTTPBearer() |
|
|
| API_KEY = os.environ.get("API_KEY", "!TU MUSISZ EDYTOWAC!") |
| MODEL = os.environ.get("MODEL", "!TU MUSISZ EDYTOWAC!") |
| OLLAMA_BASE = "http://127.0.0.1:11434" |
|
|
| if "!TU MUSISZ EDYTOWAC!" in (API_KEY, MODEL): |
| raise RuntimeError("Ustaw zmienne API_KEY i MODEL w HF Space Settings -> Variables") |
|
|
| def verify_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if credentials.credentials != API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid API key") |
| return credentials.credentials |
|
|
| |
|
|
| @app.post("/api/chat") |
| async def ollama_chat(request: Request, key: str = Depends(verify_key)): |
| body = await request.json() |
| |
| body["stream"] = True |
| body["model"] = MODEL |
| |
| async def generate(): |
| async with httpx.AsyncClient(timeout=600.0) as client: |
| async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=body) as resp: |
| async for line in resp.aiter_lines(): |
| if not line: continue |
| |
| yield line + "\n" |
|
|
| return StreamingResponse(generate(), media_type="application/x-ndjson") |
|
|
| @app.post("/api/generate") |
| async def ollama_generate(request: Request, key: str = Depends(verify_key)): |
| body = await request.json() |
| body["stream"] = True |
| body["model"] = MODEL |
| |
| async def generate(): |
| async with httpx.AsyncClient(timeout=600.0) as client: |
| async with client.stream("POST", f"{OLLAMA_BASE}/api/generate", json=body) as resp: |
| async for line in resp.aiter_lines(): |
| if not line: continue |
| yield line + "\n" |
|
|
| return StreamingResponse(generate(), media_type="application/x-ndjson") |
|
|
| @app.get("/api/tags") |
| async def ollama_tags(key: str = Depends(verify_key)): |
| return { |
| "models": [{ |
| "name": MODEL, |
| "model": MODEL, |
| "modified_at": "2024-01-01T00:00:00Z", |
| "size": 0, |
| "digest": "sha256:0000000000000000000000000000000000000000000000000000000000000000", |
| "details": { |
| "parent_model": "", |
| "format": "gguf", |
| "family": "llama", |
| "families": ["llama"], |
| "parameter_size": "unknown", |
| "quantization_level": "unknown" |
| } |
| }] |
| } |
|
|
| @app.get("/api/version") |
| async def ollama_version(): |
| return {"version": "0.1.0-proxy"} |
|
|
| |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request, key: str = Depends(verify_key)): |
| body = await request.json() |
| body["stream"] = True |
| |
| ollama_payload = { |
| "model": MODEL, |
| "messages": body.get("messages", []), |
| "stream": True, |
| "options": { |
| "temperature": body.get("temperature", 0.6), |
| "top_p": body.get("top_p", 0.95), |
| } |
| } |
| if "max_tokens" in body: |
| ollama_payload["options"]["num_predict"] = body["max_tokens"] |
|
|
| completion_id = f"chatcmpl-{uuid.uuid4().hex}" |
| created = int(time.time()) |
|
|
| async def generate(): |
| async with httpx.AsyncClient(timeout=600.0) as client: |
| async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=ollama_payload) as resp: |
| async for line in resp.aiter_lines(): |
| if not line: continue |
| try: |
| chunk = json.loads(line) |
| except: continue |
|
|
| msg = chunk.get("message", {}) |
| done = chunk.get("done", False) |
| |
| delta = {} |
| |
| thinking = msg.get("thinking") |
| content = msg.get("content") |
| |
| if thinking: |
| delta["reasoning_content"] = thinking |
| if content: |
| delta["content"] = content |
|
|
| data = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": MODEL, |
| "choices": [{ |
| "index": 0, |
| "delta": delta, |
| "finish_reason": "stop" if done else None, |
| }] |
| } |
| yield f"data: {json.dumps(data)}\n\n" |
| if done: break |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
| @app.get("/v1/models") |
| async def list_models(key: str = Depends(verify_key)): |
| return { |
| "object": "list", |
| "data": [{ |
| "id": MODEL, |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "ollama", |
| }] |
| } |
|
|
| |
|
|
| @app.get("/") |
| @app.get("/health") |
| async def health(): |
| async with httpx.AsyncClient(timeout=5.0) as client: |
| try: |
| r = await client.get(f"{OLLAMA_BASE}/api/version") |
| ollama_ok = r.status_code == 200 |
| except Exception: |
| ollama_ok = False |
| return {"status": "ok" if ollama_ok else "starting", "model": MODEL} |
|
|