| """ |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β server.py β Cloudflare AI REST API β |
| β β |
| β OpenAI-compatible endpoints: β |
| β POST /v1/chat/completions (streaming + non-streaming) β |
| β GET /v1/models β |
| β GET /health β |
| β GET / β |
| β β |
| β Pool startup: up to 3 retries per slot, logs exact errors. β |
| β Health monitor: heals dead idle slots every 60s. β |
| β SSE: threadβasyncio bridge with backpressure. β |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import sys |
| import threading |
| import time |
| import traceback |
| import uuid |
| from contextlib import asynccontextmanager |
| from typing import AsyncGenerator, List, Optional |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from cloudflare_provider import CloudflareProvider |
|
|
| |
| |
| |
| logging.basicConfig( |
| level = logging.INFO, |
| format = "%(asctime)s %(levelname)-8s %(message)s", |
| stream = sys.stdout, |
| datefmt = "%H:%M:%S", |
| ) |
| log = logging.getLogger("cf-api") |
|
|
|
|
| |
| |
| |
| POOL_SIZE = int(os.getenv("POOL_SIZE", "2")) |
| PORT = int(os.getenv("PORT", "7860")) |
| HOST = os.getenv("HOST", "0.0.0.0") |
| HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60")) |
| ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "90")) |
| STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120")) |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5") |
| DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.") |
| SLOT_RETRIES = int(os.getenv("SLOT_RETRIES", "3")) |
| SLOT_RETRY_WAIT = int(os.getenv("SLOT_RETRY_WAIT", "10")) |
|
|
|
|
| |
| |
| |
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
| class ChatRequest(BaseModel): |
| model: str = DEFAULT_MODEL |
| messages: List[Message] |
| temperature: float = Field(default=1.0, ge=0.0, le=2.0) |
| max_tokens: Optional[int] = None |
| stream: bool = True |
| system: Optional[str] = None |
|
|
|
|
| |
| |
| |
| class ManagedProvider: |
| def __init__(self, slot_id: int): |
| self.slot_id = slot_id |
| self.provider: Optional[CloudflareProvider] = None |
| self.busy = False |
| self.born_at = 0.0 |
| self.error_count = 0 |
| self.request_count = 0 |
| self.last_error = "" |
|
|
| def is_healthy(self) -> bool: |
| if self.provider is None: |
| return False |
| try: |
| return ( |
| self.provider._on |
| and self.provider._transport is not None |
| and self.provider._transport.alive |
| ) |
| except Exception: |
| return False |
|
|
| def close(self): |
| p = self.provider |
| self.provider = None |
| if p: |
| try: |
| p.close() |
| except Exception: |
| pass |
|
|
| def __repr__(self): |
| state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead") |
| mode = self.provider._mode if self.provider else "none" |
| return f"<Slot#{self.slot_id} {state} mode={mode!r} reqs={self.request_count}>" |
|
|
|
|
| |
| |
| |
| class ProviderPool: |
| def __init__(self, size: int = 2): |
| self.size = size |
| self._slots: List[ManagedProvider] = [] |
| self._queue: asyncio.Queue = None |
| self._loop: asyncio.AbstractEventLoop = None |
|
|
| |
| async def initialize(self): |
| self._loop = asyncio.get_event_loop() |
| self._queue = asyncio.Queue(maxsize=self.size) |
|
|
| log.info(f"π Initializing provider pool (slots={self.size})") |
| log.info(f" DISPLAY={os.environ.get('DISPLAY', 'NOT SET')}") |
| log.info(f" XVFB_EXTERNAL={os.environ.get('XVFB_EXTERNAL', '0')}") |
| log.info(f" VR_DISPLAY={os.environ.get('VR_DISPLAY', '0')}") |
|
|
| results = await asyncio.gather( |
| *[self._spawn_slot_with_retry(i) for i in range(self.size)], |
| return_exceptions=True, |
| ) |
|
|
| ok = sum(1 for r in results if not isinstance(r, Exception)) |
| fail = sum(1 for r in results if isinstance(r, Exception)) |
|
|
| if fail: |
| for i, r in enumerate(results): |
| if isinstance(r, Exception): |
| log.error(f" [S{i}] FAILED: {r}") |
|
|
| log.info(f" Pool ready β {ok}/{self.size} slots healthy") |
|
|
| if ok == 0: |
| raise RuntimeError( |
| f"All {self.size} provider slots failed to connect.\n" |
| f" β Check DISPLAY / XVFB_EXTERNAL environment variables.\n" |
| f" β Ensure entrypoint.sh started Xvfb before the server.\n" |
| f" β Check network connectivity to playground.ai.cloudflare.com." |
| ) |
|
|
| async def _spawn_slot_with_retry(self, slot_id: int) -> "ManagedProvider": |
| """Try to create a slot, retrying up to SLOT_RETRIES times.""" |
| managed = ManagedProvider(slot_id) |
|
|
| for attempt in range(1, SLOT_RETRIES + 1): |
| try: |
| log.info(f" [S{slot_id}] Connecting... (attempt {attempt}/{SLOT_RETRIES})") |
|
|
| def _create(): |
| return CloudflareProvider( |
| model = DEFAULT_MODEL, |
| system = DEFAULT_SYSTEM, |
| debug = True, |
| use_cache = True, |
| ) |
|
|
| managed.provider = await asyncio.wait_for( |
| self._loop.run_in_executor(None, _create), |
| timeout=180, |
| ) |
| managed.provider.debug = False |
| managed.born_at = time.time() |
|
|
| self._slots.append(managed) |
| await self._queue.put(managed) |
|
|
| mode = managed.provider._mode |
| log.info(f" [S{slot_id}] β Ready mode={mode!r}") |
| return managed |
|
|
| except asyncio.TimeoutError: |
| err = f"Slot {slot_id} timed out (attempt {attempt})" |
| log.warning(f" [S{slot_id}] β {err}") |
| managed.last_error = err |
| managed.close() |
|
|
| except Exception as exc: |
| err = str(exc) |
| |
| log.warning( |
| f" [S{slot_id}] β Attempt {attempt} failed:\n" |
| + traceback.format_exc() |
| ) |
| managed.last_error = err |
| managed.close() |
|
|
| if attempt < SLOT_RETRIES: |
| log.info(f" [S{slot_id}] Retrying in {SLOT_RETRY_WAIT}s...") |
| await asyncio.sleep(SLOT_RETRY_WAIT) |
|
|
| raise RuntimeError( |
| f"Slot {slot_id} failed after {SLOT_RETRIES} attempts. " |
| f"Last error: {managed.last_error}" |
| ) |
|
|
| |
| @asynccontextmanager |
| async def acquire(self): |
| managed: ManagedProvider = await asyncio.wait_for( |
| self._queue.get(), |
| timeout=ACQUIRE_TIMEOUT, |
| ) |
| managed.busy = True |
|
|
| try: |
| if not managed.is_healthy(): |
| log.warning(f"[S{managed.slot_id}] Unhealthy at checkout β healing now") |
| await self._heal(managed) |
|
|
| managed.request_count += 1 |
| yield managed.provider |
|
|
| except Exception: |
| managed.error_count += 1 |
| raise |
|
|
| finally: |
| managed.busy = False |
| if managed.is_healthy(): |
| await self._queue.put(managed) |
| else: |
| log.warning(f"[S{managed.slot_id}] Dead after use β background heal") |
| asyncio.create_task(self._heal_then_return(managed)) |
|
|
| |
| async def _heal(self, managed: ManagedProvider): |
| sid = managed.slot_id |
| log.info(f"[S{sid}] Healing slot...") |
|
|
| def _recreate(): |
| managed.close() |
| return CloudflareProvider( |
| model = DEFAULT_MODEL, |
| system = DEFAULT_SYSTEM, |
| debug = True, |
| use_cache = True, |
| ) |
|
|
| try: |
| managed.provider = await asyncio.wait_for( |
| self._loop.run_in_executor(None, _recreate), |
| timeout=180, |
| ) |
| managed.provider.debug = False |
| managed.born_at = time.time() |
| managed.error_count = 0 |
| managed.last_error = "" |
| log.info(f"[S{sid}] β Healed mode={managed.provider._mode!r}") |
| except Exception as e: |
| managed.last_error = str(e) |
| log.error(f"[S{sid}] β Heal failed: {e}\n{traceback.format_exc()}") |
| raise |
|
|
| async def _heal_then_return(self, managed: ManagedProvider): |
| sid = managed.slot_id |
| for attempt in range(1, SLOT_RETRIES + 1): |
| try: |
| await self._heal(managed) |
| await self._queue.put(managed) |
| return |
| except Exception as e: |
| log.warning(f"[S{sid}] Heal attempt {attempt}/{SLOT_RETRIES} failed: {e}") |
| if attempt < SLOT_RETRIES: |
| await asyncio.sleep(SLOT_RETRY_WAIT) |
|
|
| |
| log.error(f"[S{sid}] All heal attempts failed β slot may be non-functional") |
| await self._queue.put(managed) |
|
|
| |
| async def health_monitor(self): |
| while True: |
| await asyncio.sleep(HEALTH_INTERVAL) |
| healthy = sum(1 for m in self._slots if m.is_healthy()) |
| busy = sum(1 for m in self._slots if m.busy) |
| log.info( |
| f"β₯ Pool β {healthy}/{self.size} healthy " |
| f"{busy} busy queue={self._queue.qsize()}" |
| ) |
| for managed in list(self._slots): |
| if not managed.busy and not managed.is_healthy(): |
| log.warning(f"[S{managed.slot_id}] Idle+dead β healing in background") |
| asyncio.create_task(self._heal_then_return(managed)) |
|
|
| |
| @property |
| def status(self) -> dict: |
| return { |
| "pool_size": self.size, |
| "queue_free": self._queue.qsize() if self._queue else 0, |
| "slots": [ |
| { |
| "id": m.slot_id, |
| "healthy": m.is_healthy(), |
| "busy": m.busy, |
| "mode": m.provider._mode if m.provider else "none", |
| "errors": m.error_count, |
| "requests": m.request_count, |
| "age_s": round(time.time() - m.born_at, 1) if m.born_at else 0, |
| "last_error": m.last_error or None, |
| } |
| for m in self._slots |
| ], |
| } |
|
|
| |
| async def shutdown(self): |
| log.info("Shutting down provider pool...") |
| for m in self._slots: |
| m.close() |
| log.info("Pool shut down.") |
|
|
|
|
| |
| |
| |
| pool: ProviderPool = None |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global pool |
| pool = ProviderPool(size=POOL_SIZE) |
| await pool.initialize() |
|
|
| monitor = asyncio.create_task(pool.health_monitor()) |
| log.info(f"β
Server ready {HOST}:{PORT}") |
|
|
| yield |
|
|
| monitor.cancel() |
| try: |
| await monitor |
| except asyncio.CancelledError: |
| pass |
| await pool.shutdown() |
|
|
|
|
| |
| |
| |
| app = FastAPI( |
| title = "Cloudflare AI API", |
| description = "OpenAI-compatible API via Cloudflare AI Playground", |
| version = "1.1.0", |
| lifespan = lifespan, |
| docs_url = "/docs", |
| redoc_url = "/redoc", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins = ["*"], |
| allow_methods = ["*"], |
| allow_headers = ["*"], |
| ) |
|
|
|
|
| |
| |
| |
| def _sse_chunk(content: str, model: str, cid: str) -> str: |
| return "data: " + json.dumps({ |
| "id": cid, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}], |
| }, ensure_ascii=False) + "\n\n" |
|
|
| def _sse_done(model: str, cid: str) -> str: |
| return "data: " + json.dumps({ |
| "id": cid, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
| }) + "\n\ndata: [DONE]\n\n" |
|
|
| def _sse_error(msg: str) -> str: |
| return f'data: {{"error": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n' |
|
|
|
|
| async def _stream_generator( |
| provider: CloudflareProvider, |
| req: ChatRequest, |
| ) -> AsyncGenerator[str, None]: |
| loop = asyncio.get_event_loop() |
| q: asyncio.Queue = asyncio.Queue(maxsize=512) |
| cid = f"chatcmpl-{uuid.uuid4().hex[:20]}" |
| cancel = threading.Event() |
|
|
| messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| kwargs = { |
| "messages": messages, |
| "temperature": req.temperature, |
| "model": req.model, |
| } |
| if req.max_tokens: |
| kwargs["max_tokens"] = req.max_tokens |
| if req.system: |
| kwargs["system"] = req.system |
|
|
| def _worker(): |
| try: |
| for chunk in provider.chat(**kwargs): |
| if cancel.is_set(): |
| break |
| fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop) |
| fut.result(timeout=10) |
| except Exception as exc: |
| err = RuntimeError(str(exc)) |
| asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5) |
| finally: |
| asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5) |
|
|
| t = threading.Thread(target=_worker, daemon=True) |
| t.start() |
|
|
| try: |
| while True: |
| item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT) |
|
|
| if item is None: |
| yield _sse_done(req.model, cid) |
| break |
|
|
| if isinstance(item, Exception): |
| yield _sse_error(str(item)) |
| break |
|
|
| if item: |
| yield _sse_chunk(item, req.model, cid) |
|
|
| except asyncio.TimeoutError: |
| cancel.set() |
| yield _sse_error("Stream timed out") |
|
|
| finally: |
| cancel.set() |
| t.join(timeout=5) |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/", tags=["Info"]) |
| async def root(): |
| return { |
| "service": "Cloudflare AI API", |
| "version": "1.1.0", |
| "status": "running", |
| "display": os.environ.get("DISPLAY", "not set"), |
| "endpoints": { |
| "chat": "POST /v1/chat/completions", |
| "models": "GET /v1/models", |
| "health": "GET /health", |
| "docs": "GET /docs", |
| }, |
| } |
|
|
|
|
| @app.get("/health", tags=["Info"]) |
| async def health(): |
| if pool is None: |
| raise HTTPException(503, detail="Pool not initialized") |
|
|
| healthy = sum(1 for m in pool._slots if m.is_healthy()) |
| status = "ok" if healthy > 0 else "degraded" |
|
|
| return JSONResponse( |
| content={"status": status, "pool": pool.status}, |
| status_code=200 if status == "ok" else 206, |
| ) |
|
|
|
|
| @app.get("/v1/models", tags=["Models"]) |
| async def list_models(): |
| if pool is None: |
| raise HTTPException(503, detail="Pool not initialized") |
|
|
| async with pool.acquire() as provider: |
| models = await asyncio.get_event_loop().run_in_executor( |
| None, provider.list_models |
| ) |
|
|
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": m["name"], |
| "object": "model", |
| "created": 0, |
| "owned_by": "cloudflare", |
| "context_window": m.get("context", 4096), |
| } |
| for m in models |
| ], |
| } |
|
|
|
|
| @app.post("/v1/chat/completions", tags=["Chat"]) |
| async def chat_completions(req: ChatRequest, request: Request): |
| if pool is None: |
| raise HTTPException(503, detail="Pool not initialized") |
| if not req.messages: |
| raise HTTPException(400, detail="`messages` must not be empty") |
|
|
| |
| if req.stream: |
| async def _gen(): |
| async with pool.acquire() as provider: |
| async for chunk in _stream_generator(provider, req): |
| if await request.is_disconnected(): |
| break |
| yield chunk |
|
|
| return StreamingResponse( |
| _gen(), |
| media_type = "text/event-stream", |
| headers = { |
| "Cache-Control": "no-cache", |
| "X-Accel-Buffering": "no", |
| "Connection": "keep-alive", |
| }, |
| ) |
|
|
| |
| messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| kwargs = { |
| "messages": messages, |
| "temperature": req.temperature, |
| "model": req.model, |
| } |
| if req.max_tokens: |
| kwargs["max_tokens"] = req.max_tokens |
| if req.system: |
| kwargs["system"] = req.system |
|
|
| loop = asyncio.get_event_loop() |
| full_parts: list[str] = [] |
|
|
| async with pool.acquire() as provider: |
| def _collect(): |
| for chunk in provider.chat(**kwargs): |
| full_parts.append(chunk) |
|
|
| await asyncio.wait_for( |
| loop.run_in_executor(None, _collect), |
| timeout=STREAM_TIMEOUT, |
| ) |
|
|
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex[:20]}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": req.model, |
| "choices": [{ |
| "index": 0, |
| "message": {"role": "assistant", "content": "".join(full_parts)}, |
| "finish_reason": "stop", |
| }], |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, |
| } |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| uvicorn.run( |
| "server:app", |
| host = HOST, |
| port = PORT, |
| log_level = "info", |
| workers = 1, |
| loop = "asyncio", |
| timeout_keep_alive = 30, |
| ) |