Sam-Z-api / app.py
Bc-AI's picture
Update app.py
65ca8df verified
raw
history blame
7.91 kB
"""
SAM-Z-1 Cluster Head Node
Receives requests and distributes to worker spaces
"""
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import httpx
import asyncio
import json
import time
from typing import List, Optional
import random
app = FastAPI(title="SAM-Z-1 Cluster API", version="1.0.0")
# ============================================================================
# Configuration
# ============================================================================
# Add your worker space URLs here
WORKER_URLS = [
"https://bc-ai-worker-2.hf.space",
"https://bc-ai-worker-sam-z-api.hf.space",
# Add more workers as needed
]
# Health check interval (seconds)
HEALTH_CHECK_INTERVAL = 30
# Worker health status
worker_health = {url: {"healthy": True, "last_check": 0} for url in WORKER_URLS}
# ============================================================================
# Request Models
# ============================================================================
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = False
class ChatMessage(BaseModel):
role: str # "user" or "assistant"
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = False
# ============================================================================
# Load Balancing & Health Checks
# ============================================================================
def get_healthy_workers() -> List[str]:
"""Get list of healthy workers"""
return [url for url, status in worker_health.items() if status["healthy"]]
def select_worker() -> Optional[str]:
"""Select a worker using round-robin on healthy workers"""
healthy = get_healthy_workers()
if not healthy:
return None
return random.choice(healthy) # You could also implement round-robin here
async def check_worker_health(worker_url: str) -> bool:
"""Check if a worker is healthy"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{worker_url}/health")
return response.status_code == 200
except:
return False
async def health_check_loop():
"""Background task to check worker health"""
while True:
for worker_url in WORKER_URLS:
healthy = await check_worker_health(worker_url)
worker_health[worker_url]["healthy"] = healthy
worker_health[worker_url]["last_check"] = time.time()
status = "βœ…" if healthy else "❌"
print(f"{status} Worker {worker_url}: {'healthy' if healthy else 'unhealthy'}")
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
@app.on_event("startup")
async def startup_event():
"""Start health check loop on startup"""
asyncio.create_task(health_check_loop())
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/")
async def root():
"""API info"""
healthy_count = len(get_healthy_workers())
return {
"name": "SAM-Z-1 Cluster API",
"version": "1.0.0",
"workers": len(WORKER_URLS),
"healthy_workers": healthy_count,
"endpoints": {
"generate": "/v1/generate",
"chat": "/v1/chat",
"health": "/health",
"workers": "/workers"
}
}
@app.get("/health")
async def health():
"""Health check endpoint"""
healthy_count = len(get_healthy_workers())
return {
"status": "healthy" if healthy_count > 0 else "unhealthy",
"workers_total": len(WORKER_URLS),
"workers_healthy": healthy_count
}
@app.get("/workers")
async def workers_status():
"""Get status of all workers"""
return {
"workers": [
{
"url": url,
"healthy": status["healthy"],
"last_check": status["last_check"]
}
for url, status in worker_health.items()
]
}
@app.post("/v1/generate")
async def generate(request: GenerateRequest):
"""Generate text from prompt"""
worker_url = select_worker()
if not worker_url:
raise HTTPException(
status_code=503,
detail="No healthy workers available"
)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
if request.stream:
# Streaming response
async def stream_from_worker():
async with client.stream(
"POST",
f"{worker_url}/generate",
json=request.dict()
) as response:
async for chunk in response.aiter_text():
yield chunk
return StreamingResponse(
stream_from_worker(),
media_type="text/event-stream"
)
else:
# Non-streaming response
response = await client.post(
f"{worker_url}/generate",
json=request.dict()
)
return response.json()
except httpx.TimeoutException:
# Mark worker as unhealthy and retry with another
worker_health[worker_url]["healthy"] = False
raise HTTPException(
status_code=504,
detail="Worker timeout - request failed"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Worker error: {str(e)}"
)
@app.post("/v1/chat")
async def chat(request: ChatRequest):
"""Chat completion endpoint"""
worker_url = select_worker()
if not worker_url:
raise HTTPException(
status_code=503,
detail="No healthy workers available"
)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
if request.stream:
# Streaming response
async def stream_from_worker():
async with client.stream(
"POST",
f"{worker_url}/chat",
json=request.dict()
) as response:
async for chunk in response.aiter_text():
yield chunk
return StreamingResponse(
stream_from_worker(),
media_type="text/event-stream"
)
else:
# Non-streaming response
response = await client.post(
f"{worker_url}/chat",
json=request.dict()
)
return response.json()
except httpx.TimeoutException:
worker_health[worker_url]["healthy"] = False
raise HTTPException(
status_code=504,
detail="Worker timeout - request failed"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Worker error: {str(e)}"
)
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)