Spaces:
Sleeping
Sleeping
| """ | |
| SAM-Z-1 Distributed Compute Cluster Head Node v5.0 | |
| - Smart load balancing with distributed compute | |
| - Real-time status dashboard | |
| - Auto-detects worker version (v4 vs v5) | |
| - Supports 4 new models with backward compatibility | |
| """ | |
| from fastapi import FastAPI, HTTPException, WebSocket | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| import asyncio | |
| import json | |
| import time | |
| from typing import List, Optional, Dict | |
| from collections import deque | |
| import random | |
| app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="5.0.0") | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| WORKER_URLS = [ | |
| "https://bc-ai-worker-2.hf.space", | |
| "https://bc-ai-worker-sam-z-api.hf.space", | |
| "https://bc-ai-worker-sam-z-api.hf.space", | |
| "https://bc-ai-worker-3.hf.space", | |
| "https://bc-ai-worker-4.hf.space", | |
| "https://bc-ai-worker-5.hf.space" | |
| ] | |
| HEALTH_CHECK_INTERVAL = 5 | |
| LOAD_CHECK_WINDOW = 10 | |
| LIGHT_LOAD_THRESHOLD = 2 | |
| HEAVY_LOAD_THRESHOLD = 5 | |
| # New models added in v5 | |
| NEW_MODELS = [ | |
| "SAM-X-1-Large", | |
| "SAM-X-1-Fast", | |
| "SAM-X-1-Mini", | |
| "SAM-X-1-Nano" | |
| ] | |
| # Worker state with version detection | |
| worker_health = { | |
| url: { | |
| "healthy": True, | |
| "last_check": 0, | |
| "active_requests": 0, | |
| "total_requests": 0, | |
| "total_tokens": 0, | |
| "avg_latency": 0, | |
| "role": "idle", | |
| "version": None, # Will be auto-detected: "v4" or "v5" | |
| "supports_models": [] # Models this worker supports | |
| } for url in WORKER_URLS | |
| } | |
| request_timestamps = deque(maxlen=100) | |
| current_load_mode = "light" | |
| cluster_stats = { | |
| "total_requests": 0, | |
| "successful_requests": 0, | |
| "failed_requests": 0, | |
| "uptime_start": time.time() | |
| } | |
| active_connections = set() | |
| # ============================================================================ | |
| # Request Models | |
| # ============================================================================ | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| model: Optional[str] = None # NEW: Model selection | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| model: Optional[str] = None # NEW: Model selection | |
| # ============================================================================ | |
| # Worker Version Detection | |
| # ============================================================================ | |
| async def detect_worker_version(worker_url: str) -> tuple: | |
| """ | |
| Detect worker version and supported models | |
| Returns: (version: str, supported_models: List[str]) | |
| """ | |
| try: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| # Try to get worker info endpoint (v5 feature) | |
| try: | |
| response = await client.get(f"{worker_url}/info") | |
| if response.status_code == 200: | |
| data = response.json() | |
| version = data.get("version", "v5") | |
| models = data.get("models", NEW_MODELS) | |
| return version, models | |
| except: | |
| pass | |
| # Try to get models list (v5 feature) | |
| try: | |
| response = await client.get(f"{worker_url}/models") | |
| if response.status_code == 200: | |
| data = response.json() | |
| models = data.get("models", []) | |
| if models: | |
| return "v5", models | |
| except: | |
| pass | |
| # Fallback: worker is v4 (no model selection) | |
| return "v4", [] | |
| except Exception as e: | |
| print(f"⚠️ Version detection failed for {worker_url}: {e}") | |
| return "v4", [] | |
| async def check_worker_health(worker_url: str) -> bool: | |
| """Check if worker is healthy""" | |
| try: | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| response = await client.get(f"{worker_url}/health") | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| async def health_check_loop(): | |
| """Health check with version detection""" | |
| while True: | |
| for worker_url in WORKER_URLS: | |
| # Check health | |
| healthy = await check_worker_health(worker_url) | |
| worker_health[worker_url]["healthy"] = healthy | |
| worker_health[worker_url]["last_check"] = time.time() | |
| # Detect version if not yet detected | |
| if worker_health[worker_url]["version"] is None: | |
| version, models = await detect_worker_version(worker_url) | |
| worker_health[worker_url]["version"] = version | |
| worker_health[worker_url]["supports_models"] = models | |
| status = "✅" if healthy else "❌" | |
| print(f"{status} Worker: {worker_url.split('//')[1].split('.')[0]} | Version: {version} | Models: {len(models)}") | |
| await broadcast_stats() | |
| await asyncio.sleep(HEALTH_CHECK_INTERVAL) | |
| async def startup_event(): | |
| asyncio.create_task(health_check_loop()) | |
| # ============================================================================ | |
| # Smart Worker Selection | |
| # ============================================================================ | |
| def get_workers_for_model(model_name: Optional[str]) -> List[str]: | |
| """Get workers that support the requested model""" | |
| healthy = get_healthy_workers() | |
| if not model_name: | |
| # No specific model requested, use any healthy worker | |
| return healthy | |
| # Filter workers by model support | |
| compatible = [] | |
| for url in healthy: | |
| version = worker_health[url]["version"] | |
| models = worker_health[url]["supports_models"] | |
| if version == "v5" and model_name in models: | |
| # v5 worker with explicit model support | |
| compatible.append(url) | |
| elif version == "v4": | |
| # v4 workers don't support model selection but work with default | |
| compatible.append(url) | |
| return compatible if compatible else healthy | |
| def get_healthy_workers() -> List[str]: | |
| return [url for url, status in worker_health.items() if status["healthy"]] | |
| def get_least_busy_worker(worker_list: List[str] = None) -> Optional[str]: | |
| workers = worker_list if worker_list is not None else get_healthy_workers() | |
| if not workers: | |
| return None | |
| return min(workers, key=lambda url: worker_health[url]["active_requests"]) | |
| def select_distributed_workers(model_name: Optional[str] = None) -> tuple: | |
| """ | |
| Select workers for distributed compute with model compatibility | |
| Returns: (generators: List[str], decoders: List[str]) | |
| """ | |
| compatible = get_workers_for_model(model_name) | |
| if len(compatible) < 2: | |
| return ([compatible[0]], []) if len(compatible) == 1 else ([], []) | |
| sorted_workers = sorted(compatible, key=lambda url: worker_health[url]["active_requests"]) | |
| if len(compatible) >= 5: | |
| return ([sorted_workers[0]], sorted_workers[1:5]) | |
| elif len(compatible) == 4: | |
| return ([sorted_workers[0]], sorted_workers[1:4]) | |
| elif len(compatible) == 3: | |
| return ([sorted_workers[0]], sorted_workers[1:3]) | |
| else: | |
| return ([sorted_workers[0]], [sorted_workers[1]]) | |
| # ============================================================================ | |
| # Load Management | |
| # ============================================================================ | |
| def get_current_load() -> int: | |
| now = time.time() | |
| return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW) | |
| def update_load_mode(): | |
| global current_load_mode | |
| load = get_current_load() | |
| healthy_count = len(get_healthy_workers()) | |
| if healthy_count >= 5: | |
| if load <= LIGHT_LOAD_THRESHOLD: | |
| current_load_mode = "light" | |
| elif load <= HEAVY_LOAD_THRESHOLD: | |
| current_load_mode = "medium" | |
| else: | |
| current_load_mode = "heavy" | |
| elif healthy_count >= 3: | |
| if load <= 2: | |
| current_load_mode = "light" | |
| else: | |
| current_load_mode = "heavy" | |
| else: | |
| current_load_mode = "heavy" | |
| return current_load_mode, load | |
| def track_request(): | |
| request_timestamps.append(time.time()) | |
| cluster_stats["total_requests"] += 1 | |
| # ============================================================================ | |
| # Dashboard & WebSocket | |
| # ============================================================================ | |
| async def broadcast_stats(): | |
| if not active_connections: | |
| return | |
| mode, load = update_load_mode() | |
| uptime = time.time() - cluster_stats["uptime_start"] | |
| stats = { | |
| "timestamp": time.time(), | |
| "mode": mode, | |
| "load": load, | |
| "workers": [ | |
| { | |
| "url": url.split("//")[1].split(".")[0], | |
| "healthy": status["healthy"], | |
| "active": status["active_requests"], | |
| "total": status["total_requests"], | |
| "tokens": status["total_tokens"], | |
| "latency": round(status["avg_latency"], 2), | |
| "role": status["role"], | |
| "version": status["version"] or "detecting...", | |
| "models": len(status["supports_models"]) | |
| } | |
| for url, status in worker_health.items() | |
| ], | |
| "cluster": { | |
| "total_requests": cluster_stats["total_requests"], | |
| "successful": cluster_stats["successful_requests"], | |
| "failed": cluster_stats["failed_requests"], | |
| "uptime": round(uptime, 0), | |
| "rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2) | |
| } | |
| } | |
| disconnected = set() | |
| for ws in active_connections: | |
| try: | |
| await ws.send_json(stats) | |
| except: | |
| disconnected.add(ws) | |
| active_connections.difference_update(disconnected) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| active_connections.add(websocket) | |
| try: | |
| await broadcast_stats() | |
| while True: | |
| await websocket.receive_text() | |
| except: | |
| pass | |
| finally: | |
| active_connections.discard(websocket) | |
| # ============================================================================ | |
| # Distributed Generation | |
| # ============================================================================ | |
| async def distributed_generation( | |
| generators: List[str], | |
| decoders: List[str], | |
| request_data: dict, | |
| endpoint: str = "generate" | |
| ): | |
| """Distributed compute with v4/v5 compatibility""" | |
| if not generators or not decoders: | |
| return | |
| token_queue = asyncio.Queue(maxsize=50) | |
| text_queue = asyncio.Queue(maxsize=50) | |
| for gen_url in generators: | |
| worker_health[gen_url]["role"] = "generator" | |
| for dec_url in decoders: | |
| worker_health[dec_url]["role"] = "decoder" | |
| async def generate_tokens(): | |
| gen_url = generators[0] | |
| try: | |
| worker_health[gen_url]["active_requests"] += 1 | |
| request_data_tokens = {**request_data, "return_token_ids": True} | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{gen_url}/{endpoint}", | |
| json=request_data_tokens | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip() and chunk.startswith("data: "): | |
| try: | |
| data = json.loads(chunk[6:]) | |
| if "token_id" in data: | |
| await token_queue.put(data["token_id"]) | |
| elif "done" in data: | |
| for _ in decoders: | |
| await token_queue.put(None) | |
| break | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"❌ Generator error: {e}") | |
| for _ in decoders: | |
| await token_queue.put(None) | |
| finally: | |
| worker_health[gen_url]["active_requests"] -= 1 | |
| worker_health[gen_url]["role"] = "idle" | |
| async def decode_tokens(decoder_url: str, decoder_id: int): | |
| try: | |
| worker_health[decoder_url]["active_requests"] += 1 | |
| batch = [] | |
| batch_size = 2 | |
| while True: | |
| try: | |
| token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0) | |
| if token_id is None: | |
| if batch: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| worker_health[decoder_url]["total_tokens"] += len(batch) | |
| await text_queue.put(("done", decoder_id)) | |
| break | |
| batch.append(token_id) | |
| if len(batch) >= batch_size: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| worker_health[decoder_url]["total_tokens"] += len(batch) | |
| batch = [] | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| print(f"❌ Decoder {decoder_id} error: {e}") | |
| await text_queue.put(("done", decoder_id)) | |
| finally: | |
| worker_health[decoder_url]["active_requests"] -= 1 | |
| worker_health[decoder_url]["role"] = "idle" | |
| gen_task = asyncio.create_task(generate_tokens()) | |
| decoder_tasks = [ | |
| asyncio.create_task(decode_tokens(dec_url, i)) | |
| for i, dec_url in enumerate(decoders) | |
| ] | |
| accumulated_text = "" | |
| decoders_done = 0 | |
| total_decoders = len(decoders) | |
| try: | |
| while decoders_done < total_decoders: | |
| msg_type, data = await text_queue.get() | |
| if msg_type == "done": | |
| decoders_done += 1 | |
| continue | |
| if msg_type == "text": | |
| accumulated_text += data | |
| yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n" | |
| finally: | |
| await gen_task | |
| for task in decoder_tasks: | |
| await task | |
| async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"): | |
| """Standard single-worker generation""" | |
| try: | |
| worker_health[worker_url]["active_requests"] += 1 | |
| worker_health[worker_url]["role"] = "full" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{worker_url}/{endpoint}", | |
| json=request_data | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip(): | |
| yield chunk | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| finally: | |
| worker_health[worker_url]["active_requests"] -= 1 | |
| worker_health[worker_url]["role"] = "idle" | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def dashboard(): | |
| """Real-time dashboard with version info""" | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>SAM-Z-1 Cluster Control v5.0</title> | |
| <style> | |
| * { margin: 0; padding: 0; box-sizing: border-box; } | |
| body { | |
| font-family: 'Courier New', monospace; | |
| background: linear-gradient(135deg, #0a0e27 0%, #1a1f3a 100%); | |
| color: #00ff88; | |
| min-height: 100vh; | |
| overflow-x: hidden; | |
| overflow-y: auto; | |
| } | |
| .container { | |
| padding: 20px; | |
| max-width: 1400px; | |
| margin: 0 auto; | |
| padding-bottom: 40px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| padding: 20px; | |
| background: rgba(0, 255, 136, 0.1); | |
| border: 2px solid #00ff88; | |
| border-radius: 10px; | |
| box-shadow: 0 0 20px rgba(0, 255, 136, 0.3); | |
| } | |
| .header h1 { | |
| font-size: 2.5em; | |
| text-transform: uppercase; | |
| letter-spacing: 5px; | |
| text-shadow: 0 0 10px #00ff88; | |
| animation: glow 2s ease-in-out infinite alternate; | |
| } | |
| @keyframes glow { | |
| from { text-shadow: 0 0 10px #00ff88, 0 0 20px #00ff88; } | |
| to { text-shadow: 0 0 20px #00ff88, 0 0 30px #00ff88, 0 0 40px #00ff88; } | |
| } | |
| .version-badge { | |
| display: inline-block; | |
| padding: 3px 10px; | |
| border-radius: 12px; | |
| font-size: 10px; | |
| margin-left: 8px; | |
| font-weight: bold; | |
| } | |
| .version-v5 { | |
| background: rgba(0, 255, 136, 0.2); | |
| border: 1px solid #00ff88; | |
| color: #00ff88; | |
| } | |
| .version-v4 { | |
| background: rgba(255, 165, 0, 0.2); | |
| border: 1px solid #ffa500; | |
| color: #ffa500; | |
| } | |
| .status-bar { | |
| display: flex; | |
| gap: 20px; | |
| margin-bottom: 30px; | |
| } | |
| .stat-card { | |
| flex: 1; | |
| background: rgba(0, 255, 136, 0.05); | |
| border: 1px solid #00ff88; | |
| border-radius: 8px; | |
| padding: 15px; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .stat-label { | |
| font-size: 0.8em; | |
| opacity: 0.7; | |
| text-transform: uppercase; | |
| } | |
| .stat-value { | |
| font-size: 2em; | |
| font-weight: bold; | |
| margin-top: 5px; | |
| } | |
| .workers-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); | |
| gap: 20px; | |
| margin-bottom: 30px; | |
| } | |
| .worker-card { | |
| background: rgba(10, 14, 39, 0.8); | |
| border: 2px solid #00ff88; | |
| border-radius: 10px; | |
| padding: 20px; | |
| position: relative; | |
| transition: all 0.3s; | |
| } | |
| .worker-card:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 5px 30px rgba(0, 255, 136, 0.4); | |
| } | |
| .worker-card.offline { | |
| border-color: #ff4444; | |
| opacity: 0.6; | |
| } | |
| .worker-header { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 15px; | |
| } | |
| .worker-name { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| } | |
| .status-dot { | |
| width: 12px; | |
| height: 12px; | |
| border-radius: 50%; | |
| animation: pulse 2s infinite; | |
| } | |
| .status-dot.online { | |
| background: #00ff88; | |
| box-shadow: 0 0 10px #00ff88; | |
| } | |
| .status-dot.offline { | |
| background: #ff4444; | |
| box-shadow: 0 0 10px #ff4444; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.5; } | |
| } | |
| .worker-stats { | |
| display: grid; | |
| grid-template-columns: repeat(2, 1fr); | |
| gap: 10px; | |
| margin-top: 15px; | |
| } | |
| .worker-stat { | |
| background: rgba(0, 255, 136, 0.05); | |
| padding: 10px; | |
| border-radius: 5px; | |
| } | |
| .worker-stat-label { | |
| font-size: 0.7em; | |
| opacity: 0.7; | |
| } | |
| .worker-stat-value { | |
| font-size: 1.3em; | |
| font-weight: bold; | |
| margin-top: 3px; | |
| } | |
| .timestamp { | |
| text-align: center; | |
| margin-top: 20px; | |
| opacity: 0.5; | |
| font-size: 0.9em; | |
| } | |
| @media (max-width: 768px) { | |
| .workers-grid { grid-template-columns: 1fr; } | |
| .status-bar { flex-direction: column; } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="header"> | |
| <h1>⚡ SAM-Z-1 CLUSTER ⚡</h1> | |
| <div>DISTRIBUTED COMPUTE SYSTEM v5.0 • AUTO VERSION DETECTION</div> | |
| </div> | |
| <div class="status-bar"> | |
| <div class="stat-card"> | |
| <div class="stat-label">Load Mode</div> | |
| <div class="stat-value" id="mode">--</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Current Load</div> | |
| <div class="stat-value" id="load">0</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Total Requests</div> | |
| <div class="stat-value" id="total-req">0</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Req/Sec</div> | |
| <div class="stat-value" id="rps">0.00</div> | |
| </div> | |
| </div> | |
| <div class="workers-grid" id="workers"></div> | |
| <div class="timestamp" id="timestamp">Last update: --</div> | |
| </div> | |
| <script> | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| let ws; | |
| function connectWebSocket() { | |
| try { | |
| ws = new WebSocket(`${protocol}//${window.location.host}/ws`); | |
| ws.onopen = () => console.log('✅ WebSocket connected'); | |
| ws.onmessage = (event) => updateDashboard(JSON.parse(event.data)); | |
| ws.onerror = () => console.error('❌ WebSocket error'); | |
| ws.onclose = () => setTimeout(connectWebSocket, 3000); | |
| } catch (e) { | |
| console.error('Failed to connect WebSocket'); | |
| } | |
| } | |
| connectWebSocket(); | |
| function updateDashboard(data) { | |
| document.getElementById('mode').textContent = data.mode.toUpperCase(); | |
| document.getElementById('load').textContent = data.load; | |
| document.getElementById('total-req').textContent = data.cluster.total_requests; | |
| document.getElementById('rps').textContent = data.cluster.rps; | |
| const workersDiv = document.getElementById('workers'); | |
| workersDiv.innerHTML = data.workers.map(worker => ` | |
| <div class="worker-card ${worker.healthy ? '' : 'offline'}"> | |
| <div class="worker-header"> | |
| <div> | |
| <div class="worker-name">${worker.url}</div> | |
| <span class="version-badge version-${worker.version}">${worker.version.toUpperCase()}</span> | |
| ${worker.models > 0 ? `<span style="font-size:0.8em;opacity:0.7;margin-left:5px">${worker.models} models</span>` : ''} | |
| </div> | |
| <div class="status-dot ${worker.healthy ? 'online' : 'offline'}"></div> | |
| </div> | |
| <div class="worker-stats"> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Active</div> | |
| <div class="worker-stat-value">${worker.active}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Total</div> | |
| <div class="worker-stat-value">${worker.total}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Tokens</div> | |
| <div class="worker-stat-value">${worker.tokens}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Role</div> | |
| <div class="worker-stat-value" style="font-size:1em;">${worker.role}</div> | |
| </div> | |
| </div> | |
| </div> | |
| `).join(''); | |
| document.getElementById('timestamp').textContent = | |
| `Last update: ${new Date().toLocaleTimeString()}`; | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| async def api_status(): | |
| mode, load = update_load_mode() | |
| healthy_count = len(get_healthy_workers()) | |
| # Count v4 vs v5 workers | |
| v4_count = sum(1 for w in worker_health.values() if w["version"] == "v4") | |
| v5_count = sum(1 for w in worker_health.values() if w["version"] == "v5") | |
| return { | |
| "name": "SAM-Z-1 Distributed Cluster", | |
| "version": "5.0.0", | |
| "mode": mode, | |
| "current_load": load, | |
| "workers": len(WORKER_URLS), | |
| "healthy_workers": healthy_count, | |
| "v4_workers": v4_count, | |
| "v5_workers": v5_count, | |
| "features": [ | |
| "distributed_compute", | |
| "smart_load_balancing", | |
| "auto_version_detection", | |
| "multi_model_support", | |
| "real_time_dashboard" | |
| ] | |
| } | |
| async def health(): | |
| healthy_count = len(get_healthy_workers()) | |
| return { | |
| "status": "healthy" if healthy_count > 0 else "unhealthy", | |
| "workers_healthy": healthy_count | |
| } | |
| async def list_models(): | |
| """List all available models across all workers""" | |
| all_models = set() | |
| for url, status in worker_health.items(): | |
| if status["healthy"] and status["version"] == "v5": | |
| all_models.update(status["supports_models"]) | |
| return { | |
| "models": sorted(list(all_models)), | |
| "default": "SAM-X-1-Nano" if "SAM-X-1-Nano" in all_models else None | |
| } | |
| async def generate(request: GenerateRequest): | |
| """Generate text with automatic model routing""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| # Get compatible workers | |
| compatible = get_workers_for_model(request.model) | |
| if not compatible: | |
| cluster_stats["failed_requests"] += 1 | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"No workers available for model: {request.model or 'default'}" | |
| ) | |
| request_data = { | |
| "prompt": request.prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| # Add model parameter for v5 workers | |
| if request.model: | |
| request_data["model"] = request.model | |
| print(f"🎯 {mode.upper()} | Load: {load} | Model: {request.model or 'default'} | Workers: {len(compatible)}") | |
| try: | |
| if mode == "light" and len(compatible) >= 2: | |
| generators, decoders = select_distributed_workers(request.model) | |
| if decoders: | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| distributed_generation(generators, decoders, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| worker = get_least_busy_worker(compatible) | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| cluster_stats["failed_requests"] += 1 | |
| raise | |
| async def chat(request: ChatRequest): | |
| """Chat with automatic model routing""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| # Get compatible workers | |
| compatible = get_workers_for_model(request.model) | |
| if not compatible: | |
| cluster_stats["failed_requests"] += 1 | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"No workers available for model: {request.model or 'default'}" | |
| ) | |
| request_data = { | |
| "messages": [{"role": m.role, "content": m.content} for m in request.messages], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| # Add model parameter for v5 workers | |
| if request.model: | |
| request_data["model"] = request.model | |
| print(f"💬 {mode.upper()} | Load: {load} | Model: {request.model or 'default'} | Workers: {len(compatible)}") | |
| try: | |
| if mode == "light" and len(compatible) >= 2: | |
| generators, decoders = select_distributed_workers(request.model) | |
| if decoders: | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| distributed_generation(generators, decoders, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| worker = get_least_busy_worker(compatible) | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| cluster_stats["failed_requests"] += 1 | |
| raise | |
| # ============================================================================ | |
| # Launch | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |