Spaces:
Sleeping
Sleeping
| """ | |
| SAM-Z-1 Smart Load Balancing Cluster Head Node | |
| - Light load: parallel gen/decode split for max speed | |
| - Heavy load: 1 worker per request for throughput | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| import asyncio | |
| import json | |
| import time | |
| from typing import List, Optional, Dict | |
| from collections import deque | |
| import random | |
| app = FastAPI(title="SAM-Z-1 Smart Cluster API", version="3.0.0") | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| WORKER_URLS = [ | |
| "https://bc-ai-worker-2.hf.space", | |
| "https://bc-ai-worker-sam-z-api.hf.space", | |
| ] | |
| HEALTH_CHECK_INTERVAL = 30 | |
| LOAD_CHECK_WINDOW = 10 # seconds to measure load | |
| # Load thresholds | |
| LIGHT_LOAD_THRESHOLD = 2 # requests in window | |
| HEAVY_LOAD_THRESHOLD = 5 # requests in window | |
| # Worker state | |
| worker_health = {url: {"healthy": True, "last_check": 0, "active_requests": 0} for url in WORKER_URLS} | |
| request_timestamps = deque(maxlen=100) # track recent requests | |
| current_load_mode = "light" # "light" or "heavy" | |
| # ============================================================================ | |
| # Request Models | |
| # ============================================================================ | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = True | |
| # ============================================================================ | |
| # Load Management | |
| # ============================================================================ | |
| def get_current_load() -> int: | |
| """Calculate current load based on recent requests""" | |
| now = time.time() | |
| # Count requests in the last LOAD_CHECK_WINDOW seconds | |
| return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW) | |
| def update_load_mode(): | |
| """Update load mode based on current load""" | |
| global current_load_mode | |
| load = get_current_load() | |
| if load <= LIGHT_LOAD_THRESHOLD: | |
| current_load_mode = "light" | |
| elif load >= HEAVY_LOAD_THRESHOLD: | |
| current_load_mode = "heavy" | |
| # hysteresis zone between thresholds maintains current mode | |
| return current_load_mode, load | |
| def track_request(): | |
| """Track a new request""" | |
| request_timestamps.append(time.time()) | |
| def get_healthy_workers() -> List[str]: | |
| """Get list of healthy workers""" | |
| return [url for url, status in worker_health.items() if status["healthy"]] | |
| def get_least_busy_worker() -> Optional[str]: | |
| """Get worker with fewest active requests""" | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| return None | |
| return min(healthy, key=lambda url: worker_health[url]["active_requests"]) | |
| def select_worker_pair() -> tuple: | |
| """Select 2 workers for parallel operation""" | |
| healthy = get_healthy_workers() | |
| if len(healthy) < 2: | |
| return (healthy[0], None) if len(healthy) == 1 else (None, None) | |
| # Sort by active requests, take 2 least busy | |
| sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"]) | |
| return (sorted_workers[0], sorted_workers[1]) | |
| async def check_worker_health(worker_url: str) -> bool: | |
| """Check if a worker is healthy""" | |
| try: | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| response = await client.get(f"{worker_url}/health") | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| async def health_check_loop(): | |
| """Background task to check worker health""" | |
| while True: | |
| for worker_url in WORKER_URLS: | |
| healthy = await check_worker_health(worker_url) | |
| worker_health[worker_url]["healthy"] = healthy | |
| worker_health[worker_url]["last_check"] = time.time() | |
| status = "β " if healthy else "β" | |
| active = worker_health[worker_url]["active_requests"] | |
| print(f"{status} {worker_url}: {'healthy' if healthy else 'unhealthy'} | Active: {active}") | |
| mode, load = update_load_mode() | |
| print(f"π Load mode: {mode.upper()} | Current load: {load} req/{LOAD_CHECK_WINDOW}s") | |
| await asyncio.sleep(HEALTH_CHECK_INTERVAL) | |
| async def startup_event(): | |
| """Start health check loop on startup""" | |
| asyncio.create_task(health_check_loop()) | |
| # ============================================================================ | |
| # Generation Strategies | |
| # ============================================================================ | |
| async def light_load_generation( | |
| generator_url: str, | |
| decoder_url: str, | |
| request_data: dict, | |
| endpoint: str = "generate" | |
| ): | |
| """ | |
| LIGHT LOAD MODE: Split generation and decoding | |
| - Generator worker: produces token IDs only | |
| - Decoder worker: decodes token IDs to text | |
| This parallelizes the bottleneck! | |
| """ | |
| # Queues for pipeline | |
| token_queue = asyncio.Queue(maxsize=10) | |
| text_queue = asyncio.Queue(maxsize=10) | |
| async def generate_tokens(): | |
| """Worker 1: Generate token IDs""" | |
| try: | |
| worker_health[generator_url]["active_requests"] += 1 | |
| # Request token IDs only mode | |
| request_data_tokens = {**request_data, "return_token_ids": True} | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{generator_url}/{endpoint}", | |
| json=request_data_tokens | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip() and chunk.startswith("data: "): | |
| try: | |
| data = json.loads(chunk[6:]) | |
| if "token_id" in data: | |
| await token_queue.put(data["token_id"]) | |
| elif "done" in data: | |
| await token_queue.put(None) # Signal end | |
| break | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"β Generator error: {e}") | |
| await token_queue.put(None) | |
| finally: | |
| worker_health[generator_url]["active_requests"] -= 1 | |
| async def decode_tokens(): | |
| """Worker 2: Decode token IDs to text""" | |
| try: | |
| worker_health[decoder_url]["active_requests"] += 1 | |
| batch = [] | |
| batch_size = 5 # decode in small batches for speed | |
| while True: | |
| try: | |
| token_id = await asyncio.wait_for(token_queue.get(), timeout=1.0) | |
| if token_id is None: | |
| # Decode remaining batch | |
| if batch: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| await text_queue.put(("done", None)) | |
| break | |
| batch.append(token_id) | |
| # Decode batch when full | |
| if len(batch) >= batch_size: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.post( | |
| f"{decoder_url}/decode", | |
| json={"token_ids": batch} | |
| ) | |
| text = response.json()["text"] | |
| await text_queue.put(("text", text)) | |
| batch = [] | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| print(f"β Decoder error: {e}") | |
| await text_queue.put(("done", None)) | |
| finally: | |
| worker_health[decoder_url]["active_requests"] -= 1 | |
| # Start both pipelines | |
| gen_task = asyncio.create_task(generate_tokens()) | |
| dec_task = asyncio.create_task(decode_tokens()) | |
| # Stream decoded text | |
| accumulated_text = "" | |
| try: | |
| while True: | |
| msg_type, data = await text_queue.get() | |
| if msg_type == "done": | |
| break | |
| if msg_type == "text": | |
| accumulated_text += data | |
| yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n" | |
| finally: | |
| await gen_task | |
| await dec_task | |
| async def heavy_load_generation( | |
| worker_url: str, | |
| request_data: dict, | |
| endpoint: str = "generate" | |
| ): | |
| """ | |
| HEAVY LOAD MODE: Single worker per request | |
| Standard streaming for max throughput | |
| """ | |
| try: | |
| worker_health[worker_url]["active_requests"] += 1 | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{worker_url}/{endpoint}", | |
| json=request_data | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| if chunk.strip(): | |
| yield chunk | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| finally: | |
| worker_health[worker_url]["active_requests"] -= 1 | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def root(): | |
| """API info""" | |
| healthy_count = len(get_healthy_workers()) | |
| mode, load = update_load_mode() | |
| return { | |
| "name": "SAM-Z-1 Smart Cluster API", | |
| "version": "3.0.0", | |
| "mode": mode, | |
| "current_load": load, | |
| "workers": len(WORKER_URLS), | |
| "healthy_workers": healthy_count, | |
| "features": [ | |
| "smart_load_balancing", | |
| "parallel_gen_decode", | |
| "adaptive_routing" | |
| ], | |
| "load_strategy": { | |
| "light": "parallel gen/decode split for speed", | |
| "heavy": "1 worker per request for throughput" | |
| }, | |
| "endpoints": { | |
| "generate": "/v1/generate", | |
| "chat": "/v1/chat", | |
| "health": "/health", | |
| "workers": "/workers", | |
| "stats": "/stats" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| healthy_count = len(get_healthy_workers()) | |
| mode, load = update_load_mode() | |
| return { | |
| "status": "healthy" if healthy_count > 0 else "unhealthy", | |
| "workers_total": len(WORKER_URLS), | |
| "workers_healthy": healthy_count, | |
| "load_mode": mode, | |
| "current_load": load | |
| } | |
| async def workers_status(): | |
| """Get status of all workers""" | |
| return { | |
| "workers": [ | |
| { | |
| "url": url, | |
| "healthy": status["healthy"], | |
| "active_requests": status["active_requests"], | |
| "last_check": status["last_check"] | |
| } | |
| for url, status in worker_health.items() | |
| ] | |
| } | |
| async def stats(): | |
| """Get cluster statistics""" | |
| mode, load = update_load_mode() | |
| return { | |
| "load_mode": mode, | |
| "current_load": load, | |
| "load_window_seconds": LOAD_CHECK_WINDOW, | |
| "thresholds": { | |
| "light": LIGHT_LOAD_THRESHOLD, | |
| "heavy": HEAVY_LOAD_THRESHOLD | |
| }, | |
| "recent_requests": len(request_timestamps), | |
| "worker_stats": { | |
| url: { | |
| "healthy": status["healthy"], | |
| "active": status["active_requests"] | |
| } | |
| for url, status in worker_health.items() | |
| } | |
| } | |
| async def generate(request: GenerateRequest): | |
| """Generate text with smart load balancing""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| raise HTTPException(status_code=503, detail="No healthy workers available") | |
| request_data = { | |
| "prompt": request.prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| print(f"π― Mode: {mode.upper()} | Load: {load} | Request: generate") | |
| if mode == "light" and len(healthy) >= 2: | |
| # LIGHT LOAD: parallel gen/decode | |
| generator, decoder = select_worker_pair() | |
| return StreamingResponse( | |
| light_load_generation(generator, decoder, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # HEAVY LOAD: single worker | |
| worker = get_least_busy_worker() | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "generate"), | |
| media_type="text/event-stream" | |
| ) | |
| async def chat(request: ChatRequest): | |
| """Chat completion with smart load balancing""" | |
| track_request() | |
| mode, load = update_load_mode() | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| raise HTTPException(status_code=503, detail="No healthy workers available") | |
| request_data = { | |
| "messages": [{"role": m.role, "content": m.content} for m in request.messages], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "stream": True | |
| } | |
| print(f"π― Mode: {mode.upper()} | Load: {load} | Request: chat") | |
| if mode == "light" and len(healthy) >= 2: | |
| # LIGHT LOAD: parallel gen/decode | |
| generator, decoder = select_worker_pair() | |
| return StreamingResponse( | |
| light_load_generation(generator, decoder, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # HEAVY LOAD: single worker | |
| worker = get_least_busy_worker() | |
| return StreamingResponse( | |
| heavy_load_generation(worker, request_data, "chat"), | |
| media_type="text/event-stream" | |
| ) | |
| # ============================================================================ | |
| # Launch | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |