Adarshu07 commited on
Commit
fcbe7be
Β·
verified Β·
1 Parent(s): 596f075

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +207 -235
server.py CHANGED
@@ -8,11 +8,9 @@
8
  β•‘ GET /health β•‘
9
  β•‘ GET / β•‘
10
  β•‘ β•‘
11
- β•‘ Architecture: β•‘
12
- β•‘ β€’ ProviderPool β€” N pre-warmed WS connections β•‘
13
- β•‘ β€’ acquire() β€” queue-based fair checkout, auto-heal β•‘
14
- β•‘ β€’ HealthMonitor β€” periodic background probe + heal β•‘
15
- ║ ‒ SSE streaming — thread→asyncio bridge via Queue ║
16
  β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
17
  """
18
 
@@ -23,6 +21,7 @@ import os
23
  import sys
24
  import threading
25
  import time
 
26
  import uuid
27
  from contextlib import asynccontextmanager
28
  from typing import AsyncGenerator, List, Optional
@@ -33,7 +32,6 @@ from fastapi.middleware.cors import CORSMiddleware
33
  from fastapi.responses import JSONResponse, StreamingResponse
34
  from pydantic import BaseModel, Field
35
 
36
- # ─── Import provider ────────────────────────────────────────
37
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
  from cloudflare_provider import CloudflareProvider
39
 
@@ -41,24 +39,27 @@ from cloudflare_provider import CloudflareProvider
41
  # LOGGING
42
  # ═══════════════════════════════════════════════════════════
43
  logging.basicConfig(
44
- level=logging.INFO,
45
- format="%(asctime)s %(levelname)-8s %(message)s",
46
- stream=sys.stdout,
47
- datefmt="%H:%M:%S",
48
  )
49
  log = logging.getLogger("cf-api")
50
 
 
51
  # ═══════════════════════════════════════════════════════════
52
- # CONFIG (all tunable via environment variables)
53
  # ═══════════════════════════════════════════════════════════
54
- POOL_SIZE = int(os.getenv("POOL_SIZE", "2"))
55
- PORT = int(os.getenv("PORT", "7860"))
56
- HOST = os.getenv("HOST", "0.0.0.0")
57
- HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60")) # seconds
58
- ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "60")) # wait for free slot
59
- STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120")) # total stream timeout
60
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5")
61
- DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.")
 
 
62
 
63
 
64
  # ═══════════════════════════════════════════════════════════
@@ -69,42 +70,27 @@ class Message(BaseModel):
69
  content: str
70
 
71
  class ChatRequest(BaseModel):
72
- model: str = DEFAULT_MODEL
73
  messages: List[Message]
74
- temperature: float = Field(default=1.0, ge=0.0, le=2.0)
75
- max_tokens: Optional[int] = None
76
- stream: bool = True
77
- system: Optional[str] = None # extra system-prompt override
78
-
79
- class CompletionChoice(BaseModel):
80
- index: int
81
- message: dict
82
- finish_reason: str
83
-
84
- class CompletionResponse(BaseModel):
85
- id: str
86
- object: str
87
- created: int
88
- model: str
89
- choices: List[CompletionChoice]
90
- usage: dict
91
 
92
 
93
  # ═══════════════════════════════════════════════════════════
94
- # MANAGED PROVIDER (pool slot)
95
  # ═══════════════════════════════════════════════════════════
96
  class ManagedProvider:
97
- """A single pool slot wrapping one CloudflareProvider instance."""
98
-
99
  def __init__(self, slot_id: int):
100
- self.slot_id = slot_id
101
  self.provider: Optional[CloudflareProvider] = None
102
- self.busy = False
103
- self.born_at = 0.0
104
- self.error_count = 0
105
  self.request_count = 0
 
106
 
107
- # ── Health ──────────────────────────────────────
108
  def is_healthy(self) -> bool:
109
  if self.provider is None:
110
  return False
@@ -117,7 +103,6 @@ class ManagedProvider:
117
  except Exception:
118
  return False
119
 
120
- # ── Teardown ────────────────────────────────────
121
  def close(self):
122
  p = self.provider
123
  self.provider = None
@@ -130,92 +115,119 @@ class ManagedProvider:
130
  def __repr__(self):
131
  state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead")
132
  mode = self.provider._mode if self.provider else "none"
133
- return (
134
- f"<Slot#{self.slot_id} {state} mode={mode!r} "
135
- f"reqs={self.request_count} errs={self.error_count}>"
136
- )
137
 
138
 
139
  # ═══════════════════════════════════════════════════════════
140
  # PROVIDER POOL
141
  # ═══════════════════════════════════════════════════════════
142
  class ProviderPool:
143
- """
144
- Pre-warmed pool of CloudflareProvider connections.
145
-
146
- β€’ initialize() β€” create all slots at startup
147
- β€’ acquire() β€” async context manager; blocks until a free slot
148
- β€’ health_monitor β€” background coroutine; heals broken idle slots
149
- β€’ shutdown() β€” clean teardown
150
- """
151
-
152
  def __init__(self, size: int = 2):
153
  self.size = size
154
- self._slots: List[ManagedProvider] = []
155
- self._queue: asyncio.Queue = None # set in initialize()
156
- self._loop: asyncio.AbstractEventLoop = None
157
- self._lock = asyncio.Lock()
158
 
159
- # ─── Startup ──────────────────────────────────
160
  async def initialize(self):
161
  self._loop = asyncio.get_event_loop()
162
  self._queue = asyncio.Queue(maxsize=self.size)
163
 
164
  log.info(f"πŸš€ Initializing provider pool (slots={self.size})")
 
 
 
165
 
166
  results = await asyncio.gather(
167
- *[self._spawn_slot(i) for i in range(self.size)],
168
  return_exceptions=True,
169
  )
170
 
171
- ok = sum(1 for r in results if not isinstance(r, Exception))
 
 
 
 
 
 
 
172
  log.info(f" Pool ready β€” {ok}/{self.size} slots healthy")
173
 
174
  if ok == 0:
175
  raise RuntimeError(
176
- "No provider slots could connect. Check network / Xvfb setup."
 
 
 
177
  )
178
 
179
- async def _spawn_slot(self, slot_id: int) -> ManagedProvider:
 
180
  managed = ManagedProvider(slot_id)
181
 
182
- def _create() -> CloudflareProvider:
183
- log.info(f" [S{slot_id}] Connecting...")
184
- return CloudflareProvider(
185
- model = DEFAULT_MODEL,
186
- system = DEFAULT_SYSTEM,
187
- debug = False,
188
- use_cache = True,
189
- )
190
 
191
- managed.provider = await asyncio.wait_for(
192
- self._loop.run_in_executor(None, _create),
193
- timeout=180,
194
- )
195
- managed.born_at = time.time()
 
 
196
 
197
- self._slots.append(managed)
198
- await self._queue.put(managed)
 
 
 
 
199
 
200
- mode = managed.provider._mode
201
- log.info(f" [S{slot_id}] βœ“ Ready mode={mode!r}")
202
- return managed
203
 
204
- # ─── Acquire ──────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  @asynccontextmanager
206
  async def acquire(self):
207
- """Checkout a provider, yield it, return on exit (healing if needed)."""
208
  managed: ManagedProvider = await asyncio.wait_for(
209
  self._queue.get(),
210
  timeout=ACQUIRE_TIMEOUT,
211
  )
212
  managed.busy = True
213
- ok = True
214
 
215
  try:
216
- # Heal before handing out
217
  if not managed.is_healthy():
218
- log.warning(f"[S{managed.slot_id}] Unhealthy β€” healing before use")
219
  await self._heal(managed)
220
 
221
  managed.request_count += 1
@@ -223,104 +235,98 @@ class ProviderPool:
223
 
224
  except Exception:
225
  managed.error_count += 1
226
- ok = False
227
  raise
228
 
229
  finally:
230
  managed.busy = False
231
- # After use: return if healthy, else heal in background
232
  if managed.is_healthy():
233
  await self._queue.put(managed)
234
  else:
235
- log.warning(f"[S{managed.slot_id}] Unhealthy after use β€” background heal")
236
  asyncio.create_task(self._heal_then_return(managed))
237
 
238
- # ─── Healing ──────────────────────────────────
239
  async def _heal(self, managed: ManagedProvider):
240
  sid = managed.slot_id
 
241
 
242
- def _recreate() -> CloudflareProvider:
243
  managed.close()
244
  return CloudflareProvider(
245
  model = DEFAULT_MODEL,
246
  system = DEFAULT_SYSTEM,
247
- debug = False,
248
  use_cache = True,
249
  )
250
 
251
- managed.provider = await asyncio.wait_for(
252
- self._loop.run_in_executor(None, _recreate),
253
- timeout=180,
254
- )
255
- managed.born_at = time.time()
256
- managed.error_count = 0
257
- log.info(f"[S{sid}] βœ“ Healed mode={managed.provider._mode!r}")
258
-
259
- async def _heal_then_return(self, managed: ManagedProvider):
260
  try:
261
- await self._heal(managed)
 
 
 
 
 
 
 
 
262
  except Exception as e:
263
- log.error(f"[S{managed.slot_id}] Heal failed: {e}")
264
- # Try a brand-new slot as last resort
265
- try:
266
- managed.close()
267
- managed.provider = await asyncio.wait_for(
268
- self._loop.run_in_executor(
269
- None,
270
- lambda: CloudflareProvider(
271
- model=DEFAULT_MODEL, system=DEFAULT_SYSTEM,
272
- debug=False, use_cache=True,
273
- ),
274
- ),
275
- timeout=180,
276
- )
277
- managed.born_at = time.time()
278
- managed.error_count = 0
279
- log.info(f"[S{managed.slot_id}] βœ“ Cold-boot recovery succeeded")
280
- except Exception as e2:
281
- log.error(f"[S{managed.slot_id}] Cold-boot also failed: {e2}")
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  await self._queue.put(managed)
284
 
285
- # ─── Health monitor ───────────────────────────
286
  async def health_monitor(self):
287
- """Periodic background coroutine β€” checks and heals idle slots."""
288
  while True:
289
  await asyncio.sleep(HEALTH_INTERVAL)
290
  healthy = sum(1 for m in self._slots if m.is_healthy())
291
  busy = sum(1 for m in self._slots if m.busy)
292
  log.info(
293
- f"β™₯ Health check β€” {healthy}/{self.size} healthy, "
294
- f"{busy} busy, queue={self._queue.qsize()}"
295
  )
296
-
297
  for managed in list(self._slots):
298
  if not managed.busy and not managed.is_healthy():
299
- log.warning(f"[S{managed.slot_id}] Idle but unhealthy β€” healing")
300
- # Pull from queue if it's still there, otherwise skip
301
  asyncio.create_task(self._heal_then_return(managed))
302
 
303
- # ─── Status ───────────────────────────────────
304
  @property
305
  def status(self) -> dict:
306
  return {
307
- "pool_size": self.size,
308
- "queue_free": self._queue.qsize() if self._queue else 0,
309
  "slots": [
310
  {
311
- "id": m.slot_id,
312
- "healthy": m.is_healthy(),
313
- "busy": m.busy,
314
- "mode": m.provider._mode if m.provider else "none",
315
- "errors": m.error_count,
316
- "requests": m.request_count,
317
- "age_s": round(time.time() - m.born_at, 1) if m.born_at else 0,
 
318
  }
319
  for m in self._slots
320
  ],
321
  }
322
 
323
- # ─── Shutdown ─────────────────────────────────
324
  async def shutdown(self):
325
  log.info("Shutting down provider pool...")
326
  for m in self._slots:
@@ -329,13 +335,13 @@ class ProviderPool:
329
 
330
 
331
  # ═══════════════════════════════════════════════════════════
332
- # GLOBAL POOL REFERENCE
333
  # ═══════════════════════════════════════════════════════════
334
  pool: ProviderPool = None
335
 
336
 
337
  # ═══════════════════════════════════════════════════════════
338
- # LIFESPAN (startup / shutdown)
339
  # ═══════════════════════════════════════════════════════════
340
  @asynccontextmanager
341
  async def lifespan(app: FastAPI):
@@ -344,7 +350,7 @@ async def lifespan(app: FastAPI):
344
  await pool.initialize()
345
 
346
  monitor = asyncio.create_task(pool.health_monitor())
347
- log.info(f"βœ… Server ready on {HOST}:{PORT}")
348
 
349
  yield
350
 
@@ -357,15 +363,15 @@ async def lifespan(app: FastAPI):
357
 
358
 
359
  # ═══════════════════════════════════════════════════════════
360
- # FASTAPI APP
361
  # ═══════════════════════════════════════════════════════════
362
  app = FastAPI(
363
- title = "Cloudflare AI API",
364
- description = "OpenAI-compatible streaming API via Cloudflare AI Playground",
365
- version = "1.0.0",
366
- lifespan = lifespan,
367
- docs_url = "/docs",
368
- redoc_url = "/redoc",
369
  )
370
 
371
  app.add_middleware(
@@ -377,85 +383,59 @@ app.add_middleware(
377
 
378
 
379
  # ═══════════════════════════════════════════════════════════
380
- # SSE STREAMING HELPERS
381
  # ═══════════════════════════════════════════════════════════
382
- def _sse_chunk(content: str, model: str, chunk_id: str) -> str:
383
- """Format one SSE data line in OpenAI chunk format."""
384
- payload = {
385
- "id": chunk_id,
386
  "object": "chat.completion.chunk",
387
  "created": int(time.time()),
388
  "model": model,
389
- "choices": [{
390
- "index": 0,
391
- "delta": {"content": content},
392
- "finish_reason": None,
393
- }],
394
- }
395
- return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
396
-
397
 
398
- def _sse_done(model: str, chunk_id: str) -> str:
399
- """Final SSE chunk with finish_reason=stop."""
400
- payload = {
401
- "id": chunk_id,
402
  "object": "chat.completion.chunk",
403
  "created": int(time.time()),
404
  "model": model,
405
- "choices": [{
406
- "index": 0,
407
- "delta": {},
408
- "finish_reason": "stop",
409
- }],
410
- }
411
- return f"data: {json.dumps(payload)}\n\ndata: [DONE]\n\n"
412
-
413
 
414
  def _sse_error(msg: str) -> str:
415
- return f"data: {{\"error\": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n"
416
 
417
 
418
  async def _stream_generator(
419
  provider: CloudflareProvider,
420
  req: ChatRequest,
421
  ) -> AsyncGenerator[str, None]:
422
- """
423
- Bridge between the synchronous provider.chat() generator and
424
- FastAPI's async StreamingResponse.
425
-
426
- Strategy:
427
- 1. Spin up a background thread that runs provider.chat() and
428
- pushes each chunk into an asyncio.Queue.
429
- 2. Yield SSE-formatted chunks from the queue in the async loop.
430
- """
431
  loop = asyncio.get_event_loop()
432
- q: asyncio.Queue = asyncio.Queue(maxsize=256)
433
- chunk_id = f"chatcmpl-{uuid.uuid4().hex[:20]}"
434
  cancel = threading.Event()
435
 
436
- # Build kwargs for provider
437
  messages = [{"role": m.role, "content": m.content} for m in req.messages]
438
- kwargs: dict = {
439
  "messages": messages,
440
  "temperature": req.temperature,
 
441
  }
442
- if req.model:
443
- kwargs["model"] = req.model
444
  if req.max_tokens:
445
  kwargs["max_tokens"] = req.max_tokens
446
  if req.system:
447
  kwargs["system"] = req.system
448
 
449
- # ── Worker thread ────────────────────────────
450
  def _worker():
451
  try:
452
  for chunk in provider.chat(**kwargs):
453
  if cancel.is_set():
454
  break
455
  fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop)
456
- fut.result(timeout=10) # backpressure: block thread if queue full
457
  except Exception as exc:
458
- err = RuntimeError(f"Stream error: {exc}")
459
  asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5)
460
  finally:
461
  asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5)
@@ -463,25 +443,24 @@ async def _stream_generator(
463
  t = threading.Thread(target=_worker, daemon=True)
464
  t.start()
465
 
466
- # ── Async consumer ────────────────────────────
467
  try:
468
  while True:
469
  item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT)
470
 
471
- if item is None: # sentinel β€” stream done
472
- yield _sse_done(req.model, chunk_id)
473
  break
474
 
475
- if isinstance(item, Exception): # error from worker
476
  yield _sse_error(str(item))
477
  break
478
 
479
- if item: # normal text chunk
480
- yield _sse_chunk(item, req.model, chunk_id)
481
 
482
  except asyncio.TimeoutError:
483
  cancel.set()
484
- yield _sse_error("Stream timed out β€” no data received")
485
 
486
  finally:
487
  cancel.set()
@@ -495,9 +474,10 @@ async def _stream_generator(
495
  @app.get("/", tags=["Info"])
496
  async def root():
497
  return {
498
- "service": "Cloudflare AI API",
499
- "version": "1.0.0",
500
- "status": "running",
 
501
  "endpoints": {
502
  "chat": "POST /v1/chat/completions",
503
  "models": "GET /v1/models",
@@ -510,9 +490,11 @@ async def root():
510
  @app.get("/health", tags=["Info"])
511
  async def health():
512
  if pool is None:
513
- raise HTTPException(503, detail="Pool not yet initialized")
 
514
  healthy = sum(1 for m in pool._slots if m.is_healthy())
515
  status = "ok" if healthy > 0 else "degraded"
 
516
  return JSONResponse(
517
  content={"status": status, "pool": pool.status},
518
  status_code=200 if status == "ok" else 206,
@@ -533,10 +515,10 @@ async def list_models():
533
  "object": "list",
534
  "data": [
535
  {
536
- "id": m["name"],
537
- "object": "model",
538
- "created": 0,
539
- "owned_by": "cloudflare",
540
  "context_window": m.get("context", 4096),
541
  }
542
  for m in models
@@ -548,7 +530,6 @@ async def list_models():
548
  async def chat_completions(req: ChatRequest, request: Request):
549
  if pool is None:
550
  raise HTTPException(503, detail="Pool not initialized")
551
-
552
  if not req.messages:
553
  raise HTTPException(400, detail="`messages` must not be empty")
554
 
@@ -557,7 +538,6 @@ async def chat_completions(req: ChatRequest, request: Request):
557
  async def _gen():
558
  async with pool.acquire() as provider:
559
  async for chunk in _stream_generator(provider, req):
560
- # Check if client disconnected
561
  if await request.is_disconnected():
562
  break
563
  yield chunk
@@ -574,22 +554,20 @@ async def chat_completions(req: ChatRequest, request: Request):
574
 
575
  # ── Non-streaming ──────────────────────────────────────
576
  messages = [{"role": m.role, "content": m.content} for m in req.messages]
577
- kwargs: dict = {
578
  "messages": messages,
579
  "temperature": req.temperature,
 
580
  }
581
- if req.model:
582
- kwargs["model"] = req.model
583
  if req.max_tokens:
584
  kwargs["max_tokens"] = req.max_tokens
585
  if req.system:
586
  kwargs["system"] = req.system
587
 
588
- loop = asyncio.get_event_loop()
 
589
 
590
  async with pool.acquire() as provider:
591
- full_parts: list[str] = []
592
-
593
  def _collect():
594
  for chunk in provider.chat(**kwargs):
595
  full_parts.append(chunk)
@@ -599,8 +577,6 @@ async def chat_completions(req: ChatRequest, request: Request):
599
  timeout=STREAM_TIMEOUT,
600
  )
601
 
602
- response_text = "".join(full_parts)
603
-
604
  return {
605
  "id": f"chatcmpl-{uuid.uuid4().hex[:20]}",
606
  "object": "chat.completion",
@@ -608,14 +584,10 @@ async def chat_completions(req: ChatRequest, request: Request):
608
  "model": req.model,
609
  "choices": [{
610
  "index": 0,
611
- "message": {"role": "assistant", "content": response_text},
612
  "finish_reason": "stop",
613
  }],
614
- "usage": {
615
- "prompt_tokens": 0,
616
- "completion_tokens": 0,
617
- "total_tokens": 0,
618
- },
619
  }
620
 
621
 
@@ -625,10 +597,10 @@ async def chat_completions(req: ChatRequest, request: Request):
625
  if __name__ == "__main__":
626
  uvicorn.run(
627
  "server:app",
628
- host = HOST,
629
- port = PORT,
630
- log_level = "info",
631
- workers = 1, # single worker β€” state is in-process
632
- loop = "asyncio",
633
  timeout_keep_alive = 30,
634
- )
 
8
  β•‘ GET /health β•‘
9
  β•‘ GET / β•‘
10
  β•‘ β•‘
11
+ β•‘ Pool startup: up to 3 retries per slot, logs exact errors. β•‘
12
+ β•‘ Health monitor: heals dead idle slots every 60s. β•‘
13
+ ║ SSE: thread→asyncio bridge with backpressure. ║
 
 
14
  β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
15
  """
16
 
 
21
  import sys
22
  import threading
23
  import time
24
+ import traceback
25
  import uuid
26
  from contextlib import asynccontextmanager
27
  from typing import AsyncGenerator, List, Optional
 
32
  from fastapi.responses import JSONResponse, StreamingResponse
33
  from pydantic import BaseModel, Field
34
 
 
35
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
36
  from cloudflare_provider import CloudflareProvider
37
 
 
39
  # LOGGING
40
  # ═══════════════════════════════════════════════════════════
41
  logging.basicConfig(
42
+ level = logging.INFO,
43
+ format = "%(asctime)s %(levelname)-8s %(message)s",
44
+ stream = sys.stdout,
45
+ datefmt = "%H:%M:%S",
46
  )
47
  log = logging.getLogger("cf-api")
48
 
49
+
50
  # ═══════════════════════════════════════════════════════════
51
+ # CONFIG
52
  # ═══════════════════════════════════════════════════════════
53
+ POOL_SIZE = int(os.getenv("POOL_SIZE", "2"))
54
+ PORT = int(os.getenv("PORT", "7860"))
55
+ HOST = os.getenv("HOST", "0.0.0.0")
56
+ HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60"))
57
+ ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "90"))
58
+ STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120"))
59
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5")
60
+ DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.")
61
+ SLOT_RETRIES = int(os.getenv("SLOT_RETRIES", "3"))
62
+ SLOT_RETRY_WAIT = int(os.getenv("SLOT_RETRY_WAIT", "10")) # seconds between retries
63
 
64
 
65
  # ═══════════════════════════════════════════════════════════
 
70
  content: str
71
 
72
  class ChatRequest(BaseModel):
73
+ model: str = DEFAULT_MODEL
74
  messages: List[Message]
75
+ temperature: float = Field(default=1.0, ge=0.0, le=2.0)
76
+ max_tokens: Optional[int] = None
77
+ stream: bool = True
78
+ system: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  # ═══════════════════════════════════════════════════════════
82
+ # MANAGED PROVIDER SLOT
83
  # ═══════════════════════════════════════════════════════════
84
  class ManagedProvider:
 
 
85
  def __init__(self, slot_id: int):
86
+ self.slot_id = slot_id
87
  self.provider: Optional[CloudflareProvider] = None
88
+ self.busy = False
89
+ self.born_at = 0.0
90
+ self.error_count = 0
91
  self.request_count = 0
92
+ self.last_error = ""
93
 
 
94
  def is_healthy(self) -> bool:
95
  if self.provider is None:
96
  return False
 
103
  except Exception:
104
  return False
105
 
 
106
  def close(self):
107
  p = self.provider
108
  self.provider = None
 
115
  def __repr__(self):
116
  state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead")
117
  mode = self.provider._mode if self.provider else "none"
118
+ return f"<Slot#{self.slot_id} {state} mode={mode!r} reqs={self.request_count}>"
 
 
 
119
 
120
 
121
  # ═══════════════════════════════════════════════════════════
122
  # PROVIDER POOL
123
  # ═══════════════════════════════════════════════════════════
124
  class ProviderPool:
 
 
 
 
 
 
 
 
 
125
  def __init__(self, size: int = 2):
126
  self.size = size
127
+ self._slots: List[ManagedProvider] = []
128
+ self._queue: asyncio.Queue = None
129
+ self._loop: asyncio.AbstractEventLoop = None
 
130
 
131
+ # ─── Startup ──────────────────────────────────────────
132
  async def initialize(self):
133
  self._loop = asyncio.get_event_loop()
134
  self._queue = asyncio.Queue(maxsize=self.size)
135
 
136
  log.info(f"πŸš€ Initializing provider pool (slots={self.size})")
137
+ log.info(f" DISPLAY={os.environ.get('DISPLAY', 'NOT SET')}")
138
+ log.info(f" XVFB_EXTERNAL={os.environ.get('XVFB_EXTERNAL', '0')}")
139
+ log.info(f" VR_DISPLAY={os.environ.get('VR_DISPLAY', '0')}")
140
 
141
  results = await asyncio.gather(
142
+ *[self._spawn_slot_with_retry(i) for i in range(self.size)],
143
  return_exceptions=True,
144
  )
145
 
146
+ ok = sum(1 for r in results if not isinstance(r, Exception))
147
+ fail = sum(1 for r in results if isinstance(r, Exception))
148
+
149
+ if fail:
150
+ for i, r in enumerate(results):
151
+ if isinstance(r, Exception):
152
+ log.error(f" [S{i}] FAILED: {r}")
153
+
154
  log.info(f" Pool ready β€” {ok}/{self.size} slots healthy")
155
 
156
  if ok == 0:
157
  raise RuntimeError(
158
+ f"All {self.size} provider slots failed to connect.\n"
159
+ f" β†’ Check DISPLAY / XVFB_EXTERNAL environment variables.\n"
160
+ f" β†’ Ensure entrypoint.sh started Xvfb before the server.\n"
161
+ f" β†’ Check network connectivity to playground.ai.cloudflare.com."
162
  )
163
 
164
+ async def _spawn_slot_with_retry(self, slot_id: int) -> "ManagedProvider":
165
+ """Try to create a slot, retrying up to SLOT_RETRIES times."""
166
  managed = ManagedProvider(slot_id)
167
 
168
+ for attempt in range(1, SLOT_RETRIES + 1):
169
+ try:
170
+ log.info(f" [S{slot_id}] Connecting... (attempt {attempt}/{SLOT_RETRIES})")
 
 
 
 
 
171
 
172
+ def _create():
173
+ return CloudflareProvider(
174
+ model = DEFAULT_MODEL,
175
+ system = DEFAULT_SYSTEM,
176
+ debug = True, # verbose during init so we can see failures
177
+ use_cache = True,
178
+ )
179
 
180
+ managed.provider = await asyncio.wait_for(
181
+ self._loop.run_in_executor(None, _create),
182
+ timeout=180,
183
+ )
184
+ managed.provider.debug = False # quiet after successful boot
185
+ managed.born_at = time.time()
186
 
187
+ self._slots.append(managed)
188
+ await self._queue.put(managed)
 
189
 
190
+ mode = managed.provider._mode
191
+ log.info(f" [S{slot_id}] βœ“ Ready mode={mode!r}")
192
+ return managed
193
+
194
+ except asyncio.TimeoutError:
195
+ err = f"Slot {slot_id} timed out (attempt {attempt})"
196
+ log.warning(f" [S{slot_id}] ⚠ {err}")
197
+ managed.last_error = err
198
+ managed.close()
199
+
200
+ except Exception as exc:
201
+ err = str(exc)
202
+ # Print full traceback for debugging
203
+ log.warning(
204
+ f" [S{slot_id}] ⚠ Attempt {attempt} failed:\n"
205
+ + traceback.format_exc()
206
+ )
207
+ managed.last_error = err
208
+ managed.close()
209
+
210
+ if attempt < SLOT_RETRIES:
211
+ log.info(f" [S{slot_id}] Retrying in {SLOT_RETRY_WAIT}s...")
212
+ await asyncio.sleep(SLOT_RETRY_WAIT)
213
+
214
+ raise RuntimeError(
215
+ f"Slot {slot_id} failed after {SLOT_RETRIES} attempts. "
216
+ f"Last error: {managed.last_error}"
217
+ )
218
+
219
+ # ─── Acquire ──────────────────────────────────────────
220
  @asynccontextmanager
221
  async def acquire(self):
 
222
  managed: ManagedProvider = await asyncio.wait_for(
223
  self._queue.get(),
224
  timeout=ACQUIRE_TIMEOUT,
225
  )
226
  managed.busy = True
 
227
 
228
  try:
 
229
  if not managed.is_healthy():
230
+ log.warning(f"[S{managed.slot_id}] Unhealthy at checkout β€” healing now")
231
  await self._heal(managed)
232
 
233
  managed.request_count += 1
 
235
 
236
  except Exception:
237
  managed.error_count += 1
 
238
  raise
239
 
240
  finally:
241
  managed.busy = False
 
242
  if managed.is_healthy():
243
  await self._queue.put(managed)
244
  else:
245
+ log.warning(f"[S{managed.slot_id}] Dead after use β€” background heal")
246
  asyncio.create_task(self._heal_then_return(managed))
247
 
248
+ # ─── Healing ──────────────────────────────────────────
249
  async def _heal(self, managed: ManagedProvider):
250
  sid = managed.slot_id
251
+ log.info(f"[S{sid}] Healing slot...")
252
 
253
+ def _recreate():
254
  managed.close()
255
  return CloudflareProvider(
256
  model = DEFAULT_MODEL,
257
  system = DEFAULT_SYSTEM,
258
+ debug = True,
259
  use_cache = True,
260
  )
261
 
 
 
 
 
 
 
 
 
 
262
  try:
263
+ managed.provider = await asyncio.wait_for(
264
+ self._loop.run_in_executor(None, _recreate),
265
+ timeout=180,
266
+ )
267
+ managed.provider.debug = False
268
+ managed.born_at = time.time()
269
+ managed.error_count = 0
270
+ managed.last_error = ""
271
+ log.info(f"[S{sid}] βœ“ Healed mode={managed.provider._mode!r}")
272
  except Exception as e:
273
+ managed.last_error = str(e)
274
+ log.error(f"[S{sid}] βœ— Heal failed: {e}\n{traceback.format_exc()}")
275
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ async def _heal_then_return(self, managed: ManagedProvider):
278
+ sid = managed.slot_id
279
+ for attempt in range(1, SLOT_RETRIES + 1):
280
+ try:
281
+ await self._heal(managed)
282
+ await self._queue.put(managed)
283
+ return
284
+ except Exception as e:
285
+ log.warning(f"[S{sid}] Heal attempt {attempt}/{SLOT_RETRIES} failed: {e}")
286
+ if attempt < SLOT_RETRIES:
287
+ await asyncio.sleep(SLOT_RETRY_WAIT)
288
+
289
+ # Last resort: put it back anyway so queue doesn't shrink permanently
290
+ log.error(f"[S{sid}] All heal attempts failed β€” slot may be non-functional")
291
  await self._queue.put(managed)
292
 
293
+ # ─── Health monitor ───────────────────────────────────
294
  async def health_monitor(self):
 
295
  while True:
296
  await asyncio.sleep(HEALTH_INTERVAL)
297
  healthy = sum(1 for m in self._slots if m.is_healthy())
298
  busy = sum(1 for m in self._slots if m.busy)
299
  log.info(
300
+ f"β™₯ Pool β€” {healthy}/{self.size} healthy "
301
+ f"{busy} busy queue={self._queue.qsize()}"
302
  )
 
303
  for managed in list(self._slots):
304
  if not managed.busy and not managed.is_healthy():
305
+ log.warning(f"[S{managed.slot_id}] Idle+dead β€” healing in background")
 
306
  asyncio.create_task(self._heal_then_return(managed))
307
 
308
+ # ─── Status ───────────────────────────────────────────
309
  @property
310
  def status(self) -> dict:
311
  return {
312
+ "pool_size": self.size,
313
+ "queue_free": self._queue.qsize() if self._queue else 0,
314
  "slots": [
315
  {
316
+ "id": m.slot_id,
317
+ "healthy": m.is_healthy(),
318
+ "busy": m.busy,
319
+ "mode": m.provider._mode if m.provider else "none",
320
+ "errors": m.error_count,
321
+ "requests": m.request_count,
322
+ "age_s": round(time.time() - m.born_at, 1) if m.born_at else 0,
323
+ "last_error": m.last_error or None,
324
  }
325
  for m in self._slots
326
  ],
327
  }
328
 
329
+ # ─── Shutdown ─────────────────────────────────────────
330
  async def shutdown(self):
331
  log.info("Shutting down provider pool...")
332
  for m in self._slots:
 
335
 
336
 
337
  # ═══════════════════════════════════════════════════════════
338
+ # GLOBAL POOL
339
  # ═══════════════════════════════════════════════════════════
340
  pool: ProviderPool = None
341
 
342
 
343
  # ═══════════════════════════════════════════════════════════
344
+ # LIFESPAN
345
  # ═══════════════════════════════════════════════════════════
346
  @asynccontextmanager
347
  async def lifespan(app: FastAPI):
 
350
  await pool.initialize()
351
 
352
  monitor = asyncio.create_task(pool.health_monitor())
353
+ log.info(f"βœ… Server ready {HOST}:{PORT}")
354
 
355
  yield
356
 
 
363
 
364
 
365
  # ═══════════════════════════════════════════════════════════
366
+ # APP
367
  # ═══════════════════════════════════════════════════════════
368
  app = FastAPI(
369
+ title = "Cloudflare AI API",
370
+ description = "OpenAI-compatible API via Cloudflare AI Playground",
371
+ version = "1.1.0",
372
+ lifespan = lifespan,
373
+ docs_url = "/docs",
374
+ redoc_url = "/redoc",
375
  )
376
 
377
  app.add_middleware(
 
383
 
384
 
385
  # ═══════════════════════════════════════════════════════════
386
+ # SSE HELPERS
387
  # ═══════════════════════════════════════════════════════════
388
+ def _sse_chunk(content: str, model: str, cid: str) -> str:
389
+ return "data: " + json.dumps({
390
+ "id": cid,
 
391
  "object": "chat.completion.chunk",
392
  "created": int(time.time()),
393
  "model": model,
394
+ "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
395
+ }, ensure_ascii=False) + "\n\n"
 
 
 
 
 
 
396
 
397
+ def _sse_done(model: str, cid: str) -> str:
398
+ return "data: " + json.dumps({
399
+ "id": cid,
 
400
  "object": "chat.completion.chunk",
401
  "created": int(time.time()),
402
  "model": model,
403
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
404
+ }) + "\n\ndata: [DONE]\n\n"
 
 
 
 
 
 
405
 
406
  def _sse_error(msg: str) -> str:
407
+ return f'data: {{"error": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n'
408
 
409
 
410
  async def _stream_generator(
411
  provider: CloudflareProvider,
412
  req: ChatRequest,
413
  ) -> AsyncGenerator[str, None]:
 
 
 
 
 
 
 
 
 
414
  loop = asyncio.get_event_loop()
415
+ q: asyncio.Queue = asyncio.Queue(maxsize=512)
416
+ cid = f"chatcmpl-{uuid.uuid4().hex[:20]}"
417
  cancel = threading.Event()
418
 
 
419
  messages = [{"role": m.role, "content": m.content} for m in req.messages]
420
+ kwargs = {
421
  "messages": messages,
422
  "temperature": req.temperature,
423
+ "model": req.model,
424
  }
 
 
425
  if req.max_tokens:
426
  kwargs["max_tokens"] = req.max_tokens
427
  if req.system:
428
  kwargs["system"] = req.system
429
 
 
430
  def _worker():
431
  try:
432
  for chunk in provider.chat(**kwargs):
433
  if cancel.is_set():
434
  break
435
  fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop)
436
+ fut.result(timeout=10)
437
  except Exception as exc:
438
+ err = RuntimeError(str(exc))
439
  asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5)
440
  finally:
441
  asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5)
 
443
  t = threading.Thread(target=_worker, daemon=True)
444
  t.start()
445
 
 
446
  try:
447
  while True:
448
  item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT)
449
 
450
+ if item is None:
451
+ yield _sse_done(req.model, cid)
452
  break
453
 
454
+ if isinstance(item, Exception):
455
  yield _sse_error(str(item))
456
  break
457
 
458
+ if item:
459
+ yield _sse_chunk(item, req.model, cid)
460
 
461
  except asyncio.TimeoutError:
462
  cancel.set()
463
+ yield _sse_error("Stream timed out")
464
 
465
  finally:
466
  cancel.set()
 
474
  @app.get("/", tags=["Info"])
475
  async def root():
476
  return {
477
+ "service": "Cloudflare AI API",
478
+ "version": "1.1.0",
479
+ "status": "running",
480
+ "display": os.environ.get("DISPLAY", "not set"),
481
  "endpoints": {
482
  "chat": "POST /v1/chat/completions",
483
  "models": "GET /v1/models",
 
490
  @app.get("/health", tags=["Info"])
491
  async def health():
492
  if pool is None:
493
+ raise HTTPException(503, detail="Pool not initialized")
494
+
495
  healthy = sum(1 for m in pool._slots if m.is_healthy())
496
  status = "ok" if healthy > 0 else "degraded"
497
+
498
  return JSONResponse(
499
  content={"status": status, "pool": pool.status},
500
  status_code=200 if status == "ok" else 206,
 
515
  "object": "list",
516
  "data": [
517
  {
518
+ "id": m["name"],
519
+ "object": "model",
520
+ "created": 0,
521
+ "owned_by": "cloudflare",
522
  "context_window": m.get("context", 4096),
523
  }
524
  for m in models
 
530
  async def chat_completions(req: ChatRequest, request: Request):
531
  if pool is None:
532
  raise HTTPException(503, detail="Pool not initialized")
 
533
  if not req.messages:
534
  raise HTTPException(400, detail="`messages` must not be empty")
535
 
 
538
  async def _gen():
539
  async with pool.acquire() as provider:
540
  async for chunk in _stream_generator(provider, req):
 
541
  if await request.is_disconnected():
542
  break
543
  yield chunk
 
554
 
555
  # ── Non-streaming ──────────────────────────────────────
556
  messages = [{"role": m.role, "content": m.content} for m in req.messages]
557
+ kwargs = {
558
  "messages": messages,
559
  "temperature": req.temperature,
560
+ "model": req.model,
561
  }
 
 
562
  if req.max_tokens:
563
  kwargs["max_tokens"] = req.max_tokens
564
  if req.system:
565
  kwargs["system"] = req.system
566
 
567
+ loop = asyncio.get_event_loop()
568
+ full_parts: list[str] = []
569
 
570
  async with pool.acquire() as provider:
 
 
571
  def _collect():
572
  for chunk in provider.chat(**kwargs):
573
  full_parts.append(chunk)
 
577
  timeout=STREAM_TIMEOUT,
578
  )
579
 
 
 
580
  return {
581
  "id": f"chatcmpl-{uuid.uuid4().hex[:20]}",
582
  "object": "chat.completion",
 
584
  "model": req.model,
585
  "choices": [{
586
  "index": 0,
587
+ "message": {"role": "assistant", "content": "".join(full_parts)},
588
  "finish_reason": "stop",
589
  }],
590
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
 
 
 
 
591
  }
592
 
593
 
 
597
  if __name__ == "__main__":
598
  uvicorn.run(
599
  "server:app",
600
+ host = HOST,
601
+ port = PORT,
602
+ log_level = "info",
603
+ workers = 1,
604
+ loop = "asyncio",
605
  timeout_keep_alive = 30,
606
+ )