Update server.py
Browse files
server.py
CHANGED
|
@@ -8,11 +8,9 @@
|
|
| 8 |
β GET /health β
|
| 9 |
β GET / β
|
| 10 |
β β
|
| 11 |
-
β
|
| 12 |
-
β
|
| 13 |
-
β
|
| 14 |
-
β β’ HealthMonitor β periodic background probe + heal β
|
| 15 |
-
β β’ SSE streaming β threadβasyncio bridge via Queue β
|
| 16 |
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
"""
|
| 18 |
|
|
@@ -23,6 +21,7 @@ import os
|
|
| 23 |
import sys
|
| 24 |
import threading
|
| 25 |
import time
|
|
|
|
| 26 |
import uuid
|
| 27 |
from contextlib import asynccontextmanager
|
| 28 |
from typing import AsyncGenerator, List, Optional
|
|
@@ -33,7 +32,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 33 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 34 |
from pydantic import BaseModel, Field
|
| 35 |
|
| 36 |
-
# βββ Import provider ββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
from cloudflare_provider import CloudflareProvider
|
| 39 |
|
|
@@ -41,24 +39,27 @@ from cloudflare_provider import CloudflareProvider
|
|
| 41 |
# LOGGING
|
| 42 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
logging.basicConfig(
|
| 44 |
-
level=logging.INFO,
|
| 45 |
-
format="%(asctime)s %(levelname)-8s %(message)s",
|
| 46 |
-
stream=sys.stdout,
|
| 47 |
-
datefmt="%H:%M:%S",
|
| 48 |
)
|
| 49 |
log = logging.getLogger("cf-api")
|
| 50 |
|
|
|
|
| 51 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
-
# CONFIG
|
| 53 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
-
POOL_SIZE
|
| 55 |
-
PORT
|
| 56 |
-
HOST
|
| 57 |
-
HEALTH_INTERVAL
|
| 58 |
-
ACQUIRE_TIMEOUT
|
| 59 |
-
STREAM_TIMEOUT
|
| 60 |
-
DEFAULT_MODEL
|
| 61 |
-
DEFAULT_SYSTEM
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -69,42 +70,27 @@ class Message(BaseModel):
|
|
| 69 |
content: str
|
| 70 |
|
| 71 |
class ChatRequest(BaseModel):
|
| 72 |
-
model: str
|
| 73 |
messages: List[Message]
|
| 74 |
-
temperature: float
|
| 75 |
-
max_tokens: Optional[int]
|
| 76 |
-
stream: bool
|
| 77 |
-
system: Optional[str]
|
| 78 |
-
|
| 79 |
-
class CompletionChoice(BaseModel):
|
| 80 |
-
index: int
|
| 81 |
-
message: dict
|
| 82 |
-
finish_reason: str
|
| 83 |
-
|
| 84 |
-
class CompletionResponse(BaseModel):
|
| 85 |
-
id: str
|
| 86 |
-
object: str
|
| 87 |
-
created: int
|
| 88 |
-
model: str
|
| 89 |
-
choices: List[CompletionChoice]
|
| 90 |
-
usage: dict
|
| 91 |
|
| 92 |
|
| 93 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 94 |
-
# MANAGED PROVIDER
|
| 95 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
class ManagedProvider:
|
| 97 |
-
"""A single pool slot wrapping one CloudflareProvider instance."""
|
| 98 |
-
|
| 99 |
def __init__(self, slot_id: int):
|
| 100 |
-
self.slot_id
|
| 101 |
self.provider: Optional[CloudflareProvider] = None
|
| 102 |
-
self.busy
|
| 103 |
-
self.born_at
|
| 104 |
-
self.error_count
|
| 105 |
self.request_count = 0
|
|
|
|
| 106 |
|
| 107 |
-
# ββ Health ββββββββββββββββββββββββββββββββββββββ
|
| 108 |
def is_healthy(self) -> bool:
|
| 109 |
if self.provider is None:
|
| 110 |
return False
|
|
@@ -117,7 +103,6 @@ class ManagedProvider:
|
|
| 117 |
except Exception:
|
| 118 |
return False
|
| 119 |
|
| 120 |
-
# ββ Teardown ββββββββββββββββββββββββββββββββββββ
|
| 121 |
def close(self):
|
| 122 |
p = self.provider
|
| 123 |
self.provider = None
|
|
@@ -130,92 +115,119 @@ class ManagedProvider:
|
|
| 130 |
def __repr__(self):
|
| 131 |
state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead")
|
| 132 |
mode = self.provider._mode if self.provider else "none"
|
| 133 |
-
return
|
| 134 |
-
f"<Slot#{self.slot_id} {state} mode={mode!r} "
|
| 135 |
-
f"reqs={self.request_count} errs={self.error_count}>"
|
| 136 |
-
)
|
| 137 |
|
| 138 |
|
| 139 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
# PROVIDER POOL
|
| 141 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
class ProviderPool:
|
| 143 |
-
"""
|
| 144 |
-
Pre-warmed pool of CloudflareProvider connections.
|
| 145 |
-
|
| 146 |
-
β’ initialize() β create all slots at startup
|
| 147 |
-
β’ acquire() β async context manager; blocks until a free slot
|
| 148 |
-
β’ health_monitor β background coroutine; heals broken idle slots
|
| 149 |
-
β’ shutdown() β clean teardown
|
| 150 |
-
"""
|
| 151 |
-
|
| 152 |
def __init__(self, size: int = 2):
|
| 153 |
self.size = size
|
| 154 |
-
self._slots:
|
| 155 |
-
self._queue:
|
| 156 |
-
self._loop:
|
| 157 |
-
self._lock = asyncio.Lock()
|
| 158 |
|
| 159 |
-
# βββ Startup ββββββββββββββββββββββββββββββββββ
|
| 160 |
async def initialize(self):
|
| 161 |
self._loop = asyncio.get_event_loop()
|
| 162 |
self._queue = asyncio.Queue(maxsize=self.size)
|
| 163 |
|
| 164 |
log.info(f"π Initializing provider pool (slots={self.size})")
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
results = await asyncio.gather(
|
| 167 |
-
*[self.
|
| 168 |
return_exceptions=True,
|
| 169 |
)
|
| 170 |
|
| 171 |
-
ok
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
log.info(f" Pool ready β {ok}/{self.size} slots healthy")
|
| 173 |
|
| 174 |
if ok == 0:
|
| 175 |
raise RuntimeError(
|
| 176 |
-
"
|
|
|
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
-
async def
|
|
|
|
| 180 |
managed = ManagedProvider(slot_id)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
model = DEFAULT_MODEL,
|
| 186 |
-
system = DEFAULT_SYSTEM,
|
| 187 |
-
debug = False,
|
| 188 |
-
use_cache = True,
|
| 189 |
-
)
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
return managed
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
@asynccontextmanager
|
| 206 |
async def acquire(self):
|
| 207 |
-
"""Checkout a provider, yield it, return on exit (healing if needed)."""
|
| 208 |
managed: ManagedProvider = await asyncio.wait_for(
|
| 209 |
self._queue.get(),
|
| 210 |
timeout=ACQUIRE_TIMEOUT,
|
| 211 |
)
|
| 212 |
managed.busy = True
|
| 213 |
-
ok = True
|
| 214 |
|
| 215 |
try:
|
| 216 |
-
# Heal before handing out
|
| 217 |
if not managed.is_healthy():
|
| 218 |
-
log.warning(f"[S{managed.slot_id}] Unhealthy β healing
|
| 219 |
await self._heal(managed)
|
| 220 |
|
| 221 |
managed.request_count += 1
|
|
@@ -223,104 +235,98 @@ class ProviderPool:
|
|
| 223 |
|
| 224 |
except Exception:
|
| 225 |
managed.error_count += 1
|
| 226 |
-
ok = False
|
| 227 |
raise
|
| 228 |
|
| 229 |
finally:
|
| 230 |
managed.busy = False
|
| 231 |
-
# After use: return if healthy, else heal in background
|
| 232 |
if managed.is_healthy():
|
| 233 |
await self._queue.put(managed)
|
| 234 |
else:
|
| 235 |
-
log.warning(f"[S{managed.slot_id}]
|
| 236 |
asyncio.create_task(self._heal_then_return(managed))
|
| 237 |
|
| 238 |
-
# βββ Healing ββββββββββββββββββββββββββββββββββ
|
| 239 |
async def _heal(self, managed: ManagedProvider):
|
| 240 |
sid = managed.slot_id
|
|
|
|
| 241 |
|
| 242 |
-
def _recreate()
|
| 243 |
managed.close()
|
| 244 |
return CloudflareProvider(
|
| 245 |
model = DEFAULT_MODEL,
|
| 246 |
system = DEFAULT_SYSTEM,
|
| 247 |
-
debug =
|
| 248 |
use_cache = True,
|
| 249 |
)
|
| 250 |
|
| 251 |
-
managed.provider = await asyncio.wait_for(
|
| 252 |
-
self._loop.run_in_executor(None, _recreate),
|
| 253 |
-
timeout=180,
|
| 254 |
-
)
|
| 255 |
-
managed.born_at = time.time()
|
| 256 |
-
managed.error_count = 0
|
| 257 |
-
log.info(f"[S{sid}] β Healed mode={managed.provider._mode!r}")
|
| 258 |
-
|
| 259 |
-
async def _heal_then_return(self, managed: ManagedProvider):
|
| 260 |
try:
|
| 261 |
-
await
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
except Exception as e:
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
managed.close()
|
| 267 |
-
managed.provider = await asyncio.wait_for(
|
| 268 |
-
self._loop.run_in_executor(
|
| 269 |
-
None,
|
| 270 |
-
lambda: CloudflareProvider(
|
| 271 |
-
model=DEFAULT_MODEL, system=DEFAULT_SYSTEM,
|
| 272 |
-
debug=False, use_cache=True,
|
| 273 |
-
),
|
| 274 |
-
),
|
| 275 |
-
timeout=180,
|
| 276 |
-
)
|
| 277 |
-
managed.born_at = time.time()
|
| 278 |
-
managed.error_count = 0
|
| 279 |
-
log.info(f"[S{managed.slot_id}] β Cold-boot recovery succeeded")
|
| 280 |
-
except Exception as e2:
|
| 281 |
-
log.error(f"[S{managed.slot_id}] Cold-boot also failed: {e2}")
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
await self._queue.put(managed)
|
| 284 |
|
| 285 |
-
# βββ Health monitor βββββββββββββββββββββββββββ
|
| 286 |
async def health_monitor(self):
|
| 287 |
-
"""Periodic background coroutine β checks and heals idle slots."""
|
| 288 |
while True:
|
| 289 |
await asyncio.sleep(HEALTH_INTERVAL)
|
| 290 |
healthy = sum(1 for m in self._slots if m.is_healthy())
|
| 291 |
busy = sum(1 for m in self._slots if m.busy)
|
| 292 |
log.info(
|
| 293 |
-
f"β₯
|
| 294 |
-
f"{busy} busy
|
| 295 |
)
|
| 296 |
-
|
| 297 |
for managed in list(self._slots):
|
| 298 |
if not managed.busy and not managed.is_healthy():
|
| 299 |
-
log.warning(f"[S{managed.slot_id}] Idle
|
| 300 |
-
# Pull from queue if it's still there, otherwise skip
|
| 301 |
asyncio.create_task(self._heal_then_return(managed))
|
| 302 |
|
| 303 |
-
# βββ Status βββββββββββββββββββββββββββββββββββ
|
| 304 |
@property
|
| 305 |
def status(self) -> dict:
|
| 306 |
return {
|
| 307 |
-
"pool_size":
|
| 308 |
-
"queue_free":
|
| 309 |
"slots": [
|
| 310 |
{
|
| 311 |
-
"id":
|
| 312 |
-
"healthy":
|
| 313 |
-
"busy":
|
| 314 |
-
"mode":
|
| 315 |
-
"errors":
|
| 316 |
-
"requests":
|
| 317 |
-
"age_s":
|
|
|
|
| 318 |
}
|
| 319 |
for m in self._slots
|
| 320 |
],
|
| 321 |
}
|
| 322 |
|
| 323 |
-
# βββ Shutdown βββββββββββββββββββββββββββββββββ
|
| 324 |
async def shutdown(self):
|
| 325 |
log.info("Shutting down provider pool...")
|
| 326 |
for m in self._slots:
|
|
@@ -329,13 +335,13 @@ class ProviderPool:
|
|
| 329 |
|
| 330 |
|
| 331 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
-
# GLOBAL POOL
|
| 333 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 334 |
pool: ProviderPool = None
|
| 335 |
|
| 336 |
|
| 337 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 338 |
-
# LIFESPAN
|
| 339 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
@asynccontextmanager
|
| 341 |
async def lifespan(app: FastAPI):
|
|
@@ -344,7 +350,7 @@ async def lifespan(app: FastAPI):
|
|
| 344 |
await pool.initialize()
|
| 345 |
|
| 346 |
monitor = asyncio.create_task(pool.health_monitor())
|
| 347 |
-
log.info(f"β
Server ready
|
| 348 |
|
| 349 |
yield
|
| 350 |
|
|
@@ -357,15 +363,15 @@ async def lifespan(app: FastAPI):
|
|
| 357 |
|
| 358 |
|
| 359 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 360 |
-
#
|
| 361 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
app = FastAPI(
|
| 363 |
-
title
|
| 364 |
-
description
|
| 365 |
-
version
|
| 366 |
-
lifespan
|
| 367 |
-
docs_url
|
| 368 |
-
redoc_url
|
| 369 |
)
|
| 370 |
|
| 371 |
app.add_middleware(
|
|
@@ -377,85 +383,59 @@ app.add_middleware(
|
|
| 377 |
|
| 378 |
|
| 379 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 380 |
-
# SSE
|
| 381 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 382 |
-
def _sse_chunk(content: str, model: str,
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
"id": chunk_id,
|
| 386 |
"object": "chat.completion.chunk",
|
| 387 |
"created": int(time.time()),
|
| 388 |
"model": model,
|
| 389 |
-
"choices": [{
|
| 390 |
-
|
| 391 |
-
"delta": {"content": content},
|
| 392 |
-
"finish_reason": None,
|
| 393 |
-
}],
|
| 394 |
-
}
|
| 395 |
-
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
| 396 |
-
|
| 397 |
|
| 398 |
-
def _sse_done(model: str,
|
| 399 |
-
""
|
| 400 |
-
|
| 401 |
-
"id": chunk_id,
|
| 402 |
"object": "chat.completion.chunk",
|
| 403 |
"created": int(time.time()),
|
| 404 |
"model": model,
|
| 405 |
-
"choices": [{
|
| 406 |
-
|
| 407 |
-
"delta": {},
|
| 408 |
-
"finish_reason": "stop",
|
| 409 |
-
}],
|
| 410 |
-
}
|
| 411 |
-
return f"data: {json.dumps(payload)}\n\ndata: [DONE]\n\n"
|
| 412 |
-
|
| 413 |
|
| 414 |
def _sse_error(msg: str) -> str:
|
| 415 |
-
return f
|
| 416 |
|
| 417 |
|
| 418 |
async def _stream_generator(
|
| 419 |
provider: CloudflareProvider,
|
| 420 |
req: ChatRequest,
|
| 421 |
) -> AsyncGenerator[str, None]:
|
| 422 |
-
"""
|
| 423 |
-
Bridge between the synchronous provider.chat() generator and
|
| 424 |
-
FastAPI's async StreamingResponse.
|
| 425 |
-
|
| 426 |
-
Strategy:
|
| 427 |
-
1. Spin up a background thread that runs provider.chat() and
|
| 428 |
-
pushes each chunk into an asyncio.Queue.
|
| 429 |
-
2. Yield SSE-formatted chunks from the queue in the async loop.
|
| 430 |
-
"""
|
| 431 |
loop = asyncio.get_event_loop()
|
| 432 |
-
q:
|
| 433 |
-
|
| 434 |
cancel = threading.Event()
|
| 435 |
|
| 436 |
-
# Build kwargs for provider
|
| 437 |
messages = [{"role": m.role, "content": m.content} for m in req.messages]
|
| 438 |
-
kwargs
|
| 439 |
"messages": messages,
|
| 440 |
"temperature": req.temperature,
|
|
|
|
| 441 |
}
|
| 442 |
-
if req.model:
|
| 443 |
-
kwargs["model"] = req.model
|
| 444 |
if req.max_tokens:
|
| 445 |
kwargs["max_tokens"] = req.max_tokens
|
| 446 |
if req.system:
|
| 447 |
kwargs["system"] = req.system
|
| 448 |
|
| 449 |
-
# ββ Worker thread ββββββββββββββββββββββββββββ
|
| 450 |
def _worker():
|
| 451 |
try:
|
| 452 |
for chunk in provider.chat(**kwargs):
|
| 453 |
if cancel.is_set():
|
| 454 |
break
|
| 455 |
fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop)
|
| 456 |
-
fut.result(timeout=10)
|
| 457 |
except Exception as exc:
|
| 458 |
-
err = RuntimeError(
|
| 459 |
asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5)
|
| 460 |
finally:
|
| 461 |
asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5)
|
|
@@ -463,25 +443,24 @@ async def _stream_generator(
|
|
| 463 |
t = threading.Thread(target=_worker, daemon=True)
|
| 464 |
t.start()
|
| 465 |
|
| 466 |
-
# ββ Async consumer ββββββββββββββββββββββββββββ
|
| 467 |
try:
|
| 468 |
while True:
|
| 469 |
item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT)
|
| 470 |
|
| 471 |
-
if item is None:
|
| 472 |
-
yield _sse_done(req.model,
|
| 473 |
break
|
| 474 |
|
| 475 |
-
if isinstance(item, Exception):
|
| 476 |
yield _sse_error(str(item))
|
| 477 |
break
|
| 478 |
|
| 479 |
-
if item:
|
| 480 |
-
yield _sse_chunk(item, req.model,
|
| 481 |
|
| 482 |
except asyncio.TimeoutError:
|
| 483 |
cancel.set()
|
| 484 |
-
yield _sse_error("Stream timed out
|
| 485 |
|
| 486 |
finally:
|
| 487 |
cancel.set()
|
|
@@ -495,9 +474,10 @@ async def _stream_generator(
|
|
| 495 |
@app.get("/", tags=["Info"])
|
| 496 |
async def root():
|
| 497 |
return {
|
| 498 |
-
"service":
|
| 499 |
-
"version":
|
| 500 |
-
"status":
|
|
|
|
| 501 |
"endpoints": {
|
| 502 |
"chat": "POST /v1/chat/completions",
|
| 503 |
"models": "GET /v1/models",
|
|
@@ -510,9 +490,11 @@ async def root():
|
|
| 510 |
@app.get("/health", tags=["Info"])
|
| 511 |
async def health():
|
| 512 |
if pool is None:
|
| 513 |
-
raise HTTPException(503, detail="Pool not
|
|
|
|
| 514 |
healthy = sum(1 for m in pool._slots if m.is_healthy())
|
| 515 |
status = "ok" if healthy > 0 else "degraded"
|
|
|
|
| 516 |
return JSONResponse(
|
| 517 |
content={"status": status, "pool": pool.status},
|
| 518 |
status_code=200 if status == "ok" else 206,
|
|
@@ -533,10 +515,10 @@ async def list_models():
|
|
| 533 |
"object": "list",
|
| 534 |
"data": [
|
| 535 |
{
|
| 536 |
-
"id":
|
| 537 |
-
"object":
|
| 538 |
-
"created":
|
| 539 |
-
"owned_by":
|
| 540 |
"context_window": m.get("context", 4096),
|
| 541 |
}
|
| 542 |
for m in models
|
|
@@ -548,7 +530,6 @@ async def list_models():
|
|
| 548 |
async def chat_completions(req: ChatRequest, request: Request):
|
| 549 |
if pool is None:
|
| 550 |
raise HTTPException(503, detail="Pool not initialized")
|
| 551 |
-
|
| 552 |
if not req.messages:
|
| 553 |
raise HTTPException(400, detail="`messages` must not be empty")
|
| 554 |
|
|
@@ -557,7 +538,6 @@ async def chat_completions(req: ChatRequest, request: Request):
|
|
| 557 |
async def _gen():
|
| 558 |
async with pool.acquire() as provider:
|
| 559 |
async for chunk in _stream_generator(provider, req):
|
| 560 |
-
# Check if client disconnected
|
| 561 |
if await request.is_disconnected():
|
| 562 |
break
|
| 563 |
yield chunk
|
|
@@ -574,22 +554,20 @@ async def chat_completions(req: ChatRequest, request: Request):
|
|
| 574 |
|
| 575 |
# ββ Non-streaming ββββββββββββββββββββββββββββββββββββββ
|
| 576 |
messages = [{"role": m.role, "content": m.content} for m in req.messages]
|
| 577 |
-
kwargs
|
| 578 |
"messages": messages,
|
| 579 |
"temperature": req.temperature,
|
|
|
|
| 580 |
}
|
| 581 |
-
if req.model:
|
| 582 |
-
kwargs["model"] = req.model
|
| 583 |
if req.max_tokens:
|
| 584 |
kwargs["max_tokens"] = req.max_tokens
|
| 585 |
if req.system:
|
| 586 |
kwargs["system"] = req.system
|
| 587 |
|
| 588 |
-
loop
|
|
|
|
| 589 |
|
| 590 |
async with pool.acquire() as provider:
|
| 591 |
-
full_parts: list[str] = []
|
| 592 |
-
|
| 593 |
def _collect():
|
| 594 |
for chunk in provider.chat(**kwargs):
|
| 595 |
full_parts.append(chunk)
|
|
@@ -599,8 +577,6 @@ async def chat_completions(req: ChatRequest, request: Request):
|
|
| 599 |
timeout=STREAM_TIMEOUT,
|
| 600 |
)
|
| 601 |
|
| 602 |
-
response_text = "".join(full_parts)
|
| 603 |
-
|
| 604 |
return {
|
| 605 |
"id": f"chatcmpl-{uuid.uuid4().hex[:20]}",
|
| 606 |
"object": "chat.completion",
|
|
@@ -608,14 +584,10 @@ async def chat_completions(req: ChatRequest, request: Request):
|
|
| 608 |
"model": req.model,
|
| 609 |
"choices": [{
|
| 610 |
"index": 0,
|
| 611 |
-
"message": {"role": "assistant", "content":
|
| 612 |
"finish_reason": "stop",
|
| 613 |
}],
|
| 614 |
-
"usage": {
|
| 615 |
-
"prompt_tokens": 0,
|
| 616 |
-
"completion_tokens": 0,
|
| 617 |
-
"total_tokens": 0,
|
| 618 |
-
},
|
| 619 |
}
|
| 620 |
|
| 621 |
|
|
@@ -625,10 +597,10 @@ async def chat_completions(req: ChatRequest, request: Request):
|
|
| 625 |
if __name__ == "__main__":
|
| 626 |
uvicorn.run(
|
| 627 |
"server:app",
|
| 628 |
-
host
|
| 629 |
-
port
|
| 630 |
-
log_level
|
| 631 |
-
workers
|
| 632 |
-
loop
|
| 633 |
timeout_keep_alive = 30,
|
| 634 |
-
)
|
|
|
|
| 8 |
β GET /health β
|
| 9 |
β GET / β
|
| 10 |
β β
|
| 11 |
+
β Pool startup: up to 3 retries per slot, logs exact errors. β
|
| 12 |
+
β Health monitor: heals dead idle slots every 60s. β
|
| 13 |
+
β SSE: threadβasyncio bridge with backpressure. β
|
|
|
|
|
|
|
| 14 |
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
"""
|
| 16 |
|
|
|
|
| 21 |
import sys
|
| 22 |
import threading
|
| 23 |
import time
|
| 24 |
+
import traceback
|
| 25 |
import uuid
|
| 26 |
from contextlib import asynccontextmanager
|
| 27 |
from typing import AsyncGenerator, List, Optional
|
|
|
|
| 32 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 33 |
from pydantic import BaseModel, Field
|
| 34 |
|
|
|
|
| 35 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 36 |
from cloudflare_provider import CloudflareProvider
|
| 37 |
|
|
|
|
| 39 |
# LOGGING
|
| 40 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
logging.basicConfig(
|
| 42 |
+
level = logging.INFO,
|
| 43 |
+
format = "%(asctime)s %(levelname)-8s %(message)s",
|
| 44 |
+
stream = sys.stdout,
|
| 45 |
+
datefmt = "%H:%M:%S",
|
| 46 |
)
|
| 47 |
log = logging.getLogger("cf-api")
|
| 48 |
|
| 49 |
+
|
| 50 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
# CONFIG
|
| 52 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
POOL_SIZE = int(os.getenv("POOL_SIZE", "2"))
|
| 54 |
+
PORT = int(os.getenv("PORT", "7860"))
|
| 55 |
+
HOST = os.getenv("HOST", "0.0.0.0")
|
| 56 |
+
HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60"))
|
| 57 |
+
ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "90"))
|
| 58 |
+
STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120"))
|
| 59 |
+
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5")
|
| 60 |
+
DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.")
|
| 61 |
+
SLOT_RETRIES = int(os.getenv("SLOT_RETRIES", "3"))
|
| 62 |
+
SLOT_RETRY_WAIT = int(os.getenv("SLOT_RETRY_WAIT", "10")) # seconds between retries
|
| 63 |
|
| 64 |
|
| 65 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 70 |
content: str
|
| 71 |
|
| 72 |
class ChatRequest(BaseModel):
|
| 73 |
+
model: str = DEFAULT_MODEL
|
| 74 |
messages: List[Message]
|
| 75 |
+
temperature: float = Field(default=1.0, ge=0.0, le=2.0)
|
| 76 |
+
max_tokens: Optional[int] = None
|
| 77 |
+
stream: bool = True
|
| 78 |
+
system: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
# MANAGED PROVIDER SLOT
|
| 83 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
class ManagedProvider:
|
|
|
|
|
|
|
| 85 |
def __init__(self, slot_id: int):
|
| 86 |
+
self.slot_id = slot_id
|
| 87 |
self.provider: Optional[CloudflareProvider] = None
|
| 88 |
+
self.busy = False
|
| 89 |
+
self.born_at = 0.0
|
| 90 |
+
self.error_count = 0
|
| 91 |
self.request_count = 0
|
| 92 |
+
self.last_error = ""
|
| 93 |
|
|
|
|
| 94 |
def is_healthy(self) -> bool:
|
| 95 |
if self.provider is None:
|
| 96 |
return False
|
|
|
|
| 103 |
except Exception:
|
| 104 |
return False
|
| 105 |
|
|
|
|
| 106 |
def close(self):
|
| 107 |
p = self.provider
|
| 108 |
self.provider = None
|
|
|
|
| 115 |
def __repr__(self):
|
| 116 |
state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead")
|
| 117 |
mode = self.provider._mode if self.provider else "none"
|
| 118 |
+
return f"<Slot#{self.slot_id} {state} mode={mode!r} reqs={self.request_count}>"
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
# PROVIDER POOL
|
| 123 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
class ProviderPool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def __init__(self, size: int = 2):
|
| 126 |
self.size = size
|
| 127 |
+
self._slots: List[ManagedProvider] = []
|
| 128 |
+
self._queue: asyncio.Queue = None
|
| 129 |
+
self._loop: asyncio.AbstractEventLoop = None
|
|
|
|
| 130 |
|
| 131 |
+
# βββ Startup ββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
async def initialize(self):
|
| 133 |
self._loop = asyncio.get_event_loop()
|
| 134 |
self._queue = asyncio.Queue(maxsize=self.size)
|
| 135 |
|
| 136 |
log.info(f"π Initializing provider pool (slots={self.size})")
|
| 137 |
+
log.info(f" DISPLAY={os.environ.get('DISPLAY', 'NOT SET')}")
|
| 138 |
+
log.info(f" XVFB_EXTERNAL={os.environ.get('XVFB_EXTERNAL', '0')}")
|
| 139 |
+
log.info(f" VR_DISPLAY={os.environ.get('VR_DISPLAY', '0')}")
|
| 140 |
|
| 141 |
results = await asyncio.gather(
|
| 142 |
+
*[self._spawn_slot_with_retry(i) for i in range(self.size)],
|
| 143 |
return_exceptions=True,
|
| 144 |
)
|
| 145 |
|
| 146 |
+
ok = sum(1 for r in results if not isinstance(r, Exception))
|
| 147 |
+
fail = sum(1 for r in results if isinstance(r, Exception))
|
| 148 |
+
|
| 149 |
+
if fail:
|
| 150 |
+
for i, r in enumerate(results):
|
| 151 |
+
if isinstance(r, Exception):
|
| 152 |
+
log.error(f" [S{i}] FAILED: {r}")
|
| 153 |
+
|
| 154 |
log.info(f" Pool ready β {ok}/{self.size} slots healthy")
|
| 155 |
|
| 156 |
if ok == 0:
|
| 157 |
raise RuntimeError(
|
| 158 |
+
f"All {self.size} provider slots failed to connect.\n"
|
| 159 |
+
f" β Check DISPLAY / XVFB_EXTERNAL environment variables.\n"
|
| 160 |
+
f" β Ensure entrypoint.sh started Xvfb before the server.\n"
|
| 161 |
+
f" β Check network connectivity to playground.ai.cloudflare.com."
|
| 162 |
)
|
| 163 |
|
| 164 |
+
async def _spawn_slot_with_retry(self, slot_id: int) -> "ManagedProvider":
|
| 165 |
+
"""Try to create a slot, retrying up to SLOT_RETRIES times."""
|
| 166 |
managed = ManagedProvider(slot_id)
|
| 167 |
|
| 168 |
+
for attempt in range(1, SLOT_RETRIES + 1):
|
| 169 |
+
try:
|
| 170 |
+
log.info(f" [S{slot_id}] Connecting... (attempt {attempt}/{SLOT_RETRIES})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
def _create():
|
| 173 |
+
return CloudflareProvider(
|
| 174 |
+
model = DEFAULT_MODEL,
|
| 175 |
+
system = DEFAULT_SYSTEM,
|
| 176 |
+
debug = True, # verbose during init so we can see failures
|
| 177 |
+
use_cache = True,
|
| 178 |
+
)
|
| 179 |
|
| 180 |
+
managed.provider = await asyncio.wait_for(
|
| 181 |
+
self._loop.run_in_executor(None, _create),
|
| 182 |
+
timeout=180,
|
| 183 |
+
)
|
| 184 |
+
managed.provider.debug = False # quiet after successful boot
|
| 185 |
+
managed.born_at = time.time()
|
| 186 |
|
| 187 |
+
self._slots.append(managed)
|
| 188 |
+
await self._queue.put(managed)
|
|
|
|
| 189 |
|
| 190 |
+
mode = managed.provider._mode
|
| 191 |
+
log.info(f" [S{slot_id}] β Ready mode={mode!r}")
|
| 192 |
+
return managed
|
| 193 |
+
|
| 194 |
+
except asyncio.TimeoutError:
|
| 195 |
+
err = f"Slot {slot_id} timed out (attempt {attempt})"
|
| 196 |
+
log.warning(f" [S{slot_id}] β {err}")
|
| 197 |
+
managed.last_error = err
|
| 198 |
+
managed.close()
|
| 199 |
+
|
| 200 |
+
except Exception as exc:
|
| 201 |
+
err = str(exc)
|
| 202 |
+
# Print full traceback for debugging
|
| 203 |
+
log.warning(
|
| 204 |
+
f" [S{slot_id}] β Attempt {attempt} failed:\n"
|
| 205 |
+
+ traceback.format_exc()
|
| 206 |
+
)
|
| 207 |
+
managed.last_error = err
|
| 208 |
+
managed.close()
|
| 209 |
+
|
| 210 |
+
if attempt < SLOT_RETRIES:
|
| 211 |
+
log.info(f" [S{slot_id}] Retrying in {SLOT_RETRY_WAIT}s...")
|
| 212 |
+
await asyncio.sleep(SLOT_RETRY_WAIT)
|
| 213 |
+
|
| 214 |
+
raise RuntimeError(
|
| 215 |
+
f"Slot {slot_id} failed after {SLOT_RETRIES} attempts. "
|
| 216 |
+
f"Last error: {managed.last_error}"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# βββ Acquire ββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
@asynccontextmanager
|
| 221 |
async def acquire(self):
|
|
|
|
| 222 |
managed: ManagedProvider = await asyncio.wait_for(
|
| 223 |
self._queue.get(),
|
| 224 |
timeout=ACQUIRE_TIMEOUT,
|
| 225 |
)
|
| 226 |
managed.busy = True
|
|
|
|
| 227 |
|
| 228 |
try:
|
|
|
|
| 229 |
if not managed.is_healthy():
|
| 230 |
+
log.warning(f"[S{managed.slot_id}] Unhealthy at checkout β healing now")
|
| 231 |
await self._heal(managed)
|
| 232 |
|
| 233 |
managed.request_count += 1
|
|
|
|
| 235 |
|
| 236 |
except Exception:
|
| 237 |
managed.error_count += 1
|
|
|
|
| 238 |
raise
|
| 239 |
|
| 240 |
finally:
|
| 241 |
managed.busy = False
|
|
|
|
| 242 |
if managed.is_healthy():
|
| 243 |
await self._queue.put(managed)
|
| 244 |
else:
|
| 245 |
+
log.warning(f"[S{managed.slot_id}] Dead after use β background heal")
|
| 246 |
asyncio.create_task(self._heal_then_return(managed))
|
| 247 |
|
| 248 |
+
# βββ Healing ββββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
async def _heal(self, managed: ManagedProvider):
|
| 250 |
sid = managed.slot_id
|
| 251 |
+
log.info(f"[S{sid}] Healing slot...")
|
| 252 |
|
| 253 |
+
def _recreate():
|
| 254 |
managed.close()
|
| 255 |
return CloudflareProvider(
|
| 256 |
model = DEFAULT_MODEL,
|
| 257 |
system = DEFAULT_SYSTEM,
|
| 258 |
+
debug = True,
|
| 259 |
use_cache = True,
|
| 260 |
)
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
try:
|
| 263 |
+
managed.provider = await asyncio.wait_for(
|
| 264 |
+
self._loop.run_in_executor(None, _recreate),
|
| 265 |
+
timeout=180,
|
| 266 |
+
)
|
| 267 |
+
managed.provider.debug = False
|
| 268 |
+
managed.born_at = time.time()
|
| 269 |
+
managed.error_count = 0
|
| 270 |
+
managed.last_error = ""
|
| 271 |
+
log.info(f"[S{sid}] β Healed mode={managed.provider._mode!r}")
|
| 272 |
except Exception as e:
|
| 273 |
+
managed.last_error = str(e)
|
| 274 |
+
log.error(f"[S{sid}] β Heal failed: {e}\n{traceback.format_exc()}")
|
| 275 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
+
async def _heal_then_return(self, managed: ManagedProvider):
|
| 278 |
+
sid = managed.slot_id
|
| 279 |
+
for attempt in range(1, SLOT_RETRIES + 1):
|
| 280 |
+
try:
|
| 281 |
+
await self._heal(managed)
|
| 282 |
+
await self._queue.put(managed)
|
| 283 |
+
return
|
| 284 |
+
except Exception as e:
|
| 285 |
+
log.warning(f"[S{sid}] Heal attempt {attempt}/{SLOT_RETRIES} failed: {e}")
|
| 286 |
+
if attempt < SLOT_RETRIES:
|
| 287 |
+
await asyncio.sleep(SLOT_RETRY_WAIT)
|
| 288 |
+
|
| 289 |
+
# Last resort: put it back anyway so queue doesn't shrink permanently
|
| 290 |
+
log.error(f"[S{sid}] All heal attempts failed β slot may be non-functional")
|
| 291 |
await self._queue.put(managed)
|
| 292 |
|
| 293 |
+
# βββ Health monitor βββββββββββββββββββββββββββββββββββ
|
| 294 |
async def health_monitor(self):
|
|
|
|
| 295 |
while True:
|
| 296 |
await asyncio.sleep(HEALTH_INTERVAL)
|
| 297 |
healthy = sum(1 for m in self._slots if m.is_healthy())
|
| 298 |
busy = sum(1 for m in self._slots if m.busy)
|
| 299 |
log.info(
|
| 300 |
+
f"β₯ Pool β {healthy}/{self.size} healthy "
|
| 301 |
+
f"{busy} busy queue={self._queue.qsize()}"
|
| 302 |
)
|
|
|
|
| 303 |
for managed in list(self._slots):
|
| 304 |
if not managed.busy and not managed.is_healthy():
|
| 305 |
+
log.warning(f"[S{managed.slot_id}] Idle+dead β healing in background")
|
|
|
|
| 306 |
asyncio.create_task(self._heal_then_return(managed))
|
| 307 |
|
| 308 |
+
# βββ Status βββββββββββββββββββββββββββββββββββββββββββ
|
| 309 |
@property
|
| 310 |
def status(self) -> dict:
|
| 311 |
return {
|
| 312 |
+
"pool_size": self.size,
|
| 313 |
+
"queue_free": self._queue.qsize() if self._queue else 0,
|
| 314 |
"slots": [
|
| 315 |
{
|
| 316 |
+
"id": m.slot_id,
|
| 317 |
+
"healthy": m.is_healthy(),
|
| 318 |
+
"busy": m.busy,
|
| 319 |
+
"mode": m.provider._mode if m.provider else "none",
|
| 320 |
+
"errors": m.error_count,
|
| 321 |
+
"requests": m.request_count,
|
| 322 |
+
"age_s": round(time.time() - m.born_at, 1) if m.born_at else 0,
|
| 323 |
+
"last_error": m.last_error or None,
|
| 324 |
}
|
| 325 |
for m in self._slots
|
| 326 |
],
|
| 327 |
}
|
| 328 |
|
| 329 |
+
# βββ Shutdown βββββββββββββββββββββββββββββββββββββββββ
|
| 330 |
async def shutdown(self):
|
| 331 |
log.info("Shutting down provider pool...")
|
| 332 |
for m in self._slots:
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 338 |
+
# GLOBAL POOL
|
| 339 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
pool: ProviderPool = None
|
| 341 |
|
| 342 |
|
| 343 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 344 |
+
# LIFESPAN
|
| 345 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 346 |
@asynccontextmanager
|
| 347 |
async def lifespan(app: FastAPI):
|
|
|
|
| 350 |
await pool.initialize()
|
| 351 |
|
| 352 |
monitor = asyncio.create_task(pool.health_monitor())
|
| 353 |
+
log.info(f"β
Server ready {HOST}:{PORT}")
|
| 354 |
|
| 355 |
yield
|
| 356 |
|
|
|
|
| 363 |
|
| 364 |
|
| 365 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 366 |
+
# APP
|
| 367 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 368 |
app = FastAPI(
|
| 369 |
+
title = "Cloudflare AI API",
|
| 370 |
+
description = "OpenAI-compatible API via Cloudflare AI Playground",
|
| 371 |
+
version = "1.1.0",
|
| 372 |
+
lifespan = lifespan,
|
| 373 |
+
docs_url = "/docs",
|
| 374 |
+
redoc_url = "/redoc",
|
| 375 |
)
|
| 376 |
|
| 377 |
app.add_middleware(
|
|
|
|
| 383 |
|
| 384 |
|
| 385 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 386 |
+
# SSE HELPERS
|
| 387 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
def _sse_chunk(content: str, model: str, cid: str) -> str:
|
| 389 |
+
return "data: " + json.dumps({
|
| 390 |
+
"id": cid,
|
|
|
|
| 391 |
"object": "chat.completion.chunk",
|
| 392 |
"created": int(time.time()),
|
| 393 |
"model": model,
|
| 394 |
+
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
|
| 395 |
+
}, ensure_ascii=False) + "\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
+
def _sse_done(model: str, cid: str) -> str:
|
| 398 |
+
return "data: " + json.dumps({
|
| 399 |
+
"id": cid,
|
|
|
|
| 400 |
"object": "chat.completion.chunk",
|
| 401 |
"created": int(time.time()),
|
| 402 |
"model": model,
|
| 403 |
+
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
| 404 |
+
}) + "\n\ndata: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
def _sse_error(msg: str) -> str:
|
| 407 |
+
return f'data: {{"error": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n'
|
| 408 |
|
| 409 |
|
| 410 |
async def _stream_generator(
|
| 411 |
provider: CloudflareProvider,
|
| 412 |
req: ChatRequest,
|
| 413 |
) -> AsyncGenerator[str, None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
loop = asyncio.get_event_loop()
|
| 415 |
+
q: asyncio.Queue = asyncio.Queue(maxsize=512)
|
| 416 |
+
cid = f"chatcmpl-{uuid.uuid4().hex[:20]}"
|
| 417 |
cancel = threading.Event()
|
| 418 |
|
|
|
|
| 419 |
messages = [{"role": m.role, "content": m.content} for m in req.messages]
|
| 420 |
+
kwargs = {
|
| 421 |
"messages": messages,
|
| 422 |
"temperature": req.temperature,
|
| 423 |
+
"model": req.model,
|
| 424 |
}
|
|
|
|
|
|
|
| 425 |
if req.max_tokens:
|
| 426 |
kwargs["max_tokens"] = req.max_tokens
|
| 427 |
if req.system:
|
| 428 |
kwargs["system"] = req.system
|
| 429 |
|
|
|
|
| 430 |
def _worker():
|
| 431 |
try:
|
| 432 |
for chunk in provider.chat(**kwargs):
|
| 433 |
if cancel.is_set():
|
| 434 |
break
|
| 435 |
fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop)
|
| 436 |
+
fut.result(timeout=10)
|
| 437 |
except Exception as exc:
|
| 438 |
+
err = RuntimeError(str(exc))
|
| 439 |
asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5)
|
| 440 |
finally:
|
| 441 |
asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5)
|
|
|
|
| 443 |
t = threading.Thread(target=_worker, daemon=True)
|
| 444 |
t.start()
|
| 445 |
|
|
|
|
| 446 |
try:
|
| 447 |
while True:
|
| 448 |
item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT)
|
| 449 |
|
| 450 |
+
if item is None:
|
| 451 |
+
yield _sse_done(req.model, cid)
|
| 452 |
break
|
| 453 |
|
| 454 |
+
if isinstance(item, Exception):
|
| 455 |
yield _sse_error(str(item))
|
| 456 |
break
|
| 457 |
|
| 458 |
+
if item:
|
| 459 |
+
yield _sse_chunk(item, req.model, cid)
|
| 460 |
|
| 461 |
except asyncio.TimeoutError:
|
| 462 |
cancel.set()
|
| 463 |
+
yield _sse_error("Stream timed out")
|
| 464 |
|
| 465 |
finally:
|
| 466 |
cancel.set()
|
|
|
|
| 474 |
@app.get("/", tags=["Info"])
|
| 475 |
async def root():
|
| 476 |
return {
|
| 477 |
+
"service": "Cloudflare AI API",
|
| 478 |
+
"version": "1.1.0",
|
| 479 |
+
"status": "running",
|
| 480 |
+
"display": os.environ.get("DISPLAY", "not set"),
|
| 481 |
"endpoints": {
|
| 482 |
"chat": "POST /v1/chat/completions",
|
| 483 |
"models": "GET /v1/models",
|
|
|
|
| 490 |
@app.get("/health", tags=["Info"])
|
| 491 |
async def health():
|
| 492 |
if pool is None:
|
| 493 |
+
raise HTTPException(503, detail="Pool not initialized")
|
| 494 |
+
|
| 495 |
healthy = sum(1 for m in pool._slots if m.is_healthy())
|
| 496 |
status = "ok" if healthy > 0 else "degraded"
|
| 497 |
+
|
| 498 |
return JSONResponse(
|
| 499 |
content={"status": status, "pool": pool.status},
|
| 500 |
status_code=200 if status == "ok" else 206,
|
|
|
|
| 515 |
"object": "list",
|
| 516 |
"data": [
|
| 517 |
{
|
| 518 |
+
"id": m["name"],
|
| 519 |
+
"object": "model",
|
| 520 |
+
"created": 0,
|
| 521 |
+
"owned_by": "cloudflare",
|
| 522 |
"context_window": m.get("context", 4096),
|
| 523 |
}
|
| 524 |
for m in models
|
|
|
|
| 530 |
async def chat_completions(req: ChatRequest, request: Request):
|
| 531 |
if pool is None:
|
| 532 |
raise HTTPException(503, detail="Pool not initialized")
|
|
|
|
| 533 |
if not req.messages:
|
| 534 |
raise HTTPException(400, detail="`messages` must not be empty")
|
| 535 |
|
|
|
|
| 538 |
async def _gen():
|
| 539 |
async with pool.acquire() as provider:
|
| 540 |
async for chunk in _stream_generator(provider, req):
|
|
|
|
| 541 |
if await request.is_disconnected():
|
| 542 |
break
|
| 543 |
yield chunk
|
|
|
|
| 554 |
|
| 555 |
# ββ Non-streaming ββββββββββββββββββββββββββββββββββββββ
|
| 556 |
messages = [{"role": m.role, "content": m.content} for m in req.messages]
|
| 557 |
+
kwargs = {
|
| 558 |
"messages": messages,
|
| 559 |
"temperature": req.temperature,
|
| 560 |
+
"model": req.model,
|
| 561 |
}
|
|
|
|
|
|
|
| 562 |
if req.max_tokens:
|
| 563 |
kwargs["max_tokens"] = req.max_tokens
|
| 564 |
if req.system:
|
| 565 |
kwargs["system"] = req.system
|
| 566 |
|
| 567 |
+
loop = asyncio.get_event_loop()
|
| 568 |
+
full_parts: list[str] = []
|
| 569 |
|
| 570 |
async with pool.acquire() as provider:
|
|
|
|
|
|
|
| 571 |
def _collect():
|
| 572 |
for chunk in provider.chat(**kwargs):
|
| 573 |
full_parts.append(chunk)
|
|
|
|
| 577 |
timeout=STREAM_TIMEOUT,
|
| 578 |
)
|
| 579 |
|
|
|
|
|
|
|
| 580 |
return {
|
| 581 |
"id": f"chatcmpl-{uuid.uuid4().hex[:20]}",
|
| 582 |
"object": "chat.completion",
|
|
|
|
| 584 |
"model": req.model,
|
| 585 |
"choices": [{
|
| 586 |
"index": 0,
|
| 587 |
+
"message": {"role": "assistant", "content": "".join(full_parts)},
|
| 588 |
"finish_reason": "stop",
|
| 589 |
}],
|
| 590 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
}
|
| 592 |
|
| 593 |
|
|
|
|
| 597 |
if __name__ == "__main__":
|
| 598 |
uvicorn.run(
|
| 599 |
"server:app",
|
| 600 |
+
host = HOST,
|
| 601 |
+
port = PORT,
|
| 602 |
+
log_level = "info",
|
| 603 |
+
workers = 1,
|
| 604 |
+
loop = "asyncio",
|
| 605 |
timeout_keep_alive = 30,
|
| 606 |
+
)
|