""" SAM-Z-1 Cluster Head Node Receives requests and distributes to worker spaces """ from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel import httpx import asyncio import json import time from typing import List, Optional import random app = FastAPI(title="SAM-Z-1 Cluster API", version="1.0.0") # ============================================================================ # Configuration # ============================================================================ # Add your worker space URLs here WORKER_URLS = [ "https://your-username-sam-z1-worker1.hf.space", "https://your-username-sam-z1-worker2.hf.space", # Add more workers as needed ] # Health check interval (seconds) HEALTH_CHECK_INTERVAL = 30 # Worker health status worker_health = {url: {"healthy": True, "last_check": 0} for url in WORKER_URLS} # ============================================================================ # Request Models # ============================================================================ class GenerateRequest(BaseModel): prompt: str max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False class ChatMessage(BaseModel): role: str # "user" or "assistant" content: str class ChatRequest(BaseModel): messages: List[ChatMessage] max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False # ============================================================================ # Load Balancing & Health Checks # ============================================================================ def get_healthy_workers() -> List[str]: """Get list of healthy workers""" return [url for url, status in worker_health.items() if status["healthy"]] def select_worker() -> Optional[str]: """Select a worker using round-robin on healthy workers""" healthy = get_healthy_workers() if not healthy: return None return random.choice(healthy) # You could also implement round-robin here async def check_worker_health(worker_url: str) -> bool: """Check if a worker is healthy""" try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{worker_url}/health") return response.status_code == 200 except: return False async def health_check_loop(): """Background task to check worker health""" while True: for worker_url in WORKER_URLS: healthy = await check_worker_health(worker_url) worker_health[worker_url]["healthy"] = healthy worker_health[worker_url]["last_check"] = time.time() status = "✅" if healthy else "❌" print(f"{status} Worker {worker_url}: {'healthy' if healthy else 'unhealthy'}") await asyncio.sleep(HEALTH_CHECK_INTERVAL) @app.on_event("startup") async def startup_event(): """Start health check loop on startup""" asyncio.create_task(health_check_loop()) # ============================================================================ # API Endpoints # ============================================================================ @app.get("/") async def root(): """API info""" healthy_count = len(get_healthy_workers()) return { "name": "SAM-Z-1 Cluster API", "version": "1.0.0", "workers": len(WORKER_URLS), "healthy_workers": healthy_count, "endpoints": { "generate": "/v1/generate", "chat": "/v1/chat", "health": "/health", "workers": "/workers" } } @app.get("/health") async def health(): """Health check endpoint""" healthy_count = len(get_healthy_workers()) return { "status": "healthy" if healthy_count > 0 else "unhealthy", "workers_total": len(WORKER_URLS), "workers_healthy": healthy_count } @app.get("/workers") async def workers_status(): """Get status of all workers""" return { "workers": [ { "url": url, "healthy": status["healthy"], "last_check": status["last_check"] } for url, status in worker_health.items() ] } @app.post("/v1/generate") async def generate(request: GenerateRequest): """Generate text from prompt""" worker_url = select_worker() if not worker_url: raise HTTPException( status_code=503, detail="No healthy workers available" ) try: async with httpx.AsyncClient(timeout=300.0) as client: if request.stream: # Streaming response async def stream_from_worker(): async with client.stream( "POST", f"{worker_url}/generate", json=request.dict() ) as response: async for chunk in response.aiter_text(): yield chunk return StreamingResponse( stream_from_worker(), media_type="text/event-stream" ) else: # Non-streaming response response = await client.post( f"{worker_url}/generate", json=request.dict() ) return response.json() except httpx.TimeoutException: # Mark worker as unhealthy and retry with another worker_health[worker_url]["healthy"] = False raise HTTPException( status_code=504, detail="Worker timeout - request failed" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Worker error: {str(e)}" ) @app.post("/v1/chat") async def chat(request: ChatRequest): """Chat completion endpoint""" worker_url = select_worker() if not worker_url: raise HTTPException( status_code=503, detail="No healthy workers available" ) try: async with httpx.AsyncClient(timeout=300.0) as client: if request.stream: # Streaming response async def stream_from_worker(): async with client.stream( "POST", f"{worker_url}/chat", json=request.dict() ) as response: async for chunk in response.aiter_text(): yield chunk return StreamingResponse( stream_from_worker(), media_type="text/event-stream" ) else: # Non-streaming response response = await client.post( f"{worker_url}/chat", json=request.dict() ) return response.json() except httpx.TimeoutException: worker_health[worker_url]["healthy"] = False raise HTTPException( status_code=504, detail="Worker timeout - request failed" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Worker error: {str(e)}" ) # ============================================================================ # Launch # ============================================================================ if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )