import os from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import StreamingResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import httpx import json import time import uuid app = FastAPI() security = HTTPBearer() API_KEY = os.environ.get("API_KEY", "!TU MUSISZ EDYTOWAC!") MODEL = os.environ.get("MODEL", "!TU MUSISZ EDYTOWAC!") OLLAMA_BASE = "http://127.0.0.1:11434" if "!TU MUSISZ EDYTOWAC!" in (API_KEY, MODEL): raise RuntimeError("Ustaw zmienne API_KEY i MODEL w HF Space Settings -> Variables") def verify_key(credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.credentials != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") return credentials.credentials # --- Ollama Compatible Endpoints (Directly at /api/...) --- @app.post("/api/chat") async def ollama_chat(request: Request, key: str = Depends(verify_key)): body = await request.json() # Force streaming and ensure the correct model is used body["stream"] = True body["model"] = MODEL async def generate(): async with httpx.AsyncClient(timeout=600.0) as client: async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=body) as resp: async for line in resp.aiter_lines(): if not line: continue # Ollama's native chat API already includes 'thinking' if supported yield line + "\n" return StreamingResponse(generate(), media_type="application/x-ndjson") @app.post("/api/generate") async def ollama_generate(request: Request, key: str = Depends(verify_key)): body = await request.json() body["stream"] = True body["model"] = MODEL async def generate(): async with httpx.AsyncClient(timeout=600.0) as client: async with client.stream("POST", f"{OLLAMA_BASE}/api/generate", json=body) as resp: async for line in resp.aiter_lines(): if not line: continue yield line + "\n" return StreamingResponse(generate(), media_type="application/x-ndjson") @app.get("/api/tags") async def ollama_tags(key: str = Depends(verify_key)): return { "models": [{ "name": MODEL, "model": MODEL, "modified_at": "2024-01-01T00:00:00Z", "size": 0, "digest": "sha256:0000000000000000000000000000000000000000000000000000000000000000", "details": { "parent_model": "", "format": "gguf", "family": "llama", "families": ["llama"], "parameter_size": "unknown", "quantization_level": "unknown" } }] } @app.get("/api/version") async def ollama_version(): return {"version": "0.1.0-proxy"} # --- OpenAI Compatible Endpoints (Maintained for flexibility) --- @app.post("/v1/chat/completions") async def chat_completions(request: Request, key: str = Depends(verify_key)): body = await request.json() body["stream"] = True ollama_payload = { "model": MODEL, "messages": body.get("messages", []), "stream": True, "options": { "temperature": body.get("temperature", 0.6), "top_p": body.get("top_p", 0.95), } } if "max_tokens" in body: ollama_payload["options"]["num_predict"] = body["max_tokens"] completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) async def generate(): async with httpx.AsyncClient(timeout=600.0) as client: async with client.stream("POST", f"{OLLAMA_BASE}/api/chat", json=ollama_payload) as resp: async for line in resp.aiter_lines(): if not line: continue try: chunk = json.loads(line) except: continue msg = chunk.get("message", {}) done = chunk.get("done", False) delta = {} # Explicitly handle thinking tags for OpenAI format thinking = msg.get("thinking") content = msg.get("content") if thinking: delta["reasoning_content"] = thinking if content: delta["content"] = content data = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL, "choices": [{ "index": 0, "delta": delta, "finish_reason": "stop" if done else None, }] } yield f"data: {json.dumps(data)}\n\n" if done: break yield "data: [DONE]\n\n" return StreamingResponse(generate(), media_type="text/event-stream") @app.get("/v1/models") async def list_models(key: str = Depends(verify_key)): return { "object": "list", "data": [{ "id": MODEL, "object": "model", "created": int(time.time()), "owned_by": "ollama", }] } # --- Health Check --- @app.get("/") @app.get("/health") async def health(): async with httpx.AsyncClient(timeout=5.0) as client: try: r = await client.get(f"{OLLAMA_BASE}/api/version") ollama_ok = r.status_code == 200 except Exception: ollama_ok = False return {"status": "ok" if ollama_ok else "starting", "model": MODEL}