Spaces:
Running
Running
Update scripts/chat_web.py
Browse files- scripts/chat_web.py +10 -79
scripts/chat_web.py
CHANGED
|
@@ -136,6 +136,7 @@ logging.basicConfig(
|
|
| 136 |
logger = logging.getLogger(__name__)
|
| 137 |
|
| 138 |
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
|
|
|
| 139 |
|
| 140 |
@dataclass
|
| 141 |
class Worker:
|
|
@@ -267,57 +268,18 @@ def validate_chat_request(request: ChatRequest):
|
|
| 267 |
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
| 268 |
)
|
| 269 |
|
| 270 |
-
async def _load_model_background(app: FastAPI):
|
| 271 |
-
"""Download and load model in background so the server can respond to health checks immediately."""
|
| 272 |
-
loop = asyncio.get_event_loop()
|
| 273 |
-
|
| 274 |
-
def _download_and_load():
|
| 275 |
-
import os
|
| 276 |
-
model_dir = os.environ.get("NANOCHAT_BASE_DIR", "/app/nanochat_cache")
|
| 277 |
-
checkpoint_dir = os.path.join(model_dir, "chatsft_checkpoints", "d18")
|
| 278 |
-
model_file = os.path.join(checkpoint_dir, "model_000070.pt")
|
| 279 |
-
model_repo = "tventurella/mr_chatterbox_model"
|
| 280 |
-
|
| 281 |
-
# Download if not present
|
| 282 |
-
if not os.path.exists(model_file):
|
| 283 |
-
print("Downloading model checkpoint...", flush=True)
|
| 284 |
-
from huggingface_hub import hf_hub_download
|
| 285 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 286 |
-
hf_hub_download(model_repo, "model_000070.pt", local_dir=checkpoint_dir)
|
| 287 |
-
hf_hub_download(model_repo, "meta_000070.json", local_dir=checkpoint_dir)
|
| 288 |
-
print("Model downloaded.", flush=True)
|
| 289 |
-
else:
|
| 290 |
-
print("Model checkpoint already present.", flush=True)
|
| 291 |
-
|
| 292 |
-
# Initialize compute
|
| 293 |
-
print("Initializing compute...", flush=True)
|
| 294 |
-
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 295 |
-
|
| 296 |
-
# Load model
|
| 297 |
-
print("Loading model into memory...", flush=True)
|
| 298 |
-
pool = WorkerPool(num_gpus=args.num_gpus)
|
| 299 |
-
import asyncio as _asyncio
|
| 300 |
-
_asyncio.run(pool.initialize(args.source, model_tag=args.model_tag, step=args.step))
|
| 301 |
-
return pool
|
| 302 |
-
|
| 303 |
-
pool = await loop.run_in_executor(None, _download_and_load)
|
| 304 |
-
app.state.worker_pool = pool
|
| 305 |
-
app.state.model_ready = True
|
| 306 |
-
print(f"Model loaded! Server ready at http://localhost:{args.port}", flush=True)
|
| 307 |
-
|
| 308 |
@asynccontextmanager
|
| 309 |
async def lifespan(app: FastAPI):
|
| 310 |
-
"""
|
| 311 |
-
|
| 312 |
-
app.state.worker_pool =
|
| 313 |
-
|
| 314 |
-
|
| 315 |
# Start periodic log push (every hour)
|
| 316 |
log_task = asyncio.create_task(periodic_log_push(3600))
|
| 317 |
yield
|
| 318 |
# Push any remaining logs on shutdown
|
| 319 |
log_task.cancel()
|
| 320 |
-
load_task.cancel()
|
| 321 |
push_logs()
|
| 322 |
|
| 323 |
app = FastAPI(lifespan=lifespan)
|
|
@@ -330,35 +292,9 @@ app.add_middleware(
|
|
| 330 |
allow_headers=["*"],
|
| 331 |
)
|
| 332 |
|
| 333 |
-
LOADING_HTML = """<!DOCTYPE html>
|
| 334 |
-
<html><head>
|
| 335 |
-
<meta charset="UTF-8">
|
| 336 |
-
<meta http-equiv="refresh" content="10">
|
| 337 |
-
<title>Mr. Chatterbox — Loading</title>
|
| 338 |
-
<style>
|
| 339 |
-
@import url('https://fonts.googleapis.com/css2?family=Playfair+Display:wght@700&family=Lora:ital@0;1&display=swap');
|
| 340 |
-
body { font-family: 'Lora', Georgia, serif; background: #f5f0e8; color: #2c1810;
|
| 341 |
-
display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; }
|
| 342 |
-
.box { text-align: center; max-width: 500px; padding: 2rem; }
|
| 343 |
-
h1 { font-family: 'Playfair Display', Georgia, serif; color: #722f37; font-size: 2rem; }
|
| 344 |
-
p { line-height: 1.7; color: #5c4033; }
|
| 345 |
-
.spinner { display: inline-block; width: 40px; height: 40px; border: 3px solid #c4b59a;
|
| 346 |
-
border-top-color: #722f37; border-radius: 50%; animation: spin 1s linear infinite; margin: 1rem 0; }
|
| 347 |
-
@keyframes spin { to { transform: rotate(360deg); } }
|
| 348 |
-
</style></head><body>
|
| 349 |
-
<div class="box">
|
| 350 |
-
<h1>Mr. Chatterbox</h1>
|
| 351 |
-
<div class="spinner"></div>
|
| 352 |
-
<p>The gentleman is preparing himself for conversation.<br>
|
| 353 |
-
<em>This may take a few minutes on first visit.</em></p>
|
| 354 |
-
<p style="font-size:0.85rem; color:#8b7355;">This page will refresh automatically.</p>
|
| 355 |
-
</div></body></html>"""
|
| 356 |
-
|
| 357 |
@app.get("/")
|
| 358 |
async def root():
|
| 359 |
-
"""Serve the chat UI
|
| 360 |
-
if not getattr(app.state, 'model_ready', False):
|
| 361 |
-
return HTMLResponse(content=LOADING_HTML)
|
| 362 |
ui_html_path = os.path.join("nanochat", "ui.html")
|
| 363 |
with open(ui_html_path, "r", encoding="utf-8") as f:
|
| 364 |
html_content = f.read()
|
|
@@ -432,10 +368,6 @@ async def generate_stream(
|
|
| 432 |
async def chat_completions(request: ChatRequest):
|
| 433 |
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
| 434 |
|
| 435 |
-
# Block requests while model is still loading
|
| 436 |
-
if not getattr(app.state, 'model_ready', False):
|
| 437 |
-
raise HTTPException(status_code=503, detail="Model is still loading. Please wait.")
|
| 438 |
-
|
| 439 |
# Basic validation to prevent abuse
|
| 440 |
validate_chat_request(request)
|
| 441 |
|
|
@@ -520,12 +452,11 @@ async def chat_completions(request: ChatRequest):
|
|
| 520 |
|
| 521 |
@app.get("/health")
|
| 522 |
async def health():
|
| 523 |
-
"""Health check endpoint
|
| 524 |
-
model_ready = getattr(app.state, 'model_ready', False)
|
| 525 |
worker_pool = getattr(app.state, 'worker_pool', None)
|
| 526 |
return {
|
| 527 |
-
"status": "ok"
|
| 528 |
-
"ready":
|
| 529 |
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
| 530 |
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
| 531 |
}
|
|
|
|
| 136 |
logger = logging.getLogger(__name__)
|
| 137 |
|
| 138 |
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 139 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 140 |
|
| 141 |
@dataclass
|
| 142 |
class Worker:
|
|
|
|
| 268 |
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
| 269 |
)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
@asynccontextmanager
|
| 272 |
async def lifespan(app: FastAPI):
|
| 273 |
+
"""Load models on all GPUs on startup."""
|
| 274 |
+
print("Loading nanochat models across GPUs...")
|
| 275 |
+
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
| 276 |
+
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
| 277 |
+
print(f"Server ready at http://localhost:{args.port}")
|
| 278 |
# Start periodic log push (every hour)
|
| 279 |
log_task = asyncio.create_task(periodic_log_push(3600))
|
| 280 |
yield
|
| 281 |
# Push any remaining logs on shutdown
|
| 282 |
log_task.cancel()
|
|
|
|
| 283 |
push_logs()
|
| 284 |
|
| 285 |
app = FastAPI(lifespan=lifespan)
|
|
|
|
| 292 |
allow_headers=["*"],
|
| 293 |
)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
@app.get("/")
|
| 296 |
async def root():
|
| 297 |
+
"""Serve the chat UI."""
|
|
|
|
|
|
|
| 298 |
ui_html_path = os.path.join("nanochat", "ui.html")
|
| 299 |
with open(ui_html_path, "r", encoding="utf-8") as f:
|
| 300 |
html_content = f.read()
|
|
|
|
| 368 |
async def chat_completions(request: ChatRequest):
|
| 369 |
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
# Basic validation to prevent abuse
|
| 372 |
validate_chat_request(request)
|
| 373 |
|
|
|
|
| 452 |
|
| 453 |
@app.get("/health")
|
| 454 |
async def health():
|
| 455 |
+
"""Health check endpoint."""
|
|
|
|
| 456 |
worker_pool = getattr(app.state, 'worker_pool', None)
|
| 457 |
return {
|
| 458 |
+
"status": "ok",
|
| 459 |
+
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
| 460 |
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
| 461 |
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
| 462 |
}
|