LS / server.py
Adarshu07's picture
Update server.py
fcbe7be verified
"""
╔═══════════════════════════════════════════════════════════════╗
β•‘ server.py β€” Cloudflare AI REST API β•‘
β•‘ β•‘
β•‘ OpenAI-compatible endpoints: β•‘
β•‘ POST /v1/chat/completions (streaming + non-streaming) β•‘
β•‘ GET /v1/models β•‘
β•‘ GET /health β•‘
β•‘ GET / β•‘
β•‘ β•‘
β•‘ Pool startup: up to 3 retries per slot, logs exact errors. β•‘
β•‘ Health monitor: heals dead idle slots every 60s. β•‘
║ SSE: thread→asyncio bridge with backpressure. ║
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
import asyncio
import json
import logging
import os
import sys
import threading
import time
import traceback
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from cloudflare_provider import CloudflareProvider
# ═══════════════════════════════════════════════════════════
# LOGGING
# ═══════════════════════════════════════════════════════════
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)-8s %(message)s",
stream = sys.stdout,
datefmt = "%H:%M:%S",
)
log = logging.getLogger("cf-api")
# ═══════════════════════════════════════════════════════════
# CONFIG
# ═══════════════════════════════════════════════════════════
POOL_SIZE = int(os.getenv("POOL_SIZE", "2"))
PORT = int(os.getenv("PORT", "7860"))
HOST = os.getenv("HOST", "0.0.0.0")
HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60"))
ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "90"))
STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120"))
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5")
DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.")
SLOT_RETRIES = int(os.getenv("SLOT_RETRIES", "3"))
SLOT_RETRY_WAIT = int(os.getenv("SLOT_RETRY_WAIT", "10")) # seconds between retries
# ═══════════════════════════════════════════════════════════
# PYDANTIC SCHEMAS
# ═══════════════════════════════════════════════════════════
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str = DEFAULT_MODEL
messages: List[Message]
temperature: float = Field(default=1.0, ge=0.0, le=2.0)
max_tokens: Optional[int] = None
stream: bool = True
system: Optional[str] = None
# ═══════════════════════════════════════════════════════════
# MANAGED PROVIDER SLOT
# ═══════════════════════════════════════════════════════════
class ManagedProvider:
def __init__(self, slot_id: int):
self.slot_id = slot_id
self.provider: Optional[CloudflareProvider] = None
self.busy = False
self.born_at = 0.0
self.error_count = 0
self.request_count = 0
self.last_error = ""
def is_healthy(self) -> bool:
if self.provider is None:
return False
try:
return (
self.provider._on
and self.provider._transport is not None
and self.provider._transport.alive
)
except Exception:
return False
def close(self):
p = self.provider
self.provider = None
if p:
try:
p.close()
except Exception:
pass
def __repr__(self):
state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead")
mode = self.provider._mode if self.provider else "none"
return f"<Slot#{self.slot_id} {state} mode={mode!r} reqs={self.request_count}>"
# ═══════════════════════════════════════════════════════════
# PROVIDER POOL
# ═══════════════════════════════════════════════════════════
class ProviderPool:
def __init__(self, size: int = 2):
self.size = size
self._slots: List[ManagedProvider] = []
self._queue: asyncio.Queue = None
self._loop: asyncio.AbstractEventLoop = None
# ─── Startup ──────────────────────────────────────────
async def initialize(self):
self._loop = asyncio.get_event_loop()
self._queue = asyncio.Queue(maxsize=self.size)
log.info(f"πŸš€ Initializing provider pool (slots={self.size})")
log.info(f" DISPLAY={os.environ.get('DISPLAY', 'NOT SET')}")
log.info(f" XVFB_EXTERNAL={os.environ.get('XVFB_EXTERNAL', '0')}")
log.info(f" VR_DISPLAY={os.environ.get('VR_DISPLAY', '0')}")
results = await asyncio.gather(
*[self._spawn_slot_with_retry(i) for i in range(self.size)],
return_exceptions=True,
)
ok = sum(1 for r in results if not isinstance(r, Exception))
fail = sum(1 for r in results if isinstance(r, Exception))
if fail:
for i, r in enumerate(results):
if isinstance(r, Exception):
log.error(f" [S{i}] FAILED: {r}")
log.info(f" Pool ready β€” {ok}/{self.size} slots healthy")
if ok == 0:
raise RuntimeError(
f"All {self.size} provider slots failed to connect.\n"
f" β†’ Check DISPLAY / XVFB_EXTERNAL environment variables.\n"
f" β†’ Ensure entrypoint.sh started Xvfb before the server.\n"
f" β†’ Check network connectivity to playground.ai.cloudflare.com."
)
async def _spawn_slot_with_retry(self, slot_id: int) -> "ManagedProvider":
"""Try to create a slot, retrying up to SLOT_RETRIES times."""
managed = ManagedProvider(slot_id)
for attempt in range(1, SLOT_RETRIES + 1):
try:
log.info(f" [S{slot_id}] Connecting... (attempt {attempt}/{SLOT_RETRIES})")
def _create():
return CloudflareProvider(
model = DEFAULT_MODEL,
system = DEFAULT_SYSTEM,
debug = True, # verbose during init so we can see failures
use_cache = True,
)
managed.provider = await asyncio.wait_for(
self._loop.run_in_executor(None, _create),
timeout=180,
)
managed.provider.debug = False # quiet after successful boot
managed.born_at = time.time()
self._slots.append(managed)
await self._queue.put(managed)
mode = managed.provider._mode
log.info(f" [S{slot_id}] βœ“ Ready mode={mode!r}")
return managed
except asyncio.TimeoutError:
err = f"Slot {slot_id} timed out (attempt {attempt})"
log.warning(f" [S{slot_id}] ⚠ {err}")
managed.last_error = err
managed.close()
except Exception as exc:
err = str(exc)
# Print full traceback for debugging
log.warning(
f" [S{slot_id}] ⚠ Attempt {attempt} failed:\n"
+ traceback.format_exc()
)
managed.last_error = err
managed.close()
if attempt < SLOT_RETRIES:
log.info(f" [S{slot_id}] Retrying in {SLOT_RETRY_WAIT}s...")
await asyncio.sleep(SLOT_RETRY_WAIT)
raise RuntimeError(
f"Slot {slot_id} failed after {SLOT_RETRIES} attempts. "
f"Last error: {managed.last_error}"
)
# ─── Acquire ──────────────────────────────────────────
@asynccontextmanager
async def acquire(self):
managed: ManagedProvider = await asyncio.wait_for(
self._queue.get(),
timeout=ACQUIRE_TIMEOUT,
)
managed.busy = True
try:
if not managed.is_healthy():
log.warning(f"[S{managed.slot_id}] Unhealthy at checkout β€” healing now")
await self._heal(managed)
managed.request_count += 1
yield managed.provider
except Exception:
managed.error_count += 1
raise
finally:
managed.busy = False
if managed.is_healthy():
await self._queue.put(managed)
else:
log.warning(f"[S{managed.slot_id}] Dead after use β€” background heal")
asyncio.create_task(self._heal_then_return(managed))
# ─── Healing ──────────────────────────────────────────
async def _heal(self, managed: ManagedProvider):
sid = managed.slot_id
log.info(f"[S{sid}] Healing slot...")
def _recreate():
managed.close()
return CloudflareProvider(
model = DEFAULT_MODEL,
system = DEFAULT_SYSTEM,
debug = True,
use_cache = True,
)
try:
managed.provider = await asyncio.wait_for(
self._loop.run_in_executor(None, _recreate),
timeout=180,
)
managed.provider.debug = False
managed.born_at = time.time()
managed.error_count = 0
managed.last_error = ""
log.info(f"[S{sid}] βœ“ Healed mode={managed.provider._mode!r}")
except Exception as e:
managed.last_error = str(e)
log.error(f"[S{sid}] βœ— Heal failed: {e}\n{traceback.format_exc()}")
raise
async def _heal_then_return(self, managed: ManagedProvider):
sid = managed.slot_id
for attempt in range(1, SLOT_RETRIES + 1):
try:
await self._heal(managed)
await self._queue.put(managed)
return
except Exception as e:
log.warning(f"[S{sid}] Heal attempt {attempt}/{SLOT_RETRIES} failed: {e}")
if attempt < SLOT_RETRIES:
await asyncio.sleep(SLOT_RETRY_WAIT)
# Last resort: put it back anyway so queue doesn't shrink permanently
log.error(f"[S{sid}] All heal attempts failed β€” slot may be non-functional")
await self._queue.put(managed)
# ─── Health monitor ───────────────────────────────────
async def health_monitor(self):
while True:
await asyncio.sleep(HEALTH_INTERVAL)
healthy = sum(1 for m in self._slots if m.is_healthy())
busy = sum(1 for m in self._slots if m.busy)
log.info(
f"β™₯ Pool β€” {healthy}/{self.size} healthy "
f"{busy} busy queue={self._queue.qsize()}"
)
for managed in list(self._slots):
if not managed.busy and not managed.is_healthy():
log.warning(f"[S{managed.slot_id}] Idle+dead β€” healing in background")
asyncio.create_task(self._heal_then_return(managed))
# ─── Status ───────────────────────────────────────────
@property
def status(self) -> dict:
return {
"pool_size": self.size,
"queue_free": self._queue.qsize() if self._queue else 0,
"slots": [
{
"id": m.slot_id,
"healthy": m.is_healthy(),
"busy": m.busy,
"mode": m.provider._mode if m.provider else "none",
"errors": m.error_count,
"requests": m.request_count,
"age_s": round(time.time() - m.born_at, 1) if m.born_at else 0,
"last_error": m.last_error or None,
}
for m in self._slots
],
}
# ─── Shutdown ─────────────────────────────────────────
async def shutdown(self):
log.info("Shutting down provider pool...")
for m in self._slots:
m.close()
log.info("Pool shut down.")
# ═══════════════════════════════════════════════════════════
# GLOBAL POOL
# ═══════════════════════════════════════════════════════════
pool: ProviderPool = None
# ═══════════════════════════════════════════════════════════
# LIFESPAN
# ═══════════════════════════════════════════════════════════
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
pool = ProviderPool(size=POOL_SIZE)
await pool.initialize()
monitor = asyncio.create_task(pool.health_monitor())
log.info(f"βœ… Server ready {HOST}:{PORT}")
yield
monitor.cancel()
try:
await monitor
except asyncio.CancelledError:
pass
await pool.shutdown()
# ═══════════════════════════════════════════════════════════
# APP
# ═══════════════════════════════════════════════════════════
app = FastAPI(
title = "Cloudflare AI API",
description = "OpenAI-compatible API via Cloudflare AI Playground",
version = "1.1.0",
lifespan = lifespan,
docs_url = "/docs",
redoc_url = "/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_methods = ["*"],
allow_headers = ["*"],
)
# ═══════════════════════════════════════════════════════════
# SSE HELPERS
# ═══════════════════════════════════════════════════════════
def _sse_chunk(content: str, model: str, cid: str) -> str:
return "data: " + json.dumps({
"id": cid,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
}, ensure_ascii=False) + "\n\n"
def _sse_done(model: str, cid: str) -> str:
return "data: " + json.dumps({
"id": cid,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}) + "\n\ndata: [DONE]\n\n"
def _sse_error(msg: str) -> str:
return f'data: {{"error": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n'
async def _stream_generator(
provider: CloudflareProvider,
req: ChatRequest,
) -> AsyncGenerator[str, None]:
loop = asyncio.get_event_loop()
q: asyncio.Queue = asyncio.Queue(maxsize=512)
cid = f"chatcmpl-{uuid.uuid4().hex[:20]}"
cancel = threading.Event()
messages = [{"role": m.role, "content": m.content} for m in req.messages]
kwargs = {
"messages": messages,
"temperature": req.temperature,
"model": req.model,
}
if req.max_tokens:
kwargs["max_tokens"] = req.max_tokens
if req.system:
kwargs["system"] = req.system
def _worker():
try:
for chunk in provider.chat(**kwargs):
if cancel.is_set():
break
fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop)
fut.result(timeout=10)
except Exception as exc:
err = RuntimeError(str(exc))
asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5)
finally:
asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5)
t = threading.Thread(target=_worker, daemon=True)
t.start()
try:
while True:
item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT)
if item is None:
yield _sse_done(req.model, cid)
break
if isinstance(item, Exception):
yield _sse_error(str(item))
break
if item:
yield _sse_chunk(item, req.model, cid)
except asyncio.TimeoutError:
cancel.set()
yield _sse_error("Stream timed out")
finally:
cancel.set()
t.join(timeout=5)
# ═══════════════════════════════════════════════════════════
# ENDPOINTS
# ═══════════════════════════════════════════════════════════
@app.get("/", tags=["Info"])
async def root():
return {
"service": "Cloudflare AI API",
"version": "1.1.0",
"status": "running",
"display": os.environ.get("DISPLAY", "not set"),
"endpoints": {
"chat": "POST /v1/chat/completions",
"models": "GET /v1/models",
"health": "GET /health",
"docs": "GET /docs",
},
}
@app.get("/health", tags=["Info"])
async def health():
if pool is None:
raise HTTPException(503, detail="Pool not initialized")
healthy = sum(1 for m in pool._slots if m.is_healthy())
status = "ok" if healthy > 0 else "degraded"
return JSONResponse(
content={"status": status, "pool": pool.status},
status_code=200 if status == "ok" else 206,
)
@app.get("/v1/models", tags=["Models"])
async def list_models():
if pool is None:
raise HTTPException(503, detail="Pool not initialized")
async with pool.acquire() as provider:
models = await asyncio.get_event_loop().run_in_executor(
None, provider.list_models
)
return {
"object": "list",
"data": [
{
"id": m["name"],
"object": "model",
"created": 0,
"owned_by": "cloudflare",
"context_window": m.get("context", 4096),
}
for m in models
],
}
@app.post("/v1/chat/completions", tags=["Chat"])
async def chat_completions(req: ChatRequest, request: Request):
if pool is None:
raise HTTPException(503, detail="Pool not initialized")
if not req.messages:
raise HTTPException(400, detail="`messages` must not be empty")
# ── Streaming ──────────────────────────────────────────
if req.stream:
async def _gen():
async with pool.acquire() as provider:
async for chunk in _stream_generator(provider, req):
if await request.is_disconnected():
break
yield chunk
return StreamingResponse(
_gen(),
media_type = "text/event-stream",
headers = {
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
# ── Non-streaming ──────────────────────────────────────
messages = [{"role": m.role, "content": m.content} for m in req.messages]
kwargs = {
"messages": messages,
"temperature": req.temperature,
"model": req.model,
}
if req.max_tokens:
kwargs["max_tokens"] = req.max_tokens
if req.system:
kwargs["system"] = req.system
loop = asyncio.get_event_loop()
full_parts: list[str] = []
async with pool.acquire() as provider:
def _collect():
for chunk in provider.chat(**kwargs):
full_parts.append(chunk)
await asyncio.wait_for(
loop.run_in_executor(None, _collect),
timeout=STREAM_TIMEOUT,
)
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:20]}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "".join(full_parts)},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
# ═══════════════════════════════════════════════════════════
# ENTRY POINT
# ═══════════════════════════════════════════════════════════
if __name__ == "__main__":
uvicorn.run(
"server:app",
host = HOST,
port = PORT,
log_level = "info",
workers = 1,
loop = "asyncio",
timeout_keep_alive = 30,
)