Spaces:
Sleeping
Sleeping
| """ | |
| SAM-Z-1 Distributed Compute Cluster Head Node | |
| - Smart load balancing with distributed compute | |
| - Real-time status dashboard | |
| """ | |
| from fastapi import FastAPI, HTTPException, WebSocket | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| import asyncio | |
| import json | |
| import time | |
| from typing import List, Optional, Dict | |
| from collections import deque | |
| import random | |
| app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0") | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| WORKER_URLS = [ | |
| "https://bc-ai-worker-2.hf.space", | |
| "https://bc-ai-worker-sam-z-api.hf.space", | |
| "https://bc-ai-worker-3.hf.space", | |
| "https://bc-ai-worker-4.hf.space", | |
| "https://bc-ai-worker-5.hf.space", | |
| "https://bc-ai-worker-sam-z-api-2.hf.space" | |
| ] | |
| HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard | |
| LOAD_CHECK_WINDOW = 10 | |
| LIGHT_LOAD_THRESHOLD = 2 | |
| HEAVY_LOAD_THRESHOLD = 5 | |
| # Worker state | |
| worker_health = { | |
| url: { | |
| "healthy": True, | |
| "last_check": 0, | |
| "active_requests": 0, | |
| "total_requests": 0, | |
| "total_tokens": 0, | |
| "avg_latency": 0, | |
| "role": "idle" # "generator", "decoder", "full", "idle" | |
| } for url in WORKER_URLS | |
| } | |
| request_timestamps = deque(maxlen=100) | |
| current_load_mode = "light" # "light", "medium", "heavy" | |
| cluster_stats = { | |
| "total_requests": 0, | |
| "successful_requests": 0, | |
| "failed_requests": 0, | |
| "uptime_start": time.time() | |
| } | |
| # Active WebSocket connections for real-time updates | |
| active_connections = set() | |
| # ============================================================================ | |
| # Request Models | |
| # ============================================================================ | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| # ============================================================================ | |
| # Load Management | |
| # ============================================================================ | |
| def get_current_load() -> int: | |
| now = time.time() | |
| return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW) | |
| def update_load_mode(): | |
| global current_load_mode | |
| load = get_current_load() | |
| healthy_count = len(get_healthy_workers()) | |
| # Adjust thresholds based on available workers | |
| if healthy_count >= 5: | |
| if load <= LIGHT_LOAD_THRESHOLD: | |
| current_load_mode = "light" # 1 gen + 4 decoders | |
| elif load <= MEDIUM_LOAD_THRESHOLD: | |
| current_load_mode = "medium" # 2 gens + 3 decoders OR parallel requests | |
| else: | |
| current_load_mode = "heavy" # all workers independent | |
| elif healthy_count >= 3: | |
| if load <= 2: | |
| current_load_mode = "light" # 1 gen + 2 decoders | |
| else: | |
| current_load_mode = "heavy" # distribute requests | |
| else: | |
| current_load_mode = "heavy" # fallback to simple distribution | |
| return current_load_mode, load | |
| def track_request(): | |
| request_timestamps.append(time.time()) | |
| cluster_stats["total_requests"] += 1 | |
| def get_healthy_workers() -> List[str]: | |
| return [url for url, status in worker_health.items() if status["healthy"]] | |
| def get_least_busy_worker() -> Optional[str]: | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| return None | |
| return min(healthy, key=lambda url: worker_health[url]["active_requests"]) | |
| def select_distributed_workers() -> tuple: | |
| """ | |
| Select workers for distributed compute | |
| Returns: (generators: List[str], decoders: List[str]) | |
| """ | |
| healthy = get_healthy_workers() | |
| if len(healthy) < 2: | |
| return ([healthy[0]], []) if len(healthy) == 1 else ([], []) | |
| # Sort by least busy | |
| sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"]) | |
| if len(healthy) >= 5: | |
| # OPTIMAL: 1 generator, 4 decoders | |
| return ([sorted_workers[0]], sorted_workers[1:5]) | |
| elif len(healthy) == 4: | |
| # 1 generator, 3 decoders | |
| return ([sorted_workers[0]], sorted_workers[1:4]) | |
| elif len(healthy) == 3: | |
| # 1 generator, 2 decoders | |
| return ([sorted_workers[0]], sorted_workers[1:3]) | |
| else: | |
| # 1 generator, 1 decoder | |
| return ([sorted_workers[0]], [sorted_workers[1]]) | |
| async def broadcast_stats(): | |
| """Broadcast stats to all connected WebSocket clients""" | |
| if not active_connections: | |
| return | |
| mode, load = update_load_mode() | |
| uptime = time.time() - cluster_stats["uptime_start"] | |
| stats = { | |
| "timestamp": time.time(), | |
| "mode": mode, | |
| "load": load, | |
| "workers": [ | |
| { | |
| "url": url.split("//")[1].split(".")[0], # shorter name | |
| "healthy": status["healthy"], | |
| "active": status["active_requests"], | |
| "total": status["total_requests"], | |
| "tokens": status["total_tokens"], | |
| "latency": round(status["avg_latency"], 2), | |
| "role": status["role"] | |
| } | |
| for url, status in worker_health.items() | |
| ], | |
| "cluster": { | |
| "total_requests": cluster_stats["total_requests"], | |
| "successful": cluster_stats["successful_requests"], | |
| "failed": cluster_stats["failed_requests"], | |
| "uptime": round(uptime, 0), | |
| "rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2) | |
| } | |
| } | |
| # Broadcast to all connections | |
| disconnected = set() | |
| for ws in active_connections: | |
| try: | |
| await ws.send_json(stats) | |
| except: | |
| disconnected.add(ws) | |
| # Remove disconnected | |
| active_connections.difference_update(disconnected) | |
| async def check_worker_health(worker_url: str) -> bool: | |
| try: | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| response = await client.get(f"{worker_url}/health") | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| async def health_check_loop(): | |
| while True: | |
| # Check all workers | |
| for worker_url in WORKER_URLS: | |
| healthy = await check_worker_health(worker_url) | |
| worker_health[worker_url]["healthy"] = healthy | |
| worker_health[worker_url]["last_check"] = time.time() | |
| # Always broadcast stats to connected clients | |
| await broadcast_stats() | |
| await asyncio.sleep(HEALTH_CHECK_INTERVAL) | |
| async def startup_event(): | |
| asyncio.create_task(health_check_loop()) | |
| # ============================================================================ | |
| # Distributed Compute Generation | |
| # ============================================================================ | |
| async def distributed_generation( | |
| generators: List[str], | |
| decoders: List[str], | |
| request_data: dict, | |
| endpoint: str = "generate" | |
| ): | |
| """ | |
| DISTRIBUTED COMPUTE MODE | |
| - Generator(s) produce token IDs | |
| - Multiple decoders process in parallel (load balanced) | |
| """ | |
| if not generators or not decoders: | |
| return | |
| token_queue = asyncio.Queue(maxsize=50) | |
| text_queue = asyncio.Queue(maxsize=50) | |
| # Mark roles | |
| for gen_url in generators: | |
| worker_health[gen_url]["role"] = "generator" | |
| for dec_url in decoders: | |
| worker_health[dec_url]["role"] = "decoder" | |
| async def generate_tokens(): | |
| """Generator worker(s)""" | |
| gen_url = generators[0] # primary generator | |
| try: | |
| worker_health[gen_url]["active_requests"] += 1 | |
| request_data_tokens = {**request_data, "return_token_ids": True} | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{gen_url}/{endpoint}", | |
| json=request_data_tokens | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip() and chunk.startswith("data: "): | |
| try: | |
| data = json.loads(chunk[6:]) | |
| if "token_id" in data: | |
| await token_queue.put(data["token_id"]) | |
| elif "done" in data: | |
| # Send done signal for each decoder | |
| for _ in decoders: | |
| await token_queue.put(None) | |
| break | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"β Generator error: {e}") | |
| for _ in decoders: | |
| await token_queue.put(None) | |
| finally: | |
| worker_health[gen_url]["active_requests"] -= 1 | |
| worker_health[gen_url]["role"] = "idle" | |
| async def decode_tokens(decoder_url: str, decoder_id: int): | |
| """Decoder worker - processes tokens from shared queue""" | |
| try: | |
| worker_health[decoder_url]["active_requests"] += 1 | |
| batch = [] | |
| batch_size = 2 # smaller batches for faster streaming | |
| while True: | |
| try: | |
| token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0) | |
| if token_id is None: | |
| # Decode remaining batch | |
| if batch: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| worker_health[decoder_url]["total_tokens"] += len(batch) | |
| await text_queue.put(("done", decoder_id)) | |
| break | |
| batch.append(token_id) | |
| # Decode when batch is full | |
| if len(batch) >= batch_size: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| worker_health[decoder_url]["total_tokens"] += len(batch) | |
| batch = [] | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| print(f"β Decoder {decoder_id} error: {e}") | |
| await text_queue.put(("done", decoder_id)) | |
| finally: | |
| worker_health[decoder_url]["active_requests"] -= 1 | |
| worker_health[decoder_url]["role"] = "idle" | |
| # Start generator | |
| gen_task = asyncio.create_task(generate_tokens()) | |
| # Start all decoders | |
| decoder_tasks = [ | |
| asyncio.create_task(decode_tokens(dec_url, i)) | |
| for i, dec_url in enumerate(decoders) | |
| ] | |
| # Stream results | |
| accumulated_text = "" | |
| decoders_done = 0 | |
| total_decoders = len(decoders) | |
| try: | |
| while decoders_done < total_decoders: | |
| msg_type, data = await text_queue.get() | |
| if msg_type == "done": | |
| decoders_done += 1 | |
| continue | |
| if msg_type == "text": | |
| accumulated_text += data | |
| yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n" | |
| finally: | |
| await gen_task | |
| for task in decoder_tasks: | |
| await task | |
| async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"): | |
| """Standard single-worker generation""" | |
| try: | |
| worker_health[worker_url]["active_requests"] += 1 | |
| worker_health[worker_url]["role"] = "full" | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{worker_url}/{endpoint}", | |
| json=request_data | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip(): | |
| yield chunk | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| finally: | |
| worker_health[worker_url]["active_requests"] -= 1 | |
| worker_health[worker_url]["role"] = "idle" | |
| # ============================================================================ | |
| # Dashboard | |
| # ============================================================================ | |
| async def dashboard(): | |
| """Real-time futuristic dashboard""" | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>SAM-Z-1 Cluster Control</title> | |
| <style> | |
| * { | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| } | |
| body { | |
| font-family: 'Courier New', monospace; | |
| background: linear-gradient(135deg, #0a0e27 0%, #1a1f3a 100%); | |
| color: #00ff88; | |
| min-height: 100vh; | |
| overflow-x: hidden; | |
| overflow-y: auto; | |
| } | |
| .container { | |
| padding: 20px; | |
| max-width: 1400px; | |
| margin: 0 auto; | |
| padding-bottom: 40px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| padding: 20px; | |
| background: rgba(0, 255, 136, 0.1); | |
| border: 2px solid #00ff88; | |
| border-radius: 10px; | |
| box-shadow: 0 0 20px rgba(0, 255, 136, 0.3); | |
| } | |
| .header h1 { | |
| font-size: 2.5em; | |
| text-transform: uppercase; | |
| letter-spacing: 5px; | |
| text-shadow: 0 0 10px #00ff88; | |
| animation: glow 2s ease-in-out infinite alternate; | |
| } | |
| @keyframes glow { | |
| from { text-shadow: 0 0 10px #00ff88, 0 0 20px #00ff88; } | |
| to { text-shadow: 0 0 20px #00ff88, 0 0 30px #00ff88, 0 0 40px #00ff88; } | |
| } | |
| .status-bar { | |
| display: flex; | |
| gap: 20px; | |
| margin-bottom: 30px; | |
| } | |
| .stat-card { | |
| flex: 1; | |
| background: rgba(0, 255, 136, 0.05); | |
| border: 1px solid #00ff88; | |
| border-radius: 8px; | |
| padding: 15px; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .stat-card::before { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| left: -100%; | |
| width: 100%; | |
| height: 100%; | |
| background: linear-gradient(90deg, transparent, rgba(0, 255, 136, 0.2), transparent); | |
| animation: scan 3s infinite; | |
| } | |
| @keyframes scan { | |
| 0% { left: -100%; } | |
| 100% { left: 100%; } | |
| } | |
| .stat-label { | |
| font-size: 0.8em; | |
| opacity: 0.7; | |
| text-transform: uppercase; | |
| } | |
| .stat-value { | |
| font-size: 2em; | |
| font-weight: bold; | |
| margin-top: 5px; | |
| } | |
| .mode-badge { | |
| display: inline-block; | |
| padding: 5px 15px; | |
| border-radius: 20px; | |
| font-size: 0.9em; | |
| font-weight: bold; | |
| text-transform: uppercase; | |
| margin-top: 10px; | |
| } | |
| .mode-light { | |
| background: rgba(0, 255, 136, 0.2); | |
| border: 1px solid #00ff88; | |
| color: #00ff88; | |
| } | |
| .mode-heavy { | |
| background: rgba(255, 68, 68, 0.2); | |
| border: 1px solid #ff4444; | |
| color: #ff4444; | |
| } | |
| .workers-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); | |
| gap: 20px; | |
| margin-bottom: 30px; | |
| } | |
| @media (max-width: 768px) { | |
| .workers-grid { | |
| grid-template-columns: 1fr; | |
| } | |
| .status-bar { | |
| flex-direction: column; | |
| } | |
| .info-grid { | |
| grid-template-columns: repeat(2, 1fr); | |
| } | |
| .header h1 { | |
| font-size: 1.5em; | |
| letter-spacing: 2px; | |
| } | |
| } | |
| .worker-card { | |
| background: rgba(10, 14, 39, 0.8); | |
| border: 2px solid #00ff88; | |
| border-radius: 10px; | |
| padding: 20px; | |
| position: relative; | |
| transition: all 0.3s; | |
| } | |
| .worker-card:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 5px 30px rgba(0, 255, 136, 0.4); | |
| } | |
| .worker-card.offline { | |
| border-color: #ff4444; | |
| opacity: 0.6; | |
| } | |
| .worker-header { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 15px; | |
| } | |
| .worker-name { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| } | |
| .status-dot { | |
| width: 12px; | |
| height: 12px; | |
| border-radius: 50%; | |
| animation: pulse 2s infinite; | |
| } | |
| .status-dot.online { | |
| background: #00ff88; | |
| box-shadow: 0 0 10px #00ff88; | |
| } | |
| .status-dot.offline { | |
| background: #ff4444; | |
| box-shadow: 0 0 10px #ff4444; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.5; } | |
| } | |
| .worker-stats { | |
| display: grid; | |
| grid-template-columns: repeat(2, 1fr); | |
| gap: 10px; | |
| margin-top: 15px; | |
| } | |
| .worker-stat { | |
| background: rgba(0, 255, 136, 0.05); | |
| padding: 10px; | |
| border-radius: 5px; | |
| } | |
| .worker-stat-label { | |
| font-size: 0.7em; | |
| opacity: 0.7; | |
| } | |
| .worker-stat-value { | |
| font-size: 1.3em; | |
| font-weight: bold; | |
| margin-top: 3px; | |
| } | |
| .role-badge { | |
| display: inline-block; | |
| padding: 3px 10px; | |
| border-radius: 12px; | |
| font-size: 0.75em; | |
| margin-top: 10px; | |
| font-weight: bold; | |
| } | |
| .role-generator { | |
| background: rgba(255, 165, 0, 0.2); | |
| border: 1px solid #ffa500; | |
| color: #ffa500; | |
| } | |
| .role-decoder { | |
| background: rgba(0, 191, 255, 0.2); | |
| border: 1px solid #00bfff; | |
| color: #00bfff; | |
| } | |
| .role-full { | |
| background: rgba(138, 43, 226, 0.2); | |
| border: 1px solid #8a2be2; | |
| color: #8a2be2; | |
| } | |
| .role-idle { | |
| background: rgba(128, 128, 128, 0.2); | |
| border: 1px solid #808080; | |
| color: #808080; | |
| } | |
| .progress-bar { | |
| width: 100%; | |
| height: 4px; | |
| background: rgba(0, 255, 136, 0.1); | |
| border-radius: 2px; | |
| margin-top: 10px; | |
| overflow: hidden; | |
| } | |
| .progress-fill { | |
| height: 100%; | |
| background: linear-gradient(90deg, #00ff88, #00ffff); | |
| transition: width 0.3s; | |
| box-shadow: 0 0 10px #00ff88; | |
| } | |
| .cluster-info { | |
| background: rgba(0, 255, 136, 0.05); | |
| border: 1px solid #00ff88; | |
| border-radius: 8px; | |
| padding: 20px; | |
| } | |
| .info-grid { | |
| display: grid; | |
| grid-template-columns: repeat(4, 1fr); | |
| gap: 20px; | |
| } | |
| .info-item { | |
| text-align: center; | |
| } | |
| .timestamp { | |
| text-align: center; | |
| margin-top: 20px; | |
| opacity: 0.5; | |
| font-size: 0.9em; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="header"> | |
| <h1>β‘ SAM-Z-1 CLUSTER β‘</h1> | |
| <div>DISTRIBUTED COMPUTE SYSTEM v4.0</div> | |
| </div> | |
| <div class="status-bar"> | |
| <div class="stat-card"> | |
| <div class="stat-label">Load Mode</div> | |
| <div class="stat-value" id="mode">--</div> | |
| <div class="mode-badge" id="mode-badge">INITIALIZING</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Current Load</div> | |
| <div class="stat-value" id="load">0</div> | |
| <div class="stat-label">requests / 10s</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Total Requests</div> | |
| <div class="stat-value" id="total-req">0</div> | |
| </div> | |
| <div class="stat-card"> | |
| <div class="stat-label">Req/Sec</div> | |
| <div class="stat-value" id="rps">0.00</div> | |
| </div> | |
| </div> | |
| <div class="workers-grid" id="workers"> | |
| <!-- Workers populated by JS --> | |
| </div> | |
| <div class="cluster-info"> | |
| <div class="stat-label" style="margin-bottom: 15px;">CLUSTER STATISTICS</div> | |
| <div class="info-grid"> | |
| <div class="info-item"> | |
| <div class="stat-label">Successful</div> | |
| <div class="stat-value" style="font-size: 1.5em;" id="success">0</div> | |
| </div> | |
| <div class="info-item"> | |
| <div class="stat-label">Failed</div> | |
| <div class="stat-value" style="font-size: 1.5em;" id="failed">0</div> | |
| </div> | |
| <div class="info-item"> | |
| <div class="stat-label">Uptime</div> | |
| <div class="stat-value" style="font-size: 1.5em;" id="uptime">0s</div> | |
| </div> | |
| <div class="info-item"> | |
| <div class="stat-label">Healthy Workers</div> | |
| <div class="stat-value" style="font-size: 1.5em;" id="healthy">0</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="timestamp" id="timestamp">Last update: --</div> | |
| </div> | |
| <script> | |
| // Use wss:// for HTTPS, ws:// for HTTP | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| let ws; | |
| let usePolling = false; | |
| function connectWebSocket() { | |
| try { | |
| ws = new WebSocket(`${protocol}//${window.location.host}/ws`); | |
| ws.onopen = () => { | |
| console.log('β WebSocket connected'); | |
| usePolling = false; | |
| }; | |
| ws.onmessage = (event) => { | |
| const data = JSON.parse(event.data); | |
| updateDashboard(data); | |
| }; | |
| ws.onerror = (error) => { | |
| console.error('β WebSocket error, switching to polling'); | |
| usePolling = true; | |
| startPolling(); | |
| }; | |
| ws.onclose = () => { | |
| console.log('π WebSocket disconnected'); | |
| if (!usePolling) { | |
| setTimeout(connectWebSocket, 3000); | |
| } | |
| }; | |
| } catch (e) { | |
| console.error('Failed to connect WebSocket, using polling'); | |
| usePolling = true; | |
| startPolling(); | |
| } | |
| } | |
| async function pollStats() { | |
| if (!usePolling) return; | |
| try { | |
| const response = await fetch('/api/status'); | |
| const data = await response.json(); | |
| // Fetch worker stats too | |
| const workersRes = await fetch('/workers'); | |
| const workersData = await workersRes.json(); | |
| // Format data like WebSocket | |
| const formattedData = { | |
| timestamp: Date.now() / 1000, | |
| mode: data.mode, | |
| load: data.current_load, | |
| workers: workersData.workers.map(w => ({ | |
| url: w.url.split("//")[1].split(".")[0], | |
| healthy: w.healthy, | |
| active: w.active_requests || 0, | |
| total: 0, | |
| tokens: 0, | |
| latency: 0, | |
| role: "idle" | |
| })), | |
| cluster: { | |
| total_requests: 0, | |
| successful: 0, | |
| failed: 0, | |
| uptime: 0, | |
| rps: 0 | |
| } | |
| }; | |
| updateDashboard(formattedData); | |
| } catch (e) { | |
| console.error('Polling error:', e); | |
| } | |
| } | |
| function startPolling() { | |
| pollStats(); | |
| setInterval(pollStats, 1000); | |
| } | |
| // Try WebSocket first | |
| connectWebSocket(); | |
| function updateDashboard(data) { | |
| // Mode | |
| document.getElementById('mode').textContent = data.mode.toUpperCase(); | |
| const modeBadge = document.getElementById('mode-badge'); | |
| modeBadge.textContent = `${data.mode.toUpperCase()} MODE`; | |
| modeBadge.className = `mode-badge mode-${data.mode}`; | |
| // Stats | |
| document.getElementById('load').textContent = data.load; | |
| document.getElementById('total-req').textContent = data.cluster.total_requests; | |
| document.getElementById('rps').textContent = data.cluster.rps; | |
| document.getElementById('success').textContent = data.cluster.successful; | |
| document.getElementById('failed').textContent = data.cluster.failed; | |
| document.getElementById('uptime').textContent = formatUptime(data.cluster.uptime); | |
| // Workers | |
| const workersDiv = document.getElementById('workers'); | |
| const healthyCount = data.workers.filter(w => w.healthy).length; | |
| document.getElementById('healthy').textContent = `${healthyCount}/${data.workers.length}`; | |
| workersDiv.innerHTML = data.workers.map(worker => ` | |
| <div class="worker-card ${worker.healthy ? '' : 'offline'}"> | |
| <div class="worker-header"> | |
| <div class="worker-name">${worker.url}</div> | |
| <div class="status-dot ${worker.healthy ? 'online' : 'offline'}"></div> | |
| </div> | |
| <div class="role-badge role-${worker.role}">${worker.role.toUpperCase()}</div> | |
| <div class="worker-stats"> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Active</div> | |
| <div class="worker-stat-value">${worker.active}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Total</div> | |
| <div class="worker-stat-value">${worker.total}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Tokens</div> | |
| <div class="worker-stat-value">${worker.tokens}</div> | |
| </div> | |
| <div class="worker-stat"> | |
| <div class="worker-stat-label">Latency</div> | |
| <div class="worker-stat-value">${worker.latency}ms</div> | |
| </div> | |
| </div> | |
| <div class="progress-bar"> | |
| <div class="progress-fill" style="width: ${Math.min(worker.active * 33, 100)}%"></div> | |
| </div> | |
| </div> | |
| `).join(''); | |
| // Timestamp | |
| const now = new Date(); | |
| document.getElementById('timestamp').textContent = | |
| `Last update: ${now.toLocaleTimeString()}`; | |
| } | |
| function formatUptime(seconds) { | |
| const h = Math.floor(seconds / 3600); | |
| const m = Math.floor((seconds % 3600) / 60); | |
| const s = Math.floor(seconds % 60); | |
| return `${h}h ${m}m ${s}s`; | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket for real-time dashboard updates""" | |
| await websocket.accept() | |
| active_connections.add(websocket) | |
| try: | |
| # Send initial data | |
| await broadcast_stats() | |
| # Keep connection alive | |
| while True: | |
| await websocket.receive_text() | |
| except: | |
| pass | |
| finally: | |
| active_connections.discard(websocket) | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def api_status(): | |
| """JSON API for status""" | |
| mode, load = update_load_mode() | |
| healthy_count = len(get_healthy_workers()) | |
| return { | |
| "name": "SAM-Z-1 Distributed Cluster", | |
| "version": "4.0.0", | |
| "mode": mode, | |
| "current_load": load, | |
| "workers": len(WORKER_URLS), | |
| "healthy_workers": healthy_count, | |
| "features": ["distributed_compute", "smart_load_balancing", "real_time_dashboard"] | |
| } | |
| async def health(): | |
| healthy_count = len(get_healthy_workers()) | |
| return { | |
| "status": "healthy" if healthy_count > 0 else "unhealthy", | |
| "workers_healthy": healthy_count | |
| } | |
| async def generate(request: GenerateRequest): | |
| """Generate text with distributed compute""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| cluster_stats["failed_requests"] += 1 | |
| raise HTTPException(status_code=503, detail="No healthy workers") | |
| request_data = { | |
| "prompt": request.prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| print(f"π― {mode.upper()} | Load: {load} | Workers: {len(healthy)}") | |
| try: | |
| if mode == "light" and len(healthy) >= 2: | |
| # DISTRIBUTED MODE - 1 gen + multiple decoders | |
| generators, decoders = select_distributed_workers() | |
| if decoders: | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| distributed_generation(generators, decoders, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| # HEAVY/FALLBACK - single worker | |
| worker = get_least_busy_worker() | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| cluster_stats["failed_requests"] += 1 | |
| raise | |
| async def chat(request: ChatRequest): | |
| """Chat with distributed compute""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| cluster_stats["failed_requests"] += 1 | |
| raise HTTPException(status_code=503, detail="No healthy workers") | |
| request_data = { | |
| "messages": [{"role": m.role, "content": m.content} for m in request.messages], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| print(f"π¬ {mode.upper()} | Load: {load} | Workers: {len(healthy)}") | |
| try: | |
| if mode == "light" and len(healthy) >= 2: | |
| # DISTRIBUTED MODE - 1 gen + multiple decoders | |
| generators, decoders = select_distributed_workers() | |
| if decoders: | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| distributed_generation(generators, decoders, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| # HEAVY/FALLBACK - single worker | |
| worker = get_least_busy_worker() | |
| cluster_stats["successful_requests"] += 1 | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| cluster_stats["failed_requests"] += 1 | |
| raise | |
| # ============================================================================ | |
| # Launch | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |