Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,10 +24,6 @@ app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0")
|
|
| 24 |
WORKER_URLS = [
|
| 25 |
"https://bc-ai-worker-2.hf.space",
|
| 26 |
"https://bc-ai-worker-sam-z-api.hf.space",
|
| 27 |
-
"https://bc-ai-worker-3.hf.space",
|
| 28 |
-
"https://bc-ai-worker-4.hf.space",
|
| 29 |
-
"https://bc-ai-worker-5.hf.space"
|
| 30 |
-
|
| 31 |
]
|
| 32 |
|
| 33 |
HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
|
|
@@ -132,19 +128,29 @@ def get_least_busy_worker() -> Optional[str]:
|
|
| 132 |
return min(healthy, key=lambda url: worker_health[url]["active_requests"])
|
| 133 |
|
| 134 |
def select_distributed_workers() -> tuple:
|
| 135 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 136 |
healthy = get_healthy_workers()
|
| 137 |
if len(healthy) < 2:
|
| 138 |
-
return (healthy[0],
|
| 139 |
|
|
|
|
| 140 |
sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
|
| 141 |
|
| 142 |
-
if len(healthy) >=
|
| 143 |
-
#
|
| 144 |
-
return (sorted_workers[0], sorted_workers[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
else:
|
| 146 |
-
#
|
| 147 |
-
return (sorted_workers[0], sorted_workers[1]
|
| 148 |
|
| 149 |
async def broadcast_stats():
|
| 150 |
"""Broadcast stats to all connected WebSocket clients"""
|
|
@@ -220,36 +226,40 @@ async def startup_event():
|
|
| 220 |
# ============================================================================
|
| 221 |
|
| 222 |
async def distributed_generation(
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
decoder2_url: Optional[str],
|
| 226 |
request_data: dict,
|
| 227 |
endpoint: str = "generate"
|
| 228 |
):
|
| 229 |
"""
|
| 230 |
DISTRIBUTED COMPUTE MODE
|
| 231 |
-
-
|
| 232 |
-
-
|
| 233 |
"""
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# Mark roles
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
worker_health[
|
| 243 |
|
| 244 |
async def generate_tokens():
|
|
|
|
|
|
|
| 245 |
try:
|
| 246 |
-
worker_health[
|
| 247 |
request_data_tokens = {**request_data, "return_token_ids": True}
|
| 248 |
|
| 249 |
async with httpx.AsyncClient(timeout=300.0) as client:
|
| 250 |
async with client.stream(
|
| 251 |
"POST",
|
| 252 |
-
f"{
|
| 253 |
json=request_data_tokens
|
| 254 |
) as response:
|
| 255 |
async for chunk in response.aiter_text():
|
|
@@ -259,30 +269,33 @@ async def distributed_generation(
|
|
| 259 |
if "token_id" in data:
|
| 260 |
await token_queue.put(data["token_id"])
|
| 261 |
elif "done" in data:
|
| 262 |
-
|
|
|
|
|
|
|
| 263 |
break
|
| 264 |
except:
|
| 265 |
pass
|
| 266 |
except Exception as e:
|
| 267 |
print(f"❌ Generator error: {e}")
|
| 268 |
-
|
|
|
|
| 269 |
finally:
|
| 270 |
-
worker_health[
|
| 271 |
-
worker_health[
|
| 272 |
|
| 273 |
-
async def decode_tokens(decoder_url: str):
|
| 274 |
-
"""Decoder worker - processes tokens from queue"""
|
| 275 |
try:
|
| 276 |
worker_health[decoder_url]["active_requests"] += 1
|
| 277 |
batch = []
|
| 278 |
-
batch_size =
|
| 279 |
|
| 280 |
while True:
|
| 281 |
try:
|
| 282 |
-
token_id = await asyncio.wait_for(token_queue.get(), timeout=
|
| 283 |
|
| 284 |
if token_id is None:
|
| 285 |
-
# Decode remaining
|
| 286 |
if batch:
|
| 287 |
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 288 |
response = await client.post(
|
|
@@ -293,11 +306,12 @@ async def distributed_generation(
|
|
| 293 |
await text_queue.put(("text", text))
|
| 294 |
worker_health[decoder_url]["total_tokens"] += len(batch)
|
| 295 |
|
| 296 |
-
await text_queue.put(("done",
|
| 297 |
break
|
| 298 |
|
| 299 |
batch.append(token_id)
|
| 300 |
|
|
|
|
| 301 |
if len(batch) >= batch_size:
|
| 302 |
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 303 |
response = await client.post(
|
|
@@ -314,8 +328,8 @@ async def distributed_generation(
|
|
| 314 |
continue
|
| 315 |
|
| 316 |
except Exception as e:
|
| 317 |
-
print(f"❌ Decoder {
|
| 318 |
-
await text_queue.put(("done",
|
| 319 |
finally:
|
| 320 |
worker_health[decoder_url]["active_requests"] -= 1
|
| 321 |
worker_health[decoder_url]["role"] = "idle"
|
|
@@ -323,14 +337,16 @@ async def distributed_generation(
|
|
| 323 |
# Start generator
|
| 324 |
gen_task = asyncio.create_task(generate_tokens())
|
| 325 |
|
| 326 |
-
# Start
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Stream results
|
| 331 |
accumulated_text = ""
|
| 332 |
decoders_done = 0
|
| 333 |
-
total_decoders =
|
| 334 |
|
| 335 |
try:
|
| 336 |
while decoders_done < total_decoders:
|
|
@@ -346,9 +362,8 @@ async def distributed_generation(
|
|
| 346 |
|
| 347 |
finally:
|
| 348 |
await gen_task
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
await dec2_task
|
| 352 |
|
| 353 |
async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
|
| 354 |
"""Standard single-worker generation"""
|
|
|
|
| 24 |
WORKER_URLS = [
|
| 25 |
"https://bc-ai-worker-2.hf.space",
|
| 26 |
"https://bc-ai-worker-sam-z-api.hf.space",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
|
|
|
|
| 128 |
return min(healthy, key=lambda url: worker_health[url]["active_requests"])
|
| 129 |
|
| 130 |
def select_distributed_workers() -> tuple:
|
| 131 |
+
"""
|
| 132 |
+
Select workers for distributed compute
|
| 133 |
+
Returns: (generators: List[str], decoders: List[str])
|
| 134 |
+
"""
|
| 135 |
healthy = get_healthy_workers()
|
| 136 |
if len(healthy) < 2:
|
| 137 |
+
return ([healthy[0]], []) if len(healthy) == 1 else ([], [])
|
| 138 |
|
| 139 |
+
# Sort by least busy
|
| 140 |
sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
|
| 141 |
|
| 142 |
+
if len(healthy) >= 5:
|
| 143 |
+
# OPTIMAL: 1 generator, 4 decoders
|
| 144 |
+
return ([sorted_workers[0]], sorted_workers[1:5])
|
| 145 |
+
elif len(healthy) == 4:
|
| 146 |
+
# 1 generator, 3 decoders
|
| 147 |
+
return ([sorted_workers[0]], sorted_workers[1:4])
|
| 148 |
+
elif len(healthy) == 3:
|
| 149 |
+
# 1 generator, 2 decoders
|
| 150 |
+
return ([sorted_workers[0]], sorted_workers[1:3])
|
| 151 |
else:
|
| 152 |
+
# 1 generator, 1 decoder
|
| 153 |
+
return ([sorted_workers[0]], [sorted_workers[1]])
|
| 154 |
|
| 155 |
async def broadcast_stats():
|
| 156 |
"""Broadcast stats to all connected WebSocket clients"""
|
|
|
|
| 226 |
# ============================================================================
|
| 227 |
|
| 228 |
async def distributed_generation(
|
| 229 |
+
generators: List[str],
|
| 230 |
+
decoders: List[str],
|
|
|
|
| 231 |
request_data: dict,
|
| 232 |
endpoint: str = "generate"
|
| 233 |
):
|
| 234 |
"""
|
| 235 |
DISTRIBUTED COMPUTE MODE
|
| 236 |
+
- Generator(s) produce token IDs
|
| 237 |
+
- Multiple decoders process in parallel (load balanced)
|
| 238 |
"""
|
| 239 |
|
| 240 |
+
if not generators or not decoders:
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
token_queue = asyncio.Queue(maxsize=50)
|
| 244 |
+
text_queue = asyncio.Queue(maxsize=50)
|
| 245 |
|
| 246 |
# Mark roles
|
| 247 |
+
for gen_url in generators:
|
| 248 |
+
worker_health[gen_url]["role"] = "generator"
|
| 249 |
+
for dec_url in decoders:
|
| 250 |
+
worker_health[dec_url]["role"] = "decoder"
|
| 251 |
|
| 252 |
async def generate_tokens():
|
| 253 |
+
"""Generator worker(s)"""
|
| 254 |
+
gen_url = generators[0] # primary generator
|
| 255 |
try:
|
| 256 |
+
worker_health[gen_url]["active_requests"] += 1
|
| 257 |
request_data_tokens = {**request_data, "return_token_ids": True}
|
| 258 |
|
| 259 |
async with httpx.AsyncClient(timeout=300.0) as client:
|
| 260 |
async with client.stream(
|
| 261 |
"POST",
|
| 262 |
+
f"{gen_url}/{endpoint}",
|
| 263 |
json=request_data_tokens
|
| 264 |
) as response:
|
| 265 |
async for chunk in response.aiter_text():
|
|
|
|
| 269 |
if "token_id" in data:
|
| 270 |
await token_queue.put(data["token_id"])
|
| 271 |
elif "done" in data:
|
| 272 |
+
# Send done signal for each decoder
|
| 273 |
+
for _ in decoders:
|
| 274 |
+
await token_queue.put(None)
|
| 275 |
break
|
| 276 |
except:
|
| 277 |
pass
|
| 278 |
except Exception as e:
|
| 279 |
print(f"❌ Generator error: {e}")
|
| 280 |
+
for _ in decoders:
|
| 281 |
+
await token_queue.put(None)
|
| 282 |
finally:
|
| 283 |
+
worker_health[gen_url]["active_requests"] -= 1
|
| 284 |
+
worker_health[gen_url]["role"] = "idle"
|
| 285 |
|
| 286 |
+
async def decode_tokens(decoder_url: str, decoder_id: int):
|
| 287 |
+
"""Decoder worker - processes tokens from shared queue"""
|
| 288 |
try:
|
| 289 |
worker_health[decoder_url]["active_requests"] += 1
|
| 290 |
batch = []
|
| 291 |
+
batch_size = 2 # smaller batches for faster streaming
|
| 292 |
|
| 293 |
while True:
|
| 294 |
try:
|
| 295 |
+
token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0)
|
| 296 |
|
| 297 |
if token_id is None:
|
| 298 |
+
# Decode remaining batch
|
| 299 |
if batch:
|
| 300 |
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 301 |
response = await client.post(
|
|
|
|
| 306 |
await text_queue.put(("text", text))
|
| 307 |
worker_health[decoder_url]["total_tokens"] += len(batch)
|
| 308 |
|
| 309 |
+
await text_queue.put(("done", decoder_id))
|
| 310 |
break
|
| 311 |
|
| 312 |
batch.append(token_id)
|
| 313 |
|
| 314 |
+
# Decode when batch is full
|
| 315 |
if len(batch) >= batch_size:
|
| 316 |
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 317 |
response = await client.post(
|
|
|
|
| 328 |
continue
|
| 329 |
|
| 330 |
except Exception as e:
|
| 331 |
+
print(f"❌ Decoder {decoder_id} error: {e}")
|
| 332 |
+
await text_queue.put(("done", decoder_id))
|
| 333 |
finally:
|
| 334 |
worker_health[decoder_url]["active_requests"] -= 1
|
| 335 |
worker_health[decoder_url]["role"] = "idle"
|
|
|
|
| 337 |
# Start generator
|
| 338 |
gen_task = asyncio.create_task(generate_tokens())
|
| 339 |
|
| 340 |
+
# Start all decoders
|
| 341 |
+
decoder_tasks = [
|
| 342 |
+
asyncio.create_task(decode_tokens(dec_url, i))
|
| 343 |
+
for i, dec_url in enumerate(decoders)
|
| 344 |
+
]
|
| 345 |
|
| 346 |
# Stream results
|
| 347 |
accumulated_text = ""
|
| 348 |
decoders_done = 0
|
| 349 |
+
total_decoders = len(decoders)
|
| 350 |
|
| 351 |
try:
|
| 352 |
while decoders_done < total_decoders:
|
|
|
|
| 362 |
|
| 363 |
finally:
|
| 364 |
await gen_task
|
| 365 |
+
for task in decoder_tasks:
|
| 366 |
+
await task
|
|
|
|
| 367 |
|
| 368 |
async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
|
| 369 |
"""Standard single-worker generation"""
|