""" SAM-Z-1 Distributed Compute Cluster Head Node - Smart load balancing with distributed compute - Real-time status dashboard """ from fastapi import FastAPI, HTTPException, WebSocket from fastapi.responses import StreamingResponse, HTMLResponse from pydantic import BaseModel import httpx import asyncio import json import time from typing import List, Optional, Dict from collections import deque import random app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0") # ============================================================================ # Configuration # ============================================================================ WORKER_URLS = [ "https://bc-ai-worker-2.hf.space", "https://bc-ai-worker-sam-z-api.hf.space", "https://bc-ai-worker-3.hf.space", "https://bc-ai-worker-4.hf.space", "https://bc-ai-worker-5.hf.space", "https://bc-ai-worker-sam-z-api-2.hf.space" ] HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard LOAD_CHECK_WINDOW = 10 LIGHT_LOAD_THRESHOLD = 2 HEAVY_LOAD_THRESHOLD = 5 # Worker state worker_health = { url: { "healthy": True, "last_check": 0, "active_requests": 0, "total_requests": 0, "total_tokens": 0, "avg_latency": 0, "role": "idle" # "generator", "decoder", "full", "idle" } for url in WORKER_URLS } request_timestamps = deque(maxlen=100) current_load_mode = "light" # "light", "medium", "heavy" cluster_stats = { "total_requests": 0, "successful_requests": 0, "failed_requests": 0, "uptime_start": time.time() } # Active WebSocket connections for real-time updates active_connections = set() # ============================================================================ # Request Models # ============================================================================ class GenerateRequest(BaseModel): prompt: str max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = True class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = True # ============================================================================ # Load Management # ============================================================================ def get_current_load() -> int: now = time.time() return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW) def update_load_mode(): global current_load_mode load = get_current_load() healthy_count = len(get_healthy_workers()) # Adjust thresholds based on available workers if healthy_count >= 5: if load <= LIGHT_LOAD_THRESHOLD: current_load_mode = "light" # 1 gen + 4 decoders elif load <= MEDIUM_LOAD_THRESHOLD: current_load_mode = "medium" # 2 gens + 3 decoders OR parallel requests else: current_load_mode = "heavy" # all workers independent elif healthy_count >= 3: if load <= 2: current_load_mode = "light" # 1 gen + 2 decoders else: current_load_mode = "heavy" # distribute requests else: current_load_mode = "heavy" # fallback to simple distribution return current_load_mode, load def track_request(): request_timestamps.append(time.time()) cluster_stats["total_requests"] += 1 def get_healthy_workers() -> List[str]: return [url for url, status in worker_health.items() if status["healthy"]] def get_least_busy_worker() -> Optional[str]: healthy = get_healthy_workers() if not healthy: return None return min(healthy, key=lambda url: worker_health[url]["active_requests"]) def select_distributed_workers() -> tuple: """ Select workers for distributed compute Returns: (generators: List[str], decoders: List[str]) """ healthy = get_healthy_workers() if len(healthy) < 2: return ([healthy[0]], []) if len(healthy) == 1 else ([], []) # Sort by least busy sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"]) if len(healthy) >= 5: # OPTIMAL: 1 generator, 4 decoders return ([sorted_workers[0]], sorted_workers[1:5]) elif len(healthy) == 4: # 1 generator, 3 decoders return ([sorted_workers[0]], sorted_workers[1:4]) elif len(healthy) == 3: # 1 generator, 2 decoders return ([sorted_workers[0]], sorted_workers[1:3]) else: # 1 generator, 1 decoder return ([sorted_workers[0]], [sorted_workers[1]]) async def broadcast_stats(): """Broadcast stats to all connected WebSocket clients""" if not active_connections: return mode, load = update_load_mode() uptime = time.time() - cluster_stats["uptime_start"] stats = { "timestamp": time.time(), "mode": mode, "load": load, "workers": [ { "url": url.split("//")[1].split(".")[0], # shorter name "healthy": status["healthy"], "active": status["active_requests"], "total": status["total_requests"], "tokens": status["total_tokens"], "latency": round(status["avg_latency"], 2), "role": status["role"] } for url, status in worker_health.items() ], "cluster": { "total_requests": cluster_stats["total_requests"], "successful": cluster_stats["successful_requests"], "failed": cluster_stats["failed_requests"], "uptime": round(uptime, 0), "rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2) } } # Broadcast to all connections disconnected = set() for ws in active_connections: try: await ws.send_json(stats) except: disconnected.add(ws) # Remove disconnected active_connections.difference_update(disconnected) async def check_worker_health(worker_url: str) -> bool: try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{worker_url}/health") return response.status_code == 200 except: return False async def health_check_loop(): while True: # Check all workers for worker_url in WORKER_URLS: healthy = await check_worker_health(worker_url) worker_health[worker_url]["healthy"] = healthy worker_health[worker_url]["last_check"] = time.time() # Always broadcast stats to connected clients await broadcast_stats() await asyncio.sleep(HEALTH_CHECK_INTERVAL) @app.on_event("startup") async def startup_event(): asyncio.create_task(health_check_loop()) # ============================================================================ # Distributed Compute Generation # ============================================================================ async def distributed_generation( generators: List[str], decoders: List[str], request_data: dict, endpoint: str = "generate" ): """ DISTRIBUTED COMPUTE MODE - Generator(s) produce token IDs - Multiple decoders process in parallel (load balanced) """ if not generators or not decoders: return token_queue = asyncio.Queue(maxsize=50) text_queue = asyncio.Queue(maxsize=50) # Mark roles for gen_url in generators: worker_health[gen_url]["role"] = "generator" for dec_url in decoders: worker_health[dec_url]["role"] = "decoder" async def generate_tokens(): """Generator worker(s)""" gen_url = generators[0] # primary generator try: worker_health[gen_url]["active_requests"] += 1 request_data_tokens = {**request_data, "return_token_ids": True} async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream( "POST", f"{gen_url}/{endpoint}", json=request_data_tokens ) as response: async for chunk in response.aiter_text(): if chunk.strip() and chunk.startswith("data: "): try: data = json.loads(chunk[6:]) if "token_id" in data: await token_queue.put(data["token_id"]) elif "done" in data: # Send done signal for each decoder for _ in decoders: await token_queue.put(None) break except: pass except Exception as e: print(f"❌ Generator error: {e}") for _ in decoders: await token_queue.put(None) finally: worker_health[gen_url]["active_requests"] -= 1 worker_health[gen_url]["role"] = "idle" async def decode_tokens(decoder_url: str, decoder_id: int): """Decoder worker - processes tokens from shared queue""" try: worker_health[decoder_url]["active_requests"] += 1 batch = [] batch_size = 2 # smaller batches for faster streaming while True: try: token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0) if token_id is None: # Decode remaining batch if batch: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( f"{decoder_url}/decode", json={"token_ids": batch} ) text = response.json()["text"] await text_queue.put(("text", text)) worker_health[decoder_url]["total_tokens"] += len(batch) await text_queue.put(("done", decoder_id)) break batch.append(token_id) # Decode when batch is full if len(batch) >= batch_size: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( f"{decoder_url}/decode", json={"token_ids": batch} ) text = response.json()["text"] await text_queue.put(("text", text)) worker_health[decoder_url]["total_tokens"] += len(batch) batch = [] except asyncio.TimeoutError: continue except Exception as e: print(f"❌ Decoder {decoder_id} error: {e}") await text_queue.put(("done", decoder_id)) finally: worker_health[decoder_url]["active_requests"] -= 1 worker_health[decoder_url]["role"] = "idle" # Start generator gen_task = asyncio.create_task(generate_tokens()) # Start all decoders decoder_tasks = [ asyncio.create_task(decode_tokens(dec_url, i)) for i, dec_url in enumerate(decoders) ] # Stream results accumulated_text = "" decoders_done = 0 total_decoders = len(decoders) try: while decoders_done < total_decoders: msg_type, data = await text_queue.get() if msg_type == "done": decoders_done += 1 continue if msg_type == "text": accumulated_text += data yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n" finally: await gen_task for task in decoder_tasks: await task async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"): """Standard single-worker generation""" try: worker_health[worker_url]["active_requests"] += 1 worker_health[worker_url]["role"] = "full" async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream( "POST", f"{worker_url}/{endpoint}", json=request_data ) as response: async for chunk in response.aiter_text(): if chunk.strip(): yield chunk except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: worker_health[worker_url]["active_requests"] -= 1 worker_health[worker_url]["role"] = "idle" # ============================================================================ # Dashboard # ============================================================================ @app.get("/", response_class=HTMLResponse) async def dashboard(): """Real-time futuristic dashboard""" return """ SAM-Z-1 Cluster Control

⚡ SAM-Z-1 CLUSTER ⚡

DISTRIBUTED COMPUTE SYSTEM v4.0
Load Mode
--
INITIALIZING
Current Load
0
requests / 10s
Total Requests
0
Req/Sec
0.00
CLUSTER STATISTICS
Successful
0
Failed
0
Uptime
0s
Healthy Workers
0
Last update: --
""" @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """WebSocket for real-time dashboard updates""" await websocket.accept() active_connections.add(websocket) try: # Send initial data await broadcast_stats() # Keep connection alive while True: await websocket.receive_text() except: pass finally: active_connections.discard(websocket) # ============================================================================ # API Endpoints # ============================================================================ @app.get("/api/status") async def api_status(): """JSON API for status""" mode, load = update_load_mode() healthy_count = len(get_healthy_workers()) return { "name": "SAM-Z-1 Distributed Cluster", "version": "4.0.0", "mode": mode, "current_load": load, "workers": len(WORKER_URLS), "healthy_workers": healthy_count, "features": ["distributed_compute", "smart_load_balancing", "real_time_dashboard"] } @app.get("/health") async def health(): healthy_count = len(get_healthy_workers()) return { "status": "healthy" if healthy_count > 0 else "unhealthy", "workers_healthy": healthy_count } @app.post("/v1/generate") async def generate(request: GenerateRequest): """Generate text with distributed compute""" track_request() mode, load = update_load_mode() healthy = get_healthy_workers() if not healthy: cluster_stats["failed_requests"] += 1 raise HTTPException(status_code=503, detail="No healthy workers") request_data = { "prompt": request.prompt, "max_tokens": request.max_tokens, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty, "stream": True } print(f"🎯 {mode.upper()} | Load: {load} | Workers: {len(healthy)}") try: if mode == "light" and len(healthy) >= 2: # DISTRIBUTED MODE - 1 gen + multiple decoders generators, decoders = select_distributed_workers() if decoders: cluster_stats["successful_requests"] += 1 return StreamingResponse( distributed_generation(generators, decoders, request_data, "generate"), media_type="text/event-stream" ) # HEAVY/FALLBACK - single worker worker = get_least_busy_worker() cluster_stats["successful_requests"] += 1 return StreamingResponse( heavy_load_generation(worker, request_data, "generate"), media_type="text/event-stream" ) except Exception as e: cluster_stats["failed_requests"] += 1 raise @app.post("/v1/chat") async def chat(request: ChatRequest): """Chat with distributed compute""" track_request() mode, load = update_load_mode() healthy = get_healthy_workers() if not healthy: cluster_stats["failed_requests"] += 1 raise HTTPException(status_code=503, detail="No healthy workers") request_data = { "messages": [{"role": m.role, "content": m.content} for m in request.messages], "max_tokens": request.max_tokens, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty, "stream": True } print(f"💬 {mode.upper()} | Load: {load} | Workers: {len(healthy)}") try: if mode == "light" and len(healthy) >= 2: # DISTRIBUTED MODE - 1 gen + multiple decoders generators, decoders = select_distributed_workers() if decoders: cluster_stats["successful_requests"] += 1 return StreamingResponse( distributed_generation(generators, decoders, request_data, "chat"), media_type="text/event-stream" ) # HEAVY/FALLBACK - single worker worker = get_least_busy_worker() cluster_stats["successful_requests"] += 1 return StreamingResponse( heavy_load_generation(worker, request_data, "chat"), media_type="text/event-stream" ) except Exception as e: cluster_stats["failed_requests"] += 1 raise # ============================================================================ # Launch # ============================================================================ if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )