Spaces:
Sleeping
Sleeping
| """ | |
| SAM-Z-1 Cluster Head Node | |
| Receives requests and distributes to worker spaces | |
| """ | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| import asyncio | |
| import json | |
| import time | |
| from typing import List, Optional | |
| import random | |
| app = FastAPI(title="SAM-Z-1 Cluster API", version="1.0.0") | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| # Add your worker space URLs here | |
| WORKER_URLS = [ | |
| "https://bc-ai-worker-2.hf.space", | |
| "https://bc-ai-worker-sam-z-api.hf.space", | |
| # Add more workers as needed | |
| ] | |
| # Health check interval (seconds) | |
| HEALTH_CHECK_INTERVAL = 30 | |
| # Worker health status | |
| worker_health = {url: {"healthy": True, "last_check": 0} for url in WORKER_URLS} | |
| # ============================================================================ | |
| # Request Models | |
| # ============================================================================ | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = False | |
| class ChatMessage(BaseModel): | |
| role: str # "user" or "assistant" | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: int = 512 | |
| temperature: float = 0.8 | |
| top_k: int = 40 | |
| top_p: float = 0.9 | |
| repetition_penalty: float = 1.1 | |
| stream: bool = False | |
| # ============================================================================ | |
| # Load Balancing & Health Checks | |
| # ============================================================================ | |
| def get_healthy_workers() -> List[str]: | |
| """Get list of healthy workers""" | |
| return [url for url, status in worker_health.items() if status["healthy"]] | |
| def select_worker() -> Optional[str]: | |
| """Select a worker using round-robin on healthy workers""" | |
| healthy = get_healthy_workers() | |
| if not healthy: | |
| return None | |
| return random.choice(healthy) # You could also implement round-robin here | |
| async def check_worker_health(worker_url: str) -> bool: | |
| """Check if a worker is healthy""" | |
| try: | |
| async with httpx.AsyncClient(timeout=5.0) as client: | |
| response = await client.get(f"{worker_url}/health") | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| async def health_check_loop(): | |
| """Background task to check worker health""" | |
| while True: | |
| for worker_url in WORKER_URLS: | |
| healthy = await check_worker_health(worker_url) | |
| worker_health[worker_url]["healthy"] = healthy | |
| worker_health[worker_url]["last_check"] = time.time() | |
| status = "β " if healthy else "β" | |
| print(f"{status} Worker {worker_url}: {'healthy' if healthy else 'unhealthy'}") | |
| await asyncio.sleep(HEALTH_CHECK_INTERVAL) | |
| async def startup_event(): | |
| """Start health check loop on startup""" | |
| asyncio.create_task(health_check_loop()) | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def root(): | |
| """API info""" | |
| healthy_count = len(get_healthy_workers()) | |
| return { | |
| "name": "SAM-Z-1 Cluster API", | |
| "version": "1.0.0", | |
| "workers": len(WORKER_URLS), | |
| "healthy_workers": healthy_count, | |
| "endpoints": { | |
| "generate": "/v1/generate", | |
| "chat": "/v1/chat", | |
| "health": "/health", | |
| "workers": "/workers" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| healthy_count = len(get_healthy_workers()) | |
| return { | |
| "status": "healthy" if healthy_count > 0 else "unhealthy", | |
| "workers_total": len(WORKER_URLS), | |
| "workers_healthy": healthy_count | |
| } | |
| async def workers_status(): | |
| """Get status of all workers""" | |
| return { | |
| "workers": [ | |
| { | |
| "url": url, | |
| "healthy": status["healthy"], | |
| "last_check": status["last_check"] | |
| } | |
| for url, status in worker_health.items() | |
| ] | |
| } | |
| async def generate(request: GenerateRequest): | |
| """Generate text from prompt""" | |
| worker_url = select_worker() | |
| if not worker_url: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="No healthy workers available" | |
| ) | |
| try: | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| if request.stream: | |
| # Streaming response | |
| async def stream_from_worker(): | |
| async with client.stream( | |
| "POST", | |
| f"{worker_url}/generate", | |
| json=request.dict() | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| yield chunk | |
| return StreamingResponse( | |
| stream_from_worker(), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # Non-streaming response | |
| response = await client.post( | |
| f"{worker_url}/generate", | |
| json=request.dict() | |
| ) | |
| return response.json() | |
| except httpx.TimeoutException: | |
| # Mark worker as unhealthy and retry with another | |
| worker_health[worker_url]["healthy"] = False | |
| raise HTTPException( | |
| status_code=504, | |
| detail="Worker timeout - request failed" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Worker error: {str(e)}" | |
| ) | |
| async def chat(request: ChatRequest): | |
| """Chat completion endpoint""" | |
| worker_url = select_worker() | |
| if not worker_url: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="No healthy workers available" | |
| ) | |
| try: | |
| async with httpx.AsyncClient(timeout=300.0) as client: | |
| if request.stream: | |
| # Streaming response | |
| async def stream_from_worker(): | |
| async with client.stream( | |
| "POST", | |
| f"{worker_url}/chat", | |
| json=request.dict() | |
| ) as response: | |
| async for chunk in response.aiter_text(): | |
| yield chunk | |
| return StreamingResponse( | |
| stream_from_worker(), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # Non-streaming response | |
| response = await client.post( | |
| f"{worker_url}/chat", | |
| json=request.dict() | |
| ) | |
| return response.json() | |
| except httpx.TimeoutException: | |
| worker_health[worker_url]["healthy"] = False | |
| raise HTTPException( | |
| status_code=504, | |
| detail="Worker timeout - request failed" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Worker error: {str(e)}" | |
| ) | |
| # ============================================================================ | |
| # Launch | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) |