tventurella commited on
Commit
559415e
·
verified ·
1 Parent(s): 654c0be

Update scripts/chat_web.py

Browse files
Files changed (1) hide show
  1. scripts/chat_web.py +10 -79
scripts/chat_web.py CHANGED
@@ -136,6 +136,7 @@ logging.basicConfig(
136
  logger = logging.getLogger(__name__)
137
 
138
  device_type = autodetect_device_type() if args.device_type == "" else args.device_type
 
139
 
140
  @dataclass
141
  class Worker:
@@ -267,57 +268,18 @@ def validate_chat_request(request: ChatRequest):
267
  detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
268
  )
269
 
270
- async def _load_model_background(app: FastAPI):
271
- """Download and load model in background so the server can respond to health checks immediately."""
272
- loop = asyncio.get_event_loop()
273
-
274
- def _download_and_load():
275
- import os
276
- model_dir = os.environ.get("NANOCHAT_BASE_DIR", "/app/nanochat_cache")
277
- checkpoint_dir = os.path.join(model_dir, "chatsft_checkpoints", "d18")
278
- model_file = os.path.join(checkpoint_dir, "model_000070.pt")
279
- model_repo = "tventurella/mr_chatterbox_model"
280
-
281
- # Download if not present
282
- if not os.path.exists(model_file):
283
- print("Downloading model checkpoint...", flush=True)
284
- from huggingface_hub import hf_hub_download
285
- os.makedirs(checkpoint_dir, exist_ok=True)
286
- hf_hub_download(model_repo, "model_000070.pt", local_dir=checkpoint_dir)
287
- hf_hub_download(model_repo, "meta_000070.json", local_dir=checkpoint_dir)
288
- print("Model downloaded.", flush=True)
289
- else:
290
- print("Model checkpoint already present.", flush=True)
291
-
292
- # Initialize compute
293
- print("Initializing compute...", flush=True)
294
- ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
295
-
296
- # Load model
297
- print("Loading model into memory...", flush=True)
298
- pool = WorkerPool(num_gpus=args.num_gpus)
299
- import asyncio as _asyncio
300
- _asyncio.run(pool.initialize(args.source, model_tag=args.model_tag, step=args.step))
301
- return pool
302
-
303
- pool = await loop.run_in_executor(None, _download_and_load)
304
- app.state.worker_pool = pool
305
- app.state.model_ready = True
306
- print(f"Model loaded! Server ready at http://localhost:{args.port}", flush=True)
307
-
308
  @asynccontextmanager
309
  async def lifespan(app: FastAPI):
310
- """Start server immediately, load model in background."""
311
- app.state.model_ready = False
312
- app.state.worker_pool = None
313
- # Start model loading in background
314
- load_task = asyncio.create_task(_load_model_background(app))
315
  # Start periodic log push (every hour)
316
  log_task = asyncio.create_task(periodic_log_push(3600))
317
  yield
318
  # Push any remaining logs on shutdown
319
  log_task.cancel()
320
- load_task.cancel()
321
  push_logs()
322
 
323
  app = FastAPI(lifespan=lifespan)
@@ -330,35 +292,9 @@ app.add_middleware(
330
  allow_headers=["*"],
331
  )
332
 
333
- LOADING_HTML = """<!DOCTYPE html>
334
- <html><head>
335
- <meta charset="UTF-8">
336
- <meta http-equiv="refresh" content="10">
337
- <title>Mr. Chatterbox — Loading</title>
338
- <style>
339
- @import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@700&family=Lora:ital@0;1&display=swap');
340
- body { font-family: 'Lora', Georgia, serif; background: #f5f0e8; color: #2c1810;
341
- display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; }
342
- .box { text-align: center; max-width: 500px; padding: 2rem; }
343
- h1 { font-family: 'Playfair Display', Georgia, serif; color: #722f37; font-size: 2rem; }
344
- p { line-height: 1.7; color: #5c4033; }
345
- .spinner { display: inline-block; width: 40px; height: 40px; border: 3px solid #c4b59a;
346
- border-top-color: #722f37; border-radius: 50%; animation: spin 1s linear infinite; margin: 1rem 0; }
347
- @keyframes spin { to { transform: rotate(360deg); } }
348
- </style></head><body>
349
- <div class="box">
350
- <h1>Mr. Chatterbox</h1>
351
- <div class="spinner"></div>
352
- <p>The gentleman is preparing himself for conversation.<br>
353
- <em>This may take a few minutes on first visit.</em></p>
354
- <p style="font-size:0.85rem; color:#8b7355;">This page will refresh automatically.</p>
355
- </div></body></html>"""
356
-
357
  @app.get("/")
358
  async def root():
359
- """Serve the chat UI, or a loading page if model isn't ready yet."""
360
- if not getattr(app.state, 'model_ready', False):
361
- return HTMLResponse(content=LOADING_HTML)
362
  ui_html_path = os.path.join("nanochat", "ui.html")
363
  with open(ui_html_path, "r", encoding="utf-8") as f:
364
  html_content = f.read()
@@ -432,10 +368,6 @@ async def generate_stream(
432
  async def chat_completions(request: ChatRequest):
433
  """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
434
 
435
- # Block requests while model is still loading
436
- if not getattr(app.state, 'model_ready', False):
437
- raise HTTPException(status_code=503, detail="Model is still loading. Please wait.")
438
-
439
  # Basic validation to prevent abuse
440
  validate_chat_request(request)
441
 
@@ -520,12 +452,11 @@ async def chat_completions(request: ChatRequest):
520
 
521
  @app.get("/health")
522
  async def health():
523
- """Health check endpoint — always returns 200 so HF doesn't time out."""
524
- model_ready = getattr(app.state, 'model_ready', False)
525
  worker_pool = getattr(app.state, 'worker_pool', None)
526
  return {
527
- "status": "ok" if model_ready else "loading",
528
- "ready": model_ready and worker_pool is not None and len(worker_pool.workers) > 0,
529
  "num_gpus": worker_pool.num_gpus if worker_pool else 0,
530
  "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
531
  }
 
136
  logger = logging.getLogger(__name__)
137
 
138
  device_type = autodetect_device_type() if args.device_type == "" else args.device_type
139
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
140
 
141
  @dataclass
142
  class Worker:
 
268
  detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
269
  )
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  @asynccontextmanager
272
  async def lifespan(app: FastAPI):
273
+ """Load models on all GPUs on startup."""
274
+ print("Loading nanochat models across GPUs...")
275
+ app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
276
+ await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
277
+ print(f"Server ready at http://localhost:{args.port}")
278
  # Start periodic log push (every hour)
279
  log_task = asyncio.create_task(periodic_log_push(3600))
280
  yield
281
  # Push any remaining logs on shutdown
282
  log_task.cancel()
 
283
  push_logs()
284
 
285
  app = FastAPI(lifespan=lifespan)
 
292
  allow_headers=["*"],
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  @app.get("/")
296
  async def root():
297
+ """Serve the chat UI."""
 
 
298
  ui_html_path = os.path.join("nanochat", "ui.html")
299
  with open(ui_html_path, "r", encoding="utf-8") as f:
300
  html_content = f.read()
 
368
  async def chat_completions(request: ChatRequest):
369
  """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
370
 
 
 
 
 
371
  # Basic validation to prevent abuse
372
  validate_chat_request(request)
373
 
 
452
 
453
  @app.get("/health")
454
  async def health():
455
+ """Health check endpoint."""
 
456
  worker_pool = getattr(app.state, 'worker_pool', None)
457
  return {
458
+ "status": "ok",
459
+ "ready": worker_pool is not None and len(worker_pool.workers) > 0,
460
  "num_gpus": worker_pool.num_gpus if worker_pool else 0,
461
  "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
462
  }