Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from typing import List, AsyncGenerator, Optional | |
| import httpx | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse, ORJSONResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # ===== FastAPI app ===== | |
| app = FastAPI( | |
| title="Qwen3 Main Router", | |
| description="Main router / load balancer for Qwen3 mini servers", | |
| version="1.0.0", | |
| default_response_class=ORJSONResponse, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Comma-separated list of mini server base URLs | |
| # Example: | |
| # MINI_SERVERS="https://username-mini1.hf.space,https://username-mini2.hf.space" | |
| MINI_SERVERS = [ | |
| "https://antaram-server1.hf.space", | |
| "https://antaram-server2.hf.space", | |
| ] | |
| http_client: Optional[httpx.AsyncClient] = None | |
| # Usage stats per mini for /gui | |
| MINI_USAGE = {} # { base_url: {"total_requests": int, "last_used": float or None} } | |
| async def startup(): | |
| global http_client, MINI_USAGE | |
| http_client = httpx.AsyncClient( | |
| timeout=httpx.Timeout(300.0, connect=10.0), | |
| limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), | |
| http2=True, | |
| ) | |
| MINI_USAGE = {base_url: {"total_requests": 0, "last_used": None} for base_url in MINI_SERVERS} | |
| async def shutdown(): | |
| global http_client | |
| if http_client: | |
| await http_client.aclose() | |
| # ===== Shared models (same as mini server) ===== | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class Config: | |
| extra = "ignore" | |
| class ChatRequest(BaseModel): | |
| messages: List[Message] | |
| temperature: float = Field(default=0.6, ge=0.0, le=2.0) | |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) | |
| max_tokens: int = Field(default=4096, ge=1, le=32768) | |
| stream: bool = Field(default=True) | |
| class Config: | |
| extra = "ignore" | |
| class SimpleChatRequest(BaseModel): | |
| prompt: str | |
| temperature: float = Field(default=0.6, ge=0.0, le=2.0) | |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) | |
| max_tokens: int = Field(default=4096, ge=1, le=32768) | |
| stream: bool = Field(default=True) | |
| class Config: | |
| extra = "ignore" | |
| # ===== Mini server coordination helpers ===== | |
| async def reserve_on_mini(base_url: str) -> bool: | |
| """Try to reserve a slot on the given mini. | |
| Returns True if reserved, False if busy/unreachable. | |
| """ | |
| try: | |
| resp = await http_client.post(f"{base_url}/reserve", timeout=5.0) | |
| if resp.status_code == 200: | |
| return True | |
| return False # 429 or anything else | |
| except Exception: | |
| return False | |
| async def release_on_mini(base_url: str) -> None: | |
| """Best-effort release; ignore errors.""" | |
| try: | |
| await http_client.post(f"{base_url}/release", timeout=5.0) | |
| except Exception: | |
| pass | |
| async def choose_mini() -> str: | |
| """Iterate minis and grab the first one that grants a /reserve. | |
| Encodes the logic: | |
| - if mini1 is working/processing (full), try mini2, etc. | |
| """ | |
| if not MINI_SERVERS: | |
| raise HTTPException(status_code=503, detail="No mini servers configured") | |
| for base_url in MINI_SERVERS: | |
| if await reserve_on_mini(base_url): | |
| usage = MINI_USAGE.setdefault(base_url, {"total_requests": 0, "last_used": None}) | |
| usage["total_requests"] += 1 | |
| usage["last_used"] = time.time() | |
| return base_url | |
| raise HTTPException(status_code=503, detail="All mini servers are busy") | |
| # ===== Proxy helpers ===== | |
| async def proxy_sse_to_mini(path: str, payload: dict) -> AsyncGenerator[bytes, None]: | |
| """Streaming proxy: | |
| frontend -> main -> mini (SSE) -> main -> frontend | |
| """ | |
| mini_url = await choose_mini() | |
| full_url = f"{mini_url}{path}" | |
| try: | |
| async with http_client.stream( | |
| "POST", | |
| full_url, | |
| json=payload, | |
| headers={"Accept": "text/event-stream"}, | |
| ) as resp: | |
| if resp.status_code != 200: | |
| body = await resp.aread() | |
| raise HTTPException( | |
| status_code=resp.status_code, | |
| detail=f"Mini error: {body.decode(errors='ignore')}", | |
| ) | |
| async for chunk in resp.aiter_raw(): | |
| # pass SSE bytes straight through | |
| yield chunk | |
| finally: | |
| await release_on_mini(mini_url) | |
| async def proxy_json_to_mini(path: str, payload: dict) -> ORJSONResponse: | |
| mini_url = await choose_mini() | |
| full_url = f"{mini_url}{path}" | |
| try: | |
| resp = await http_client.post(full_url, json=payload) | |
| data = resp.json() | |
| return ORJSONResponse(content=data, status_code=resp.status_code) | |
| finally: | |
| await release_on_mini(mini_url) | |
| # ===== Public endpoints for frontend ===== | |
| async def root(): | |
| return { | |
| "status": "ok", | |
| "message": "Main Qwen3 router is running", | |
| "mini_servers": MINI_SERVERS, | |
| } | |
| async def list_models(): | |
| # You can keep this static; it's the same model ID in all minis | |
| return { | |
| "object": "list", | |
| "data": [{ | |
| "id": "qwen3-0.6b", | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "cluster", | |
| }], | |
| } | |
| async def health(): | |
| """Aggregated health + status across minis.""" | |
| results = [] | |
| for base_url in MINI_SERVERS: | |
| mini_health = None | |
| mini_status = None | |
| try: | |
| # Low-level LLM backend health from mini | |
| resp_h = await http_client.get(f"{base_url}/health", timeout=5.0) | |
| mini_health = resp_h.json() | |
| except Exception as e: | |
| mini_health = {"status": "unreachable", "error": str(e)} | |
| try: | |
| # Load status from mini | |
| resp_s = await http_client.get(f"{base_url}/status", timeout=5.0) | |
| mini_status = resp_s.json() | |
| except Exception as e: | |
| mini_status = {"status": "unknown", "error": str(e)} | |
| usage = MINI_USAGE.get(base_url, {"total_requests": 0, "last_used": None}) | |
| results.append( | |
| { | |
| "mini": base_url, | |
| "health": mini_health, | |
| "status": mini_status, | |
| "usage": usage, | |
| } | |
| ) | |
| return {"mini_servers": results} | |
| async def chat_completions(request: ChatRequest): | |
| payload = request.dict() | |
| if request.stream: | |
| return StreamingResponse( | |
| proxy_sse_to_mini("/v1/chat/completions", payload), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| "Transfer-Encoding": "chunked", | |
| }, | |
| ) | |
| return await proxy_json_to_mini("/v1/chat/completions", payload) | |
| async def simple_chat(request: SimpleChatRequest): | |
| payload = { | |
| "prompt": request.prompt, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "max_tokens": request.max_tokens, | |
| "stream": request.stream, | |
| } | |
| if request.stream: | |
| return StreamingResponse( | |
| proxy_sse_to_mini("/chat", payload), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| return await proxy_json_to_mini("/chat", payload) | |
| async def raw_chat(request: SimpleChatRequest): | |
| payload = { | |
| "prompt": request.prompt, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "max_tokens": request.max_tokens, | |
| "stream": True, | |
| } | |
| return StreamingResponse( | |
| proxy_sse_to_mini("/chat/raw", payload), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| async def fast_chat(prompt: str = "", max_tokens: int = 512): | |
| payload = { | |
| "prompt": prompt, | |
| "max_tokens": max_tokens, | |
| "stream": False, | |
| "temperature": 0.6, | |
| "top_p": 0.95, | |
| } | |
| return await proxy_json_to_mini("/fast", payload) | |
| # ===== Simple GUI at /gui ===== | |
| async def gui(): | |
| """Simple HTML dashboard showing mini servers with lights and stats. | |
| Frontend: GET mainserver.space/gui | |
| """ | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <title>Qwen3 Cluster GUI</title> | |
| <style> | |
| body { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; background: #0f172a; color: #e5e7eb; margin: 0; padding: 24px; } | |
| h1 { margin-bottom: 8px; } | |
| .subtitle { color: #9ca3af; margin-bottom: 24px; } | |
| .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); gap: 16px; } | |
| .card { background: #020617; border-radius: 16px; padding: 16px 18px; border: 1px solid #1f2937; box-shadow: 0 18px 40px rgba(0,0,0,0.4); } | |
| .card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px; } | |
| .mini-name { font-weight: 600; font-size: 14px; } | |
| .badge { font-size: 11px; padding: 4px 8px; border-radius: 999px; text-transform: uppercase; letter-spacing: .06em; } | |
| .badge-up { background: rgba(34,197,94,0.15); color: #4ade80; border: 1px solid rgba(34,197,94,0.4); } | |
| .badge-down { background: rgba(248,113,113,0.12); color: #fca5a5; border: 1px solid rgba(248,113,113,0.4); } | |
| .badge-unknown { background: rgba(148,163,184,0.15); color: #e5e7eb; border: 1px solid rgba(148,163,184,0.4); } | |
| .status-row { display: flex; align-items: center; gap: 8px; margin-bottom: 6px; } | |
| .dot { width: 10px; height: 10px; border-radius: 999px; } | |
| .dot-idle { background: #22c55e; box-shadow: 0 0 12px rgba(34,197,94,0.8); } | |
| .dot-busy { background: #eab308; box-shadow: 0 0 12px rgba(234,179,8,0.8); } | |
| .dot-off { background: #6b7280; } | |
| .label { font-size: 12px; color: #9ca3af; } | |
| .value { font-size: 13px; } | |
| .muted { font-size: 12px; color: #6b7280; } | |
| .footer { margin-top: 18px; font-size: 11px; color: #6b7280; } | |
| #updated { font-size: 11px; color: #9ca3af; margin-bottom: 12px; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Qwen3 Cluster</h1> | |
| <div class="subtitle">Main router view of all mini servers. Lights = idle, busy, or offline.</div> | |
| <div id="updated">Loading...</div> | |
| <div id="grid" class="grid"></div> | |
| <div class="footer">Data source: <code>/health</code> endpoint on this main server. Auto-refresh every 2s.</div> | |
| <script> | |
| async function loadStatus() { | |
| try { | |
| const res = await fetch('/health', { cache: 'no-store' }); | |
| const data = await res.json(); | |
| const grid = document.getElementById('grid'); | |
| grid.innerHTML = ''; | |
| const list = data.mini_servers || []; | |
| if (list.length === 0) { | |
| grid.innerHTML = '<div class="card"><div class="value">No mini servers configured.</div></div>'; | |
| } | |
| list.forEach(item => { | |
| const mini = item.mini; | |
| const health = item.health || {}; | |
| const status = item.status || {}; | |
| const usage = item.usage || {}; | |
| const up = health.status === 'healthy'; | |
| const statusText = status.status || 'unknown'; | |
| const dotClass = | |
| !up ? 'dot-off' : | |
| (statusText === 'busy' ? 'dot-busy' : 'dot-idle'); | |
| const badgeClass = | |
| !up ? 'badge-down' : | |
| 'badge-up'; | |
| const badgeText = | |
| !up ? 'OFFLINE' : | |
| 'ONLINE'; | |
| const total = usage.total_requests || 0; | |
| const lastUsed = usage.last_used | |
| ? new Date(usage.last_used * 1000).toLocaleTimeString() | |
| : 'never'; | |
| const currentReq = status.current_requests !== undefined ? status.current_requests : '?'; | |
| const maxReq = status.max_concurrent !== undefined ? status.max_concurrent : '?'; | |
| const card = document.createElement('div'); | |
| card.className = 'card'; | |
| card.innerHTML = ` | |
| <div class="card-header"> | |
| <div class="mini-name">${mini}</div> | |
| <span class="badge ${badgeClass}">${badgeText}</span> | |
| </div> | |
| <div class="status-row"> | |
| <div class="dot ${dotClass}"></div> | |
| <div class="value">${statusText.toUpperCase()}</div> | |
| </div> | |
| <div class="label">Current load</div> | |
| <div class="value">${currentReq} / ${maxReq} active</div> | |
| <div class="label" style="margin-top:6px;">Usage</div> | |
| <div class="value">Total requests: ${total}</div> | |
| <div class="muted">Last used: ${lastUsed}</div> | |
| `; | |
| grid.appendChild(card); | |
| }); | |
| const now = new Date(); | |
| document.getElementById('updated').textContent = | |
| 'Last updated: ' + now.toLocaleTimeString(); | |
| } catch (e) { | |
| document.getElementById('updated').textContent = | |
| 'Error loading status: ' + e; | |
| } | |
| } | |
| loadStatus(); | |
| setInterval(loadStatus, 2000); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| loop="uvloop", | |
| http="httptools", | |
| access_log=False, | |
| workers=1, | |
| ) | |