Sam-Z-api / app.py
Bc-AI's picture
Update app.py
cdc0a63 verified
"""
SAM-Z-1 Distributed Compute Cluster Head Node
- Smart load balancing with distributed compute
- Real-time status dashboard
"""
from fastapi import FastAPI, HTTPException, WebSocket
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import httpx
import asyncio
import json
import time
from typing import List, Optional, Dict
from collections import deque
import random
app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0")
# ============================================================================
# Configuration
# ============================================================================
WORKER_URLS = [
"https://bc-ai-worker-2.hf.space",
"https://bc-ai-worker-sam-z-api.hf.space",
"https://bc-ai-worker-3.hf.space",
"https://bc-ai-worker-4.hf.space",
"https://bc-ai-worker-5.hf.space",
"https://bc-ai-worker-sam-z-api-2.hf.space"
]
HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
LOAD_CHECK_WINDOW = 10
LIGHT_LOAD_THRESHOLD = 2
HEAVY_LOAD_THRESHOLD = 5
# Worker state
worker_health = {
url: {
"healthy": True,
"last_check": 0,
"active_requests": 0,
"total_requests": 0,
"total_tokens": 0,
"avg_latency": 0,
"role": "idle" # "generator", "decoder", "full", "idle"
} for url in WORKER_URLS
}
request_timestamps = deque(maxlen=100)
current_load_mode = "light" # "light", "medium", "heavy"
cluster_stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"uptime_start": time.time()
}
# Active WebSocket connections for real-time updates
active_connections = set()
# ============================================================================
# Request Models
# ============================================================================
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = True
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = True
# ============================================================================
# Load Management
# ============================================================================
def get_current_load() -> int:
now = time.time()
return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW)
def update_load_mode():
global current_load_mode
load = get_current_load()
healthy_count = len(get_healthy_workers())
# Adjust thresholds based on available workers
if healthy_count >= 5:
if load <= LIGHT_LOAD_THRESHOLD:
current_load_mode = "light" # 1 gen + 4 decoders
elif load <= MEDIUM_LOAD_THRESHOLD:
current_load_mode = "medium" # 2 gens + 3 decoders OR parallel requests
else:
current_load_mode = "heavy" # all workers independent
elif healthy_count >= 3:
if load <= 2:
current_load_mode = "light" # 1 gen + 2 decoders
else:
current_load_mode = "heavy" # distribute requests
else:
current_load_mode = "heavy" # fallback to simple distribution
return current_load_mode, load
def track_request():
request_timestamps.append(time.time())
cluster_stats["total_requests"] += 1
def get_healthy_workers() -> List[str]:
return [url for url, status in worker_health.items() if status["healthy"]]
def get_least_busy_worker() -> Optional[str]:
healthy = get_healthy_workers()
if not healthy:
return None
return min(healthy, key=lambda url: worker_health[url]["active_requests"])
def select_distributed_workers() -> tuple:
"""
Select workers for distributed compute
Returns: (generators: List[str], decoders: List[str])
"""
healthy = get_healthy_workers()
if len(healthy) < 2:
return ([healthy[0]], []) if len(healthy) == 1 else ([], [])
# Sort by least busy
sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
if len(healthy) >= 5:
# OPTIMAL: 1 generator, 4 decoders
return ([sorted_workers[0]], sorted_workers[1:5])
elif len(healthy) == 4:
# 1 generator, 3 decoders
return ([sorted_workers[0]], sorted_workers[1:4])
elif len(healthy) == 3:
# 1 generator, 2 decoders
return ([sorted_workers[0]], sorted_workers[1:3])
else:
# 1 generator, 1 decoder
return ([sorted_workers[0]], [sorted_workers[1]])
async def broadcast_stats():
"""Broadcast stats to all connected WebSocket clients"""
if not active_connections:
return
mode, load = update_load_mode()
uptime = time.time() - cluster_stats["uptime_start"]
stats = {
"timestamp": time.time(),
"mode": mode,
"load": load,
"workers": [
{
"url": url.split("//")[1].split(".")[0], # shorter name
"healthy": status["healthy"],
"active": status["active_requests"],
"total": status["total_requests"],
"tokens": status["total_tokens"],
"latency": round(status["avg_latency"], 2),
"role": status["role"]
}
for url, status in worker_health.items()
],
"cluster": {
"total_requests": cluster_stats["total_requests"],
"successful": cluster_stats["successful_requests"],
"failed": cluster_stats["failed_requests"],
"uptime": round(uptime, 0),
"rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2)
}
}
# Broadcast to all connections
disconnected = set()
for ws in active_connections:
try:
await ws.send_json(stats)
except:
disconnected.add(ws)
# Remove disconnected
active_connections.difference_update(disconnected)
async def check_worker_health(worker_url: str) -> bool:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{worker_url}/health")
return response.status_code == 200
except:
return False
async def health_check_loop():
while True:
# Check all workers
for worker_url in WORKER_URLS:
healthy = await check_worker_health(worker_url)
worker_health[worker_url]["healthy"] = healthy
worker_health[worker_url]["last_check"] = time.time()
# Always broadcast stats to connected clients
await broadcast_stats()
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(health_check_loop())
# ============================================================================
# Distributed Compute Generation
# ============================================================================
async def distributed_generation(
generators: List[str],
decoders: List[str],
request_data: dict,
endpoint: str = "generate"
):
"""
DISTRIBUTED COMPUTE MODE
- Generator(s) produce token IDs
- Multiple decoders process in parallel (load balanced)
"""
if not generators or not decoders:
return
token_queue = asyncio.Queue(maxsize=50)
text_queue = asyncio.Queue(maxsize=50)
# Mark roles
for gen_url in generators:
worker_health[gen_url]["role"] = "generator"
for dec_url in decoders:
worker_health[dec_url]["role"] = "decoder"
async def generate_tokens():
"""Generator worker(s)"""
gen_url = generators[0] # primary generator
try:
worker_health[gen_url]["active_requests"] += 1
request_data_tokens = {**request_data, "return_token_ids": True}
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{gen_url}/{endpoint}",
json=request_data_tokens
) as response:
async for chunk in response.aiter_text():
if chunk.strip() and chunk.startswith("data: "):
try:
data = json.loads(chunk[6:])
if "token_id" in data:
await token_queue.put(data["token_id"])
elif "done" in data:
# Send done signal for each decoder
for _ in decoders:
await token_queue.put(None)
break
except:
pass
except Exception as e:
print(f"❌ Generator error: {e}")
for _ in decoders:
await token_queue.put(None)
finally:
worker_health[gen_url]["active_requests"] -= 1
worker_health[gen_url]["role"] = "idle"
async def decode_tokens(decoder_url: str, decoder_id: int):
"""Decoder worker - processes tokens from shared queue"""
try:
worker_health[decoder_url]["active_requests"] += 1
batch = []
batch_size = 2 # smaller batches for faster streaming
while True:
try:
token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0)
if token_id is None:
# Decode remaining batch
if batch:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{decoder_url}/decode",
json={"token_ids": batch}
)
text = response.json()["text"]
await text_queue.put(("text", text))
worker_health[decoder_url]["total_tokens"] += len(batch)
await text_queue.put(("done", decoder_id))
break
batch.append(token_id)
# Decode when batch is full
if len(batch) >= batch_size:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{decoder_url}/decode",
json={"token_ids": batch}
)
text = response.json()["text"]
await text_queue.put(("text", text))
worker_health[decoder_url]["total_tokens"] += len(batch)
batch = []
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"❌ Decoder {decoder_id} error: {e}")
await text_queue.put(("done", decoder_id))
finally:
worker_health[decoder_url]["active_requests"] -= 1
worker_health[decoder_url]["role"] = "idle"
# Start generator
gen_task = asyncio.create_task(generate_tokens())
# Start all decoders
decoder_tasks = [
asyncio.create_task(decode_tokens(dec_url, i))
for i, dec_url in enumerate(decoders)
]
# Stream results
accumulated_text = ""
decoders_done = 0
total_decoders = len(decoders)
try:
while decoders_done < total_decoders:
msg_type, data = await text_queue.get()
if msg_type == "done":
decoders_done += 1
continue
if msg_type == "text":
accumulated_text += data
yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n"
finally:
await gen_task
for task in decoder_tasks:
await task
async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
"""Standard single-worker generation"""
try:
worker_health[worker_url]["active_requests"] += 1
worker_health[worker_url]["role"] = "full"
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream(
"POST",
f"{worker_url}/{endpoint}",
json=request_data
) as response:
async for chunk in response.aiter_text():
if chunk.strip():
yield chunk
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
worker_health[worker_url]["active_requests"] -= 1
worker_health[worker_url]["role"] = "idle"
# ============================================================================
# Dashboard
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def dashboard():
"""Real-time futuristic dashboard"""
return """
<!DOCTYPE html>
<html>
<head>
<title>SAM-Z-1 Cluster Control</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Courier New', monospace;
background: linear-gradient(135deg, #0a0e27 0%, #1a1f3a 100%);
color: #00ff88;
min-height: 100vh;
overflow-x: hidden;
overflow-y: auto;
}
.container {
padding: 20px;
max-width: 1400px;
margin: 0 auto;
padding-bottom: 40px;
}
.header {
text-align: center;
margin-bottom: 30px;
padding: 20px;
background: rgba(0, 255, 136, 0.1);
border: 2px solid #00ff88;
border-radius: 10px;
box-shadow: 0 0 20px rgba(0, 255, 136, 0.3);
}
.header h1 {
font-size: 2.5em;
text-transform: uppercase;
letter-spacing: 5px;
text-shadow: 0 0 10px #00ff88;
animation: glow 2s ease-in-out infinite alternate;
}
@keyframes glow {
from { text-shadow: 0 0 10px #00ff88, 0 0 20px #00ff88; }
to { text-shadow: 0 0 20px #00ff88, 0 0 30px #00ff88, 0 0 40px #00ff88; }
}
.status-bar {
display: flex;
gap: 20px;
margin-bottom: 30px;
}
.stat-card {
flex: 1;
background: rgba(0, 255, 136, 0.05);
border: 1px solid #00ff88;
border-radius: 8px;
padding: 15px;
position: relative;
overflow: hidden;
}
.stat-card::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(0, 255, 136, 0.2), transparent);
animation: scan 3s infinite;
}
@keyframes scan {
0% { left: -100%; }
100% { left: 100%; }
}
.stat-label {
font-size: 0.8em;
opacity: 0.7;
text-transform: uppercase;
}
.stat-value {
font-size: 2em;
font-weight: bold;
margin-top: 5px;
}
.mode-badge {
display: inline-block;
padding: 5px 15px;
border-radius: 20px;
font-size: 0.9em;
font-weight: bold;
text-transform: uppercase;
margin-top: 10px;
}
.mode-light {
background: rgba(0, 255, 136, 0.2);
border: 1px solid #00ff88;
color: #00ff88;
}
.mode-heavy {
background: rgba(255, 68, 68, 0.2);
border: 1px solid #ff4444;
color: #ff4444;
}
.workers-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
@media (max-width: 768px) {
.workers-grid {
grid-template-columns: 1fr;
}
.status-bar {
flex-direction: column;
}
.info-grid {
grid-template-columns: repeat(2, 1fr);
}
.header h1 {
font-size: 1.5em;
letter-spacing: 2px;
}
}
.worker-card {
background: rgba(10, 14, 39, 0.8);
border: 2px solid #00ff88;
border-radius: 10px;
padding: 20px;
position: relative;
transition: all 0.3s;
}
.worker-card:hover {
transform: translateY(-5px);
box-shadow: 0 5px 30px rgba(0, 255, 136, 0.4);
}
.worker-card.offline {
border-color: #ff4444;
opacity: 0.6;
}
.worker-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
}
.worker-name {
font-size: 1.2em;
font-weight: bold;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
animation: pulse 2s infinite;
}
.status-dot.online {
background: #00ff88;
box-shadow: 0 0 10px #00ff88;
}
.status-dot.offline {
background: #ff4444;
box-shadow: 0 0 10px #ff4444;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.worker-stats {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 10px;
margin-top: 15px;
}
.worker-stat {
background: rgba(0, 255, 136, 0.05);
padding: 10px;
border-radius: 5px;
}
.worker-stat-label {
font-size: 0.7em;
opacity: 0.7;
}
.worker-stat-value {
font-size: 1.3em;
font-weight: bold;
margin-top: 3px;
}
.role-badge {
display: inline-block;
padding: 3px 10px;
border-radius: 12px;
font-size: 0.75em;
margin-top: 10px;
font-weight: bold;
}
.role-generator {
background: rgba(255, 165, 0, 0.2);
border: 1px solid #ffa500;
color: #ffa500;
}
.role-decoder {
background: rgba(0, 191, 255, 0.2);
border: 1px solid #00bfff;
color: #00bfff;
}
.role-full {
background: rgba(138, 43, 226, 0.2);
border: 1px solid #8a2be2;
color: #8a2be2;
}
.role-idle {
background: rgba(128, 128, 128, 0.2);
border: 1px solid #808080;
color: #808080;
}
.progress-bar {
width: 100%;
height: 4px;
background: rgba(0, 255, 136, 0.1);
border-radius: 2px;
margin-top: 10px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: linear-gradient(90deg, #00ff88, #00ffff);
transition: width 0.3s;
box-shadow: 0 0 10px #00ff88;
}
.cluster-info {
background: rgba(0, 255, 136, 0.05);
border: 1px solid #00ff88;
border-radius: 8px;
padding: 20px;
}
.info-grid {
display: grid;
grid-template-columns: repeat(4, 1fr);
gap: 20px;
}
.info-item {
text-align: center;
}
.timestamp {
text-align: center;
margin-top: 20px;
opacity: 0.5;
font-size: 0.9em;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>⚑ SAM-Z-1 CLUSTER ⚑</h1>
<div>DISTRIBUTED COMPUTE SYSTEM v4.0</div>
</div>
<div class="status-bar">
<div class="stat-card">
<div class="stat-label">Load Mode</div>
<div class="stat-value" id="mode">--</div>
<div class="mode-badge" id="mode-badge">INITIALIZING</div>
</div>
<div class="stat-card">
<div class="stat-label">Current Load</div>
<div class="stat-value" id="load">0</div>
<div class="stat-label">requests / 10s</div>
</div>
<div class="stat-card">
<div class="stat-label">Total Requests</div>
<div class="stat-value" id="total-req">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Req/Sec</div>
<div class="stat-value" id="rps">0.00</div>
</div>
</div>
<div class="workers-grid" id="workers">
<!-- Workers populated by JS -->
</div>
<div class="cluster-info">
<div class="stat-label" style="margin-bottom: 15px;">CLUSTER STATISTICS</div>
<div class="info-grid">
<div class="info-item">
<div class="stat-label">Successful</div>
<div class="stat-value" style="font-size: 1.5em;" id="success">0</div>
</div>
<div class="info-item">
<div class="stat-label">Failed</div>
<div class="stat-value" style="font-size: 1.5em;" id="failed">0</div>
</div>
<div class="info-item">
<div class="stat-label">Uptime</div>
<div class="stat-value" style="font-size: 1.5em;" id="uptime">0s</div>
</div>
<div class="info-item">
<div class="stat-label">Healthy Workers</div>
<div class="stat-value" style="font-size: 1.5em;" id="healthy">0</div>
</div>
</div>
</div>
<div class="timestamp" id="timestamp">Last update: --</div>
</div>
<script>
// Use wss:// for HTTPS, ws:// for HTTP
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let ws;
let usePolling = false;
function connectWebSocket() {
try {
ws = new WebSocket(`${protocol}//${window.location.host}/ws`);
ws.onopen = () => {
console.log('βœ… WebSocket connected');
usePolling = false;
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
updateDashboard(data);
};
ws.onerror = (error) => {
console.error('❌ WebSocket error, switching to polling');
usePolling = true;
startPolling();
};
ws.onclose = () => {
console.log('πŸ”Œ WebSocket disconnected');
if (!usePolling) {
setTimeout(connectWebSocket, 3000);
}
};
} catch (e) {
console.error('Failed to connect WebSocket, using polling');
usePolling = true;
startPolling();
}
}
async function pollStats() {
if (!usePolling) return;
try {
const response = await fetch('/api/status');
const data = await response.json();
// Fetch worker stats too
const workersRes = await fetch('/workers');
const workersData = await workersRes.json();
// Format data like WebSocket
const formattedData = {
timestamp: Date.now() / 1000,
mode: data.mode,
load: data.current_load,
workers: workersData.workers.map(w => ({
url: w.url.split("//")[1].split(".")[0],
healthy: w.healthy,
active: w.active_requests || 0,
total: 0,
tokens: 0,
latency: 0,
role: "idle"
})),
cluster: {
total_requests: 0,
successful: 0,
failed: 0,
uptime: 0,
rps: 0
}
};
updateDashboard(formattedData);
} catch (e) {
console.error('Polling error:', e);
}
}
function startPolling() {
pollStats();
setInterval(pollStats, 1000);
}
// Try WebSocket first
connectWebSocket();
function updateDashboard(data) {
// Mode
document.getElementById('mode').textContent = data.mode.toUpperCase();
const modeBadge = document.getElementById('mode-badge');
modeBadge.textContent = `${data.mode.toUpperCase()} MODE`;
modeBadge.className = `mode-badge mode-${data.mode}`;
// Stats
document.getElementById('load').textContent = data.load;
document.getElementById('total-req').textContent = data.cluster.total_requests;
document.getElementById('rps').textContent = data.cluster.rps;
document.getElementById('success').textContent = data.cluster.successful;
document.getElementById('failed').textContent = data.cluster.failed;
document.getElementById('uptime').textContent = formatUptime(data.cluster.uptime);
// Workers
const workersDiv = document.getElementById('workers');
const healthyCount = data.workers.filter(w => w.healthy).length;
document.getElementById('healthy').textContent = `${healthyCount}/${data.workers.length}`;
workersDiv.innerHTML = data.workers.map(worker => `
<div class="worker-card ${worker.healthy ? '' : 'offline'}">
<div class="worker-header">
<div class="worker-name">${worker.url}</div>
<div class="status-dot ${worker.healthy ? 'online' : 'offline'}"></div>
</div>
<div class="role-badge role-${worker.role}">${worker.role.toUpperCase()}</div>
<div class="worker-stats">
<div class="worker-stat">
<div class="worker-stat-label">Active</div>
<div class="worker-stat-value">${worker.active}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Total</div>
<div class="worker-stat-value">${worker.total}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Tokens</div>
<div class="worker-stat-value">${worker.tokens}</div>
</div>
<div class="worker-stat">
<div class="worker-stat-label">Latency</div>
<div class="worker-stat-value">${worker.latency}ms</div>
</div>
</div>
<div class="progress-bar">
<div class="progress-fill" style="width: ${Math.min(worker.active * 33, 100)}%"></div>
</div>
</div>
`).join('');
// Timestamp
const now = new Date();
document.getElementById('timestamp').textContent =
`Last update: ${now.toLocaleTimeString()}`;
}
function formatUptime(seconds) {
const h = Math.floor(seconds / 3600);
const m = Math.floor((seconds % 3600) / 60);
const s = Math.floor(seconds % 60);
return `${h}h ${m}m ${s}s`;
}
</script>
</body>
</html>
"""
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket for real-time dashboard updates"""
await websocket.accept()
active_connections.add(websocket)
try:
# Send initial data
await broadcast_stats()
# Keep connection alive
while True:
await websocket.receive_text()
except:
pass
finally:
active_connections.discard(websocket)
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/api/status")
async def api_status():
"""JSON API for status"""
mode, load = update_load_mode()
healthy_count = len(get_healthy_workers())
return {
"name": "SAM-Z-1 Distributed Cluster",
"version": "4.0.0",
"mode": mode,
"current_load": load,
"workers": len(WORKER_URLS),
"healthy_workers": healthy_count,
"features": ["distributed_compute", "smart_load_balancing", "real_time_dashboard"]
}
@app.get("/health")
async def health():
healthy_count = len(get_healthy_workers())
return {
"status": "healthy" if healthy_count > 0 else "unhealthy",
"workers_healthy": healthy_count
}
@app.post("/v1/generate")
async def generate(request: GenerateRequest):
"""Generate text with distributed compute"""
track_request()
mode, load = update_load_mode()
healthy = get_healthy_workers()
if not healthy:
cluster_stats["failed_requests"] += 1
raise HTTPException(status_code=503, detail="No healthy workers")
request_data = {
"prompt": request.prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"stream": True
}
print(f"🎯 {mode.upper()} | Load: {load} | Workers: {len(healthy)}")
try:
if mode == "light" and len(healthy) >= 2:
# DISTRIBUTED MODE - 1 gen + multiple decoders
generators, decoders = select_distributed_workers()
if decoders:
cluster_stats["successful_requests"] += 1
return StreamingResponse(
distributed_generation(generators, decoders, request_data, "generate"),
media_type="text/event-stream"
)
# HEAVY/FALLBACK - single worker
worker = get_least_busy_worker()
cluster_stats["successful_requests"] += 1
return StreamingResponse(
heavy_load_generation(worker, request_data, "generate"),
media_type="text/event-stream"
)
except Exception as e:
cluster_stats["failed_requests"] += 1
raise
@app.post("/v1/chat")
async def chat(request: ChatRequest):
"""Chat with distributed compute"""
track_request()
mode, load = update_load_mode()
healthy = get_healthy_workers()
if not healthy:
cluster_stats["failed_requests"] += 1
raise HTTPException(status_code=503, detail="No healthy workers")
request_data = {
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"stream": True
}
print(f"πŸ’¬ {mode.upper()} | Load: {load} | Workers: {len(healthy)}")
try:
if mode == "light" and len(healthy) >= 2:
# DISTRIBUTED MODE - 1 gen + multiple decoders
generators, decoders = select_distributed_workers()
if decoders:
cluster_stats["successful_requests"] += 1
return StreamingResponse(
distributed_generation(generators, decoders, request_data, "chat"),
media_type="text/event-stream"
)
# HEAVY/FALLBACK - single worker
worker = get_least_busy_worker()
cluster_stats["successful_requests"] += 1
return StreamingResponse(
heavy_load_generation(worker, request_data, "chat"),
media_type="text/event-stream"
)
except Exception as e:
cluster_stats["failed_requests"] += 1
raise
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)