Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import requests | |
| import uvicorn | |
| from fastapi import FastAPI, Depends, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import StreamingResponse | |
| from contextlib import asynccontextmanager | |
| import subprocess | |
| import shutil | |
| # Check if ollama is available | |
| OLLAMA_AVAILABLE = shutil.which("ollama") is not None | |
| async def lifespan(app: FastAPI): | |
| """Startup and shutdown events""" | |
| if OLLAMA_AVAILABLE: | |
| print("Starting Ollama service...") | |
| subprocess.Popen(["ollama", "serve"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| await asyncio.sleep(3) # Wait for Ollama to start | |
| # Set keep-alive to prevent model unloading | |
| os.environ["OLLAMA_KEEP_ALIVE"] = "24h" | |
| # Pull model if needed | |
| try: | |
| r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=5) | |
| models = [m["name"] for m in r.json().get("models", [])] | |
| if MODEL not in models: | |
| print(f"Pulling model {MODEL}...") | |
| subprocess.run(["ollama", "pull", MODEL], check=False) | |
| except Exception as e: | |
| print(f"Warning: Could not check/pull model: {e}") | |
| yield | |
| print("Shutting down...") | |
| app = FastAPI(title="o87Dev Cloud LLM API", lifespan=lifespan) | |
| security = HTTPBearer(auto_error=False) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| OLLAMA_BASE = "http://localhost:11434" | |
| MODEL = os.environ.get("DEFAULT_MODEL", "qwen2.5-coder:7b-instruct-q4_K_M") | |
| API_TOKEN = os.environ.get("API_TOKEN", "") | |
| MAX_CTX = int(os.environ.get("MAX_CTX", "4096")) | |
| MAX_OUT = int(os.environ.get("MAX_OUT", "1024")) | |
| TIMEOUT = int(os.environ.get("TIMEOUT", "240")) # 4 min limit | |
| # Semaphore to limit concurrent requests (prevents OOM) | |
| semaphore = asyncio.Semaphore(1) # Only 1 request at a time for CPU Spaces | |
| def verify_token(creds: HTTPAuthorizationCredentials = Depends(security)): | |
| if not API_TOKEN: | |
| return "no-auth" | |
| if not creds or creds.credentials != API_TOKEN: | |
| raise HTTPException(401, "Invalid token") | |
| return creds.credentials | |
| async def wait_for_ollama(max_retries=10, delay=1): | |
| """Wait for Ollama to be ready, with retries""" | |
| for i in range(max_retries): | |
| try: | |
| r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=2) | |
| if r.status_code == 200: | |
| return True | |
| except: | |
| pass | |
| await asyncio.sleep(delay) | |
| return False | |
| async def ensure_model_loaded(model_name: str = None): | |
| """Pre-load model with a dummy request to force it into memory""" | |
| model = model_name or MODEL | |
| try: | |
| # Check if model is already loaded | |
| r = requests.get(f"{OLLAMA_BASE}/api/ps", timeout=2) | |
| loaded = [m.get("model") for m in r.json().get("models", [])] | |
| if model not in loaded: | |
| print(f"Pre-loading model {model}...") | |
| requests.post( | |
| f"{OLLAMA_BASE}/api/generate", | |
| json={"model": model, "prompt": "test", "stream": False}, | |
| timeout=30 | |
| ) | |
| print(f"Model {model} loaded") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-load model: {e}") | |
| async def root(): | |
| return { | |
| "status": "ok", | |
| "model": MODEL, | |
| "max_ctx": MAX_CTX, | |
| "ollama_available": OLLAMA_AVAILABLE | |
| } | |
| async def health(): | |
| try: | |
| r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=5) | |
| models = [m["name"] for m in r.json().get("models", [])] | |
| return { | |
| "status": "ok" if MODEL in models else "model_missing", | |
| "model": MODEL, | |
| "model_available": MODEL in models, | |
| "available_models": models, | |
| "max_ctx": MAX_CTX | |
| } | |
| except Exception as e: | |
| return {"status": "starting", "error": str(e)} | |
| async def list_models(token: str = Depends(verify_token)): | |
| try: | |
| r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=5) | |
| models = [{"id": m["name"], "object": "model"} for m in r.json().get("models", [])] | |
| return {"object": "list", "data": models} | |
| except Exception: | |
| return {"object": "list", "data": [{"id": MODEL, "object": "model"}]} | |
| async def chat_completions(request: Request, token: str = Depends(verify_token)): | |
| """OpenAI-compatible endpoint with retries and better error handling""" | |
| # Wait for Ollama to be ready | |
| if not await wait_for_ollama(): | |
| raise HTTPException(503, "Ollama service not ready") | |
| async with semaphore: | |
| body = await request.json() | |
| model = body.get("model", MODEL) | |
| stream = body.get("stream", False) | |
| # Ensure model is loaded before proceeding | |
| await ensure_model_loaded(model) | |
| payload = { | |
| "model": model, | |
| "messages": body.get("messages", []), | |
| "stream": stream, | |
| "options": { | |
| "num_ctx": MAX_CTX, | |
| "num_predict": min(body.get("max_tokens", MAX_OUT), MAX_OUT), | |
| "temperature": body.get("temperature", 0.7), | |
| } | |
| } | |
| if stream: | |
| def generate(): | |
| try: | |
| with requests.post( | |
| f"{OLLAMA_BASE}/v1/chat/completions", | |
| json=payload, | |
| stream=True, | |
| timeout=TIMEOUT | |
| ) as r: | |
| if r.status_code != 200: | |
| error_msg = f"Ollama error: {r.status_code}" | |
| yield f"data: {json.dumps({'error': error_msg})}\n\n".encode() | |
| yield b"data: [DONE]\n\n" | |
| return | |
| for chunk in r.iter_content(chunk_size=None): | |
| if chunk: | |
| yield chunk | |
| except requests.Timeout: | |
| yield f"data: {json.dumps({'error': 'Request timeout - try a shorter prompt'})}\n\n".encode() | |
| yield b"data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n".encode() | |
| yield b"data: [DONE]\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| # Non-streaming request with retry logic | |
| max_retries = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| r = requests.post( | |
| f"{OLLAMA_BASE}/v1/chat/completions", | |
| json=payload, | |
| timeout=TIMEOUT | |
| ) | |
| if r.status_code == 200: | |
| return r.json() | |
| elif r.status_code == 404: | |
| # Model not found - try to pull it | |
| if attempt < max_retries - 1: | |
| print(f"Model {model} not found, attempting pull...") | |
| subprocess.run(["ollama", "pull", model], check=False) | |
| await asyncio.sleep(5) | |
| continue | |
| raise HTTPException(r.status_code, f"Ollama error: {r.text}") | |
| except requests.Timeout: | |
| if attempt == max_retries - 1: | |
| raise HTTPException(504, "Inference timeout — try a shorter prompt") | |
| await asyncio.sleep(2) | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise HTTPException(500, str(e)) | |
| await asyncio.sleep(2) | |
| async def messages(request: Request, token: str = Depends(verify_token)): | |
| """Anthropic-compatible messages endpoint""" | |
| if not await wait_for_ollama(): | |
| raise HTTPException(503, "Ollama service not ready") | |
| async with semaphore: | |
| body = await request.json() | |
| model = body.get("model", MODEL) | |
| stream = body.get("stream", False) | |
| await ensure_model_loaded(model) | |
| payload = { | |
| "model": model, | |
| "messages": body.get("messages", []), | |
| "stream": stream, | |
| "options": { | |
| "num_ctx": MAX_CTX, | |
| "num_predict": min(body.get("max_tokens", MAX_OUT), MAX_OUT), | |
| "temperature": body.get("temperature", 0.7), | |
| } | |
| } | |
| if stream: | |
| def generate_anthropic(): | |
| msg_id = f"msg_{int(time.time())}" | |
| yield f"event: message_start\ndata: {json.dumps({'type':'message_start','message':{'id':msg_id,'type':'message','role':'assistant','content':[],'model':model,'stop_reason':None,'usage':{'input_tokens':0,'output_tokens':0}}})}\n\n".encode() | |
| yield f"event: content_block_start\ndata: {json.dumps({'type':'content_block_start','index':0,'content_block':{'type':'text','text':''}})}\n\n".encode() | |
| yield b"event: ping\ndata: {\"type\":\"ping\"}\n\n" | |
| out_tokens = 0 | |
| try: | |
| with requests.post( | |
| f"{OLLAMA_BASE}/v1/chat/completions", | |
| json=payload, stream=True, timeout=TIMEOUT | |
| ) as r: | |
| if r.status_code != 200: | |
| yield f"event: content_block_delta\ndata: {json.dumps({'type':'content_block_delta','index':0,'delta':{'type':'text_delta','text':f'Error: Ollama returned {r.status_code}'}})}\n\n".encode() | |
| else: | |
| buf = "" | |
| for chunk in r.iter_content(chunk_size=None): | |
| if not chunk: | |
| continue | |
| buf += chunk.decode("utf-8", errors="ignore") | |
| lines = buf.split("\n") | |
| buf = lines.pop() | |
| for line in lines: | |
| line = line.strip() | |
| if not line or not line.startswith("data: "): | |
| continue | |
| js = line[6:] | |
| if js == "[DONE]": | |
| break | |
| try: | |
| d = json.loads(js) | |
| if d.get("usage"): | |
| out_tokens = d["usage"].get("completion_tokens", 0) | |
| text = (d.get("choices") or [{}])[0].get("delta", {}).get("content", "") | |
| if text: | |
| yield f"event: content_block_delta\ndata: {json.dumps({'type':'content_block_delta','index':0,'delta':{'type':'text_delta','text':text}})}\n\n".encode() | |
| except: | |
| pass | |
| except Exception as e: | |
| yield f"event: content_block_delta\ndata: {json.dumps({'type':'content_block_delta','index':0,'delta':{'type':'text_delta','text':f'Error: {e}'}})}\n\n".encode() | |
| yield b"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n" | |
| yield f"event: message_delta\ndata: {json.dumps({'type':'message_delta','delta':{'stop_reason':'end_turn','stop_sequence':None},'usage':{'output_tokens':out_tokens}})}\n\n".encode() | |
| yield b"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n" | |
| return StreamingResponse(generate_anthropic(), media_type="text/event-stream") | |
| # Non-streaming | |
| try: | |
| r = requests.post(f"{OLLAMA_BASE}/v1/chat/completions", json=payload, timeout=TIMEOUT) | |
| data = r.json() | |
| content = (data.get("choices") or [{}])[0].get("message", {}).get("content", "") | |
| return { | |
| "id": data.get("id", f"msg_{int(time.time())}"), | |
| "type": "message", | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": content}], | |
| "model": model, | |
| "stop_reason": "end_turn", | |
| "usage": { | |
| "input_tokens": data.get("usage", {}).get("prompt_tokens", 0), | |
| "output_tokens": data.get("usage", {}).get("completion_tokens", 0) | |
| } | |
| } | |
| except requests.Timeout: | |
| raise HTTPException(504, "Inference timeout — try a shorter prompt") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |