Sam-Z-api / app.py
Bc-AI's picture
Update app.py
74ffe1c verified
raw
history blame
31.4 kB
"""
SAM-Z-1 Distributed Compute Cluster Head Node v5.0
- Smart load balancing with distributed compute
- Real-time status dashboard
- Auto-detects worker version (v4 vs v5)
- Supports 4 new models with backward compatibility
"""
from fastapi import FastAPI, HTTPException, WebSocket
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import httpx
import asyncio
import json
import time
from typing import List, Optional, Dict
from collections import deque
import random
app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="5.0.0")
# ============================================================================
# Configuration
# ============================================================================
WORKER_URLS = [
"https://bc-ai-worker-2.hf.space",
"https://bc-ai-worker-sam-z-api.hf.space",
"https://bc-ai-worker-sam-z-api.hf.space",
"https://bc-ai-worker-3.hf.space",
"https://bc-ai-worker-4.hf.space",
"https://bc-ai-worker-5.hf.space"
]
HEALTH_CHECK_INTERVAL = 5
LOAD_CHECK_WINDOW = 10
LIGHT_LOAD_THRESHOLD = 2
HEAVY_LOAD_THRESHOLD = 5
# New models added in v5
NEW_MODELS = [
"SAM-X-1-Large",
"SAM-X-1-Fast",
"SAM-X-1-Mini",
"SAM-X-1-Nano"
]
# Worker state with version detection
worker_health = {
url: {
"healthy": True,
"last_check": 0,
"active_requests": 0,
"total_requests": 0,
"total_tokens": 0,
"avg_latency": 0,
"role": "idle",
"version": None, # Will be auto-detected: "v4" or "v5"
"supports_models": [] # Models this worker supports
} for url in WORKER_URLS
}
request_timestamps = deque(maxlen=100)
current_load_mode = "light"
cluster_stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"uptime_start": time.time()
}
active_connections = set()
# ============================================================================
# Request Models
# ============================================================================
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = True
model: Optional[str] = None # NEW: Model selection
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = True
model: Optional[str] = None # NEW: Model selection
# ============================================================================
# Worker Version Detection
# ============================================================================
async def detect_worker_version(worker_url: str) -> tuple:
"""
Detect worker version and supported models
Returns: (version: str, supported_models: List[str])
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Try to get worker info endpoint (v5 feature)
try:
response = await client.get(f"{worker_url}/info")
if response.status_code == 200:
data = response.json()
version = data.get("version", "v5")
models = data.get("models", NEW_MODELS)
return version, models
except:
pass
# Try to get models list (v5 feature)
try:
response = await client.get(f"{worker_url}/models")
if response.status_code == 200:
data = response.json()
models = data.get("models", [])
if models:
return "v5", models
except:
pass
# Fallback: worker is v4 (no model selection)
return "v4", []
except Exception as e:
print(f"⚠️ Version detection failed for {worker_url}: {e}")
return "v4", []
async def check_worker_health(worker_url: str) -> bool:
"""Check if worker is healthy"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{worker_url}/health")
return response.status_code == 200
except:
return False
async def health_check_loop():
"""Health check with version detection"""
while True:
for worker_url in WORKER_URLS:
# Check health
healthy = await check_worker_health(worker_url)
worker_health[worker_url]["healthy"] = healthy
worker_health[worker_url]["last_check"] = time.time()
# Detect version if not yet detected
if worker_health[worker_url]["version"] is None:
version, models = await detect_worker_version(worker_url)
worker_health[worker_url]["version"] = version
worker_health[worker_url]["supports_models"] = models
status = "✅" if healthy else "❌"
print(f"{status} Worker: {worker_url.split('//')[1].split('.')[0]} | Version: {version} | Models: {len(models)}")
await broadcast_stats()
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(health_check_loop())
# ============================================================================
# Smart Worker Selection
# ============================================================================
def get_workers_for_model(model_name: Optional[str]) -> List[str]:
"""Get workers that support the requested model"""
healthy = get_healthy_workers()
if not model_name:
# No specific model requested, use any healthy worker
return healthy
# Filter workers by model support
compatible = []
for url in healthy:
version = worker_health[url]["version"]
models = worker_health[url]["supports_models"]
if version == "v5" and model_name in models:
# v5 worker with explicit model support
compatible.append(url)
elif version == "v4":
# v4 workers don't support model selection but work with default
compatible.append(url)
return compatible if compatible else healthy
def get_healthy_workers() -> List[str]:
return [url for url, status in worker_health.items() if status["healthy"]]
def get_least_busy_worker(worker_list: List[str] = None) -> Optional[str]:
workers = worker_list if worker_list is not None else get_healthy_workers()
if not workers:
return None
return min(workers, key=lambda url: worker_health[url]["active_requests"])
def select_distributed_workers(model_name: Optional[str] = None) -> tuple:
"""
Select workers for distributed compute with model compatibility
Returns: (generators: List[str], decoders: List[str])
"""
compatible = get_workers_for_model(model_name)
if len(compatible) < 2:
return ([compatible[0]], []) if len(compatible) == 1 else ([], [])
sorted_workers = sorted(compatible, key=lambda url: worker_health[url]["active_requests"])
if len(compatible) >= 5:
return ([sorted_workers[0]], sorted_workers[1:5])
elif len(compatible) == 4:
return ([sorted_workers[0]], sorted_workers[1:4])
elif len(compatible) == 3:
return ([sorted_workers[0]], sorted_workers[1:3])
else:
return ([sorted_workers[0]], [sorted_workers[1]])
# ============================================================================
# Load Management
# ============================================================================
def get_current_load() -> int:
now = time.time()
return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW)
def update_load_mode():
global current_load_mode
load = get_current_load()
healthy_count = len(get_healthy_workers())
if healthy_count >= 5:
if load <= LIGHT_LOAD_THRESHOLD:
current_load_mode = "light"
elif load <= HEAVY_LOAD_THRESHOLD:
current_load_mode = "medium"
else:
current_load_mode = "heavy"
elif healthy_count >= 3:
if load <= 2:
current_load_mode = "light"
else:
current_load_mode = "heavy"
else:
current_load_mode = "heavy"
return current_load_mode, load
def track_request():
request_timestamps.append(time.time())
cluster_stats["total_requests"] += 1
# ============================================================================
# Dashboard & WebSocket
# ============================================================================
async def broadcast_stats():
if not active_connections:
return
mode, load = update_load_mode()
uptime = time.time() - cluster_stats["uptime_start"]
stats = {
"timestamp": time.time(),
"mode": mode,
"load": load,
"workers": [
{
"url": url.split("//")[1].split(".")[0],
"healthy": status["healthy"],
"active": status["active_requests"],
"total": status["total_requests"],
"tokens": status["total_tokens"],
"latency": round(status["avg_latency"], 2),
"role": status["role"],
"version": status["version"] or "detecting...",
"models": len(status["supports_models"])
}
for url, status in worker_health.items()
],
"cluster": {
"total_requests": cluster_stats["total_requests"],
"successful": cluster_stats["successful_requests"],
"failed": cluster_stats["failed_requests"],
"uptime": round(uptime, 0),
"rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2)
}
}
disconnected = set()
for ws in active_connections:
try:
await ws.send_json(stats)
except:
disconnected.add(ws)
active_connections.difference_update(disconnected)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
active_connections.add(websocket)
try:
await broadcast_stats()
while True:
await websocket.receive_text()
except:
pass
finally:
active_connections.discard(websocket)
# ============================================================================
# Distributed Generation
# ============================================================================
async def distributed_generation(
generators: List[str],
decoders: List[str],
request_data: dict,
endpoint: str = "generate"
):
"""Distributed compute with v4/v5 compatibility"""
if not generators or not decoders:
return
token_queue = asyncio.Queue(maxsize=50)
text_queue = asyncio.Queue(maxsize=50)
for gen_url in generators:
worker_health[gen_url]["role"] = "generator"
for dec_url in decoders:
worker_health[dec_url]["role"] = "decoder"
async def generate_tokens():
gen_url = generators[0]
try:
worker_health[gen_url]["active_requests"] += 1
request_data_tokens = {**request_data, "return_token_ids": True}
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{gen_url}/{endpoint}",
json=request_data_tokens
) as response:
async for chunk in response.aiter_text():
if chunk.strip() and chunk.startswith("data: "):
try:
data = json.loads(chunk[6:])
if "token_id" in data:
await token_queue.put(data["token_id"])
elif "done" in data:
for _ in decoders:
await token_queue.put(None)
break
except:
pass
except Exception as e:
print(f"❌ Generator error: {e}")
for _ in decoders:
await token_queue.put(None)
finally:
worker_health[gen_url]["active_requests"] -= 1
worker_health[gen_url]["role"] = "idle"
async def decode_tokens(decoder_url: str, decoder_id: int):
try:
worker_health[decoder_url]["active_requests"] += 1
batch = []
batch_size = 2
while True:
try:
token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0)
if token_id is None:
if batch:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{decoder_url}/decode",
json={"token_ids": batch}
)
text = response.json()["text"]
await text_queue.put(("text", text))
worker_health[decoder_url]["total_tokens"] += len(batch)
await text_queue.put(("done", decoder_id))
break
batch.append(token_id)
if len(batch) >= batch_size:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{decoder_url}/decode",
json={"token_ids": batch}
)
text = response.json()["text"]
await text_queue.put(("text", text))
worker_health[decoder_url]["total_tokens"] += len(batch)
batch = []
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"❌ Decoder {decoder_id} error: {e}")
await text_queue.put(("done", decoder_id))
finally:
worker_health[decoder_url]["active_requests"] -= 1
worker_health[decoder_url]["role"] = "idle"
gen_task = asyncio.create_task(generate_tokens())
decoder_tasks = [
asyncio.create_task(decode_tokens(dec_url, i))
for i, dec_url in enumerate(decoders)
]
accumulated_text = ""
decoders_done = 0
total_decoders = len(decoders)
try:
while decoders_done < total_decoders:
msg_type, data = await text_queue.get()
if msg_type == "done":
decoders_done += 1
continue
if msg_type == "text":
accumulated_text += data
yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n"
finally:
await gen_task
for task in decoder_tasks:
await task
async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
"""Standard single-worker generation"""
try:
worker_health[worker_url]["active_requests"] += 1
worker_health[worker_url]["role"] = "full"
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{worker_url}/{endpoint}",
json=request_data
) as response:
async for chunk in response.aiter_text():
if chunk.strip():
yield chunk
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
worker_health[worker_url]["active_requests"] -= 1
worker_health[worker_url]["role"] = "idle"
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def dashboard():
"""Real-time dashboard with version info"""
return """
<!DOCTYPE html>
<html>
<head>
<title>SAM-Z-1 Cluster Control v5.0</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Courier New', monospace;
background: linear-gradient(135deg, #0a0e27 0%, #1a1f3a 100%);
color: #00ff88;
min-height: 100vh;
overflow-x: hidden;
overflow-y: auto;
}
.container {
padding: 20px;
max-width: 1400px;
margin: 0 auto;
padding-bottom: 40px;
}
.header {
text-align: center;
margin-bottom: 30px;
padding: 20px;
background: rgba(0, 255, 136, 0.1);
border: 2px solid #00ff88;
border-radius: 10px;
box-shadow: 0 0 20px rgba(0, 255, 136, 0.3);
}
.header h1 {
font-size: 2.5em;
text-transform: uppercase;
letter-spacing: 5px;
text-shadow: 0 0 10px #00ff88;
animation: glow 2s ease-in-out infinite alternate;
}
@keyframes glow {
from { text-shadow: 0 0 10px #00ff88, 0 0 20px #00ff88; }
to { text-shadow: 0 0 20px #00ff88, 0 0 30px #00ff88, 0 0 40px #00ff88; }
}
.version-badge {
display: inline-block;
padding: 3px 10px;
border-radius: 12px;
font-size: 10px;
margin-left: 8px;
font-weight: bold;
}
.version-v5 {
background: rgba(0, 255, 136, 0.2);
border: 1px solid #00ff88;
color: #00ff88;
}
.version-v4 {
background: rgba(255, 165, 0, 0.2);
border: 1px solid #ffa500;
color: #ffa500;
}
.status-bar {
display: flex;
gap: 20px;
margin-bottom: 30px;
}
.stat-card {
flex: 1;
background: rgba(0, 255, 136, 0.05);
border: 1px solid #00ff88;
border-radius: 8px;
padding: 15px;
position: relative;
overflow: hidden;
}
.stat-label {
font-size: 0.8em;
opacity: 0.7;
text-transform: uppercase;
}
.stat-value {
font-size: 2em;
font-weight: bold;
margin-top: 5px;
}
.workers-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.worker-card {
background: rgba(10, 14, 39, 0.8);
border: 2px solid #00ff88;
border-radius: 10px;
padding: 20px;
position: relative;
transition: all 0.3s;
}
.worker-card:hover {
transform: translateY(-5px);
box-shadow: 0 5px 30px rgba(0, 255, 136, 0.4);
}
.worker-card.offline {
border-color: #ff4444;
opacity: 0.6;
}
.worker-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
}
.worker-name {
font-size: 1.2em;
font-weight: bold;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
animation: pulse 2s infinite;
}
.status-dot.online {
background: #00ff88;
box-shadow: 0 0 10px #00ff88;
}
.status-dot.offline {
background: #ff4444;
box-shadow: 0 0 10px #ff4444;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.worker-stats {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 10px;
margin-top: 15px;
}
.worker-stat {
background: rgba(0, 255, 136, 0.05);
padding: 10px;
border-radius: 5px;
}
.worker-stat-label {
font-size: 0.7em;
opacity: 0.7;
}
.worker-stat-value {
font-size: 1.3em;
font-weight: bold;
margin-top: 3px;
}
.timestamp {
text-align: center;
margin-top: 20px;
opacity: 0.5;
font-size: 0.9em;
}
@media (max-width: 768px) {
.workers-grid { grid-template-columns: 1fr; }
.status-bar { flex-direction: column; }
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>⚡ SAM-Z-1 CLUSTER ⚡</h1>
<div>DISTRIBUTED COMPUTE SYSTEM v5.0 • AUTO VERSION DETECTION</div>
</div>
<div class="status-bar">
<div class="stat-card">
<div class="stat-label">Load Mode</div>
<div class="stat-value" id="mode">--</div>
</div>
<div class="stat-card">
<div class="stat-label">Current Load</div>
<div class="stat-value" id="load">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Total Requests</div>
<div class="stat-value" id="total-req">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Req/Sec</div>
<div class="stat-value" id="rps">0.00</div>
</div>
</div>
<div class="workers-grid" id="workers"></div>
<div class="timestamp" id="timestamp">Last update: --</div>
</div>
<script>
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let ws;
function connectWebSocket() {
try {
ws = new WebSocket(`${protocol}//${window.location.host}/ws`);
ws.onopen = () => console.log('✅ WebSocket connected');
ws.onmessage = (event) => updateDashboard(JSON.parse(event.data));
ws.onerror = () => console.error('❌ WebSocket error');
ws.onclose = () => setTimeout(connectWebSocket, 3000);
} catch (e) {
console.error('Failed to connect WebSocket');
}
}
connectWebSocket();
function updateDashboard(data) {
document.getElementById('mode').textContent = data.mode.toUpperCase();
document.getElementById('load').textContent = data.load;
document.getElementById('total-req').textContent = data.cluster.total_requests;
document.getElementById('rps').textContent = data.cluster.rps;
const workersDiv = document.getElementById('workers');
workersDiv.innerHTML = data.workers.map(worker => `
<div class="worker-card ${worker.healthy ? '' : 'offline'}">
<div class="worker-header">
<div>
<div class="worker-name">${worker.url}</div>
<span class="version-badge version-${worker.version}">${worker.version.toUpperCase()}</span>
${worker.models > 0 ? `<span style="font-size:0.8em;opacity:0.7;margin-left:5px">${worker.models} models</span>` : ''}
</div>
<div class="status-dot ${worker.healthy ? 'online' : 'offline'}"></div>
</div>
<div class="worker-stats">
<div class="worker-stat">
<div class="worker-stat-label">Active</div>
<div class="worker-stat-value">${worker.active}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Total</div>
<div class="worker-stat-value">${worker.total}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Tokens</div>
<div class="worker-stat-value">${worker.tokens}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Role</div>
<div class="worker-stat-value" style="font-size:1em;">${worker.role}</div>
</div>
</div>
</div>
`).join('');
document.getElementById('timestamp').textContent =
`Last update: ${new Date().toLocaleTimeString()}`;
}
</script>
</body>
</html>
"""
@app.get("/api/status")
async def api_status():
mode, load = update_load_mode()
healthy_count = len(get_healthy_workers())
# Count v4 vs v5 workers
v4_count = sum(1 for w in worker_health.values() if w["version"] == "v4")
v5_count = sum(1 for w in worker_health.values() if w["version"] == "v5")
return {
"name": "SAM-Z-1 Distributed Cluster",
"version": "5.0.0",
"mode": mode,
"current_load": load,
"workers": len(WORKER_URLS),
"healthy_workers": healthy_count,
"v4_workers": v4_count,
"v5_workers": v5_count,
"features": [
"distributed_compute",
"smart_load_balancing",
"auto_version_detection",
"multi_model_support",
"real_time_dashboard"
]
}
@app.get("/health")
async def health():
healthy_count = len(get_healthy_workers())
return {
"status": "healthy" if healthy_count > 0 else "unhealthy",
"workers_healthy": healthy_count
}
@app.get("/models")
async def list_models():
"""List all available models across all workers"""
all_models = set()
for url, status in worker_health.items():
if status["healthy"] and status["version"] == "v5":
all_models.update(status["supports_models"])
return {
"models": sorted(list(all_models)),
"default": "SAM-X-1-Nano" if "SAM-X-1-Nano" in all_models else None
}
@app.post("/v1/generate")
async def generate(request: GenerateRequest):
"""Generate text with automatic model routing"""
track_request()
mode, load = update_load_mode()
# Get compatible workers
compatible = get_workers_for_model(request.model)
if not compatible:
cluster_stats["failed_requests"] += 1
raise HTTPException(
status_code=503,
detail=f"No workers available for model: {request.model or 'default'}"
)
request_data = {
"prompt": request.prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"stream": True
}
# Add model parameter for v5 workers
if request.model:
request_data["model"] = request.model
print(f"🎯 {mode.upper()} | Load: {load} | Model: {request.model or 'default'} | Workers: {len(compatible)}")
try:
if mode == "light" and len(compatible) >= 2:
generators, decoders = select_distributed_workers(request.model)
if decoders:
cluster_stats["successful_requests"] += 1
return StreamingResponse(
distributed_generation(generators, decoders, request_data, "generate"),
media_type="text/event-stream"
)
worker = get_least_busy_worker(compatible)
cluster_stats["successful_requests"] += 1
return StreamingResponse(
heavy_load_generation(worker, request_data, "generate"),
media_type="text/event-stream"
)
except Exception as e:
cluster_stats["failed_requests"] += 1
raise
@app.post("/v1/chat")
async def chat(request: ChatRequest):
"""Chat with automatic model routing"""
track_request()
mode, load = update_load_mode()
# Get compatible workers
compatible = get_workers_for_model(request.model)
if not compatible:
cluster_stats["failed_requests"] += 1
raise HTTPException(
status_code=503,
detail=f"No workers available for model: {request.model or 'default'}"
)
request_data = {
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"stream": True
}
# Add model parameter for v5 workers
if request.model:
request_data["model"] = request.model
print(f"💬 {mode.upper()} | Load: {load} | Model: {request.model or 'default'} | Workers: {len(compatible)}")
try:
if mode == "light" and len(compatible) >= 2:
generators, decoders = select_distributed_workers(request.model)
if decoders:
cluster_stats["successful_requests"] += 1
return StreamingResponse(
distributed_generation(generators, decoders, request_data, "chat"),
media_type="text/event-stream"
)
worker = get_least_busy_worker(compatible)
cluster_stats["successful_requests"] += 1
return StreamingResponse(
heavy_load_generation(worker, request_data, "chat"),
media_type="text/event-stream"
)
except Exception as e:
cluster_stats["failed_requests"] += 1
raise
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)