Bc-AI commited on
Commit
87997a1
·
verified ·
1 Parent(s): 0377525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -43
app.py CHANGED
@@ -24,10 +24,6 @@ app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0")
24
  WORKER_URLS = [
25
  "https://bc-ai-worker-2.hf.space",
26
  "https://bc-ai-worker-sam-z-api.hf.space",
27
- "https://bc-ai-worker-3.hf.space",
28
- "https://bc-ai-worker-4.hf.space",
29
- "https://bc-ai-worker-5.hf.space"
30
-
31
  ]
32
 
33
  HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
@@ -132,19 +128,29 @@ def get_least_busy_worker() -> Optional[str]:
132
  return min(healthy, key=lambda url: worker_health[url]["active_requests"])
133
 
134
  def select_distributed_workers() -> tuple:
135
- """Select workers for distributed compute"""
 
 
 
136
  healthy = get_healthy_workers()
137
  if len(healthy) < 2:
138
- return (healthy[0], None, None) if len(healthy) == 1 else (None, None, None)
139
 
 
140
  sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
141
 
142
- if len(healthy) >= 3:
143
- # 3 workers: 1 generator, 2 decoders
144
- return (sorted_workers[0], sorted_workers[1], sorted_workers[2])
 
 
 
 
 
 
145
  else:
146
- # 2 workers: 1 generator, 1 decoder
147
- return (sorted_workers[0], sorted_workers[1], None)
148
 
149
  async def broadcast_stats():
150
  """Broadcast stats to all connected WebSocket clients"""
@@ -220,36 +226,40 @@ async def startup_event():
220
  # ============================================================================
221
 
222
  async def distributed_generation(
223
- generator_url: str,
224
- decoder1_url: str,
225
- decoder2_url: Optional[str],
226
  request_data: dict,
227
  endpoint: str = "generate"
228
  ):
229
  """
230
  DISTRIBUTED COMPUTE MODE
231
- - 1 worker generates token IDs
232
- - 2 workers decode in parallel (load balanced)
233
  """
234
 
235
- token_queue = asyncio.Queue(maxsize=20)
236
- text_queue = asyncio.Queue(maxsize=20)
 
 
 
237
 
238
  # Mark roles
239
- worker_health[generator_url]["role"] = "generator"
240
- worker_health[decoder1_url]["role"] = "decoder"
241
- if decoder2_url:
242
- worker_health[decoder2_url]["role"] = "decoder"
243
 
244
  async def generate_tokens():
 
 
245
  try:
246
- worker_health[generator_url]["active_requests"] += 1
247
  request_data_tokens = {**request_data, "return_token_ids": True}
248
 
249
  async with httpx.AsyncClient(timeout=300.0) as client:
250
  async with client.stream(
251
  "POST",
252
- f"{generator_url}/{endpoint}",
253
  json=request_data_tokens
254
  ) as response:
255
  async for chunk in response.aiter_text():
@@ -259,30 +269,33 @@ async def distributed_generation(
259
  if "token_id" in data:
260
  await token_queue.put(data["token_id"])
261
  elif "done" in data:
262
- await token_queue.put(None)
 
 
263
  break
264
  except:
265
  pass
266
  except Exception as e:
267
  print(f"❌ Generator error: {e}")
268
- await token_queue.put(None)
 
269
  finally:
270
- worker_health[generator_url]["active_requests"] -= 1
271
- worker_health[generator_url]["role"] = "idle"
272
 
273
- async def decode_tokens(decoder_url: str):
274
- """Decoder worker - processes tokens from queue"""
275
  try:
276
  worker_health[decoder_url]["active_requests"] += 1
277
  batch = []
278
- batch_size = 3
279
 
280
  while True:
281
  try:
282
- token_id = await asyncio.wait_for(token_queue.get(), timeout=1.0)
283
 
284
  if token_id is None:
285
- # Decode remaining
286
  if batch:
287
  async with httpx.AsyncClient(timeout=10.0) as client:
288
  response = await client.post(
@@ -293,11 +306,12 @@ async def distributed_generation(
293
  await text_queue.put(("text", text))
294
  worker_health[decoder_url]["total_tokens"] += len(batch)
295
 
296
- await text_queue.put(("done", decoder_url))
297
  break
298
 
299
  batch.append(token_id)
300
 
 
301
  if len(batch) >= batch_size:
302
  async with httpx.AsyncClient(timeout=10.0) as client:
303
  response = await client.post(
@@ -314,8 +328,8 @@ async def distributed_generation(
314
  continue
315
 
316
  except Exception as e:
317
- print(f"❌ Decoder {decoder_url} error: {e}")
318
- await text_queue.put(("done", decoder_url))
319
  finally:
320
  worker_health[decoder_url]["active_requests"] -= 1
321
  worker_health[decoder_url]["role"] = "idle"
@@ -323,14 +337,16 @@ async def distributed_generation(
323
  # Start generator
324
  gen_task = asyncio.create_task(generate_tokens())
325
 
326
- # Start decoder(s)
327
- dec1_task = asyncio.create_task(decode_tokens(decoder1_url))
328
- dec2_task = asyncio.create_task(decode_tokens(decoder2_url)) if decoder2_url else None
 
 
329
 
330
  # Stream results
331
  accumulated_text = ""
332
  decoders_done = 0
333
- total_decoders = 2 if decoder2_url else 1
334
 
335
  try:
336
  while decoders_done < total_decoders:
@@ -346,9 +362,8 @@ async def distributed_generation(
346
 
347
  finally:
348
  await gen_task
349
- await dec1_task
350
- if dec2_task:
351
- await dec2_task
352
 
353
  async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
354
  """Standard single-worker generation"""
 
24
  WORKER_URLS = [
25
  "https://bc-ai-worker-2.hf.space",
26
  "https://bc-ai-worker-sam-z-api.hf.space",
 
 
 
 
27
  ]
28
 
29
  HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
 
128
  return min(healthy, key=lambda url: worker_health[url]["active_requests"])
129
 
130
  def select_distributed_workers() -> tuple:
131
+ """
132
+ Select workers for distributed compute
133
+ Returns: (generators: List[str], decoders: List[str])
134
+ """
135
  healthy = get_healthy_workers()
136
  if len(healthy) < 2:
137
+ return ([healthy[0]], []) if len(healthy) == 1 else ([], [])
138
 
139
+ # Sort by least busy
140
  sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
141
 
142
+ if len(healthy) >= 5:
143
+ # OPTIMAL: 1 generator, 4 decoders
144
+ return ([sorted_workers[0]], sorted_workers[1:5])
145
+ elif len(healthy) == 4:
146
+ # 1 generator, 3 decoders
147
+ return ([sorted_workers[0]], sorted_workers[1:4])
148
+ elif len(healthy) == 3:
149
+ # 1 generator, 2 decoders
150
+ return ([sorted_workers[0]], sorted_workers[1:3])
151
  else:
152
+ # 1 generator, 1 decoder
153
+ return ([sorted_workers[0]], [sorted_workers[1]])
154
 
155
  async def broadcast_stats():
156
  """Broadcast stats to all connected WebSocket clients"""
 
226
  # ============================================================================
227
 
228
  async def distributed_generation(
229
+ generators: List[str],
230
+ decoders: List[str],
 
231
  request_data: dict,
232
  endpoint: str = "generate"
233
  ):
234
  """
235
  DISTRIBUTED COMPUTE MODE
236
+ - Generator(s) produce token IDs
237
+ - Multiple decoders process in parallel (load balanced)
238
  """
239
 
240
+ if not generators or not decoders:
241
+ return
242
+
243
+ token_queue = asyncio.Queue(maxsize=50)
244
+ text_queue = asyncio.Queue(maxsize=50)
245
 
246
  # Mark roles
247
+ for gen_url in generators:
248
+ worker_health[gen_url]["role"] = "generator"
249
+ for dec_url in decoders:
250
+ worker_health[dec_url]["role"] = "decoder"
251
 
252
  async def generate_tokens():
253
+ """Generator worker(s)"""
254
+ gen_url = generators[0] # primary generator
255
  try:
256
+ worker_health[gen_url]["active_requests"] += 1
257
  request_data_tokens = {**request_data, "return_token_ids": True}
258
 
259
  async with httpx.AsyncClient(timeout=300.0) as client:
260
  async with client.stream(
261
  "POST",
262
+ f"{gen_url}/{endpoint}",
263
  json=request_data_tokens
264
  ) as response:
265
  async for chunk in response.aiter_text():
 
269
  if "token_id" in data:
270
  await token_queue.put(data["token_id"])
271
  elif "done" in data:
272
+ # Send done signal for each decoder
273
+ for _ in decoders:
274
+ await token_queue.put(None)
275
  break
276
  except:
277
  pass
278
  except Exception as e:
279
  print(f"❌ Generator error: {e}")
280
+ for _ in decoders:
281
+ await token_queue.put(None)
282
  finally:
283
+ worker_health[gen_url]["active_requests"] -= 1
284
+ worker_health[gen_url]["role"] = "idle"
285
 
286
+ async def decode_tokens(decoder_url: str, decoder_id: int):
287
+ """Decoder worker - processes tokens from shared queue"""
288
  try:
289
  worker_health[decoder_url]["active_requests"] += 1
290
  batch = []
291
+ batch_size = 2 # smaller batches for faster streaming
292
 
293
  while True:
294
  try:
295
+ token_id = await asyncio.wait_for(token_queue.get(), timeout=2.0)
296
 
297
  if token_id is None:
298
+ # Decode remaining batch
299
  if batch:
300
  async with httpx.AsyncClient(timeout=10.0) as client:
301
  response = await client.post(
 
306
  await text_queue.put(("text", text))
307
  worker_health[decoder_url]["total_tokens"] += len(batch)
308
 
309
+ await text_queue.put(("done", decoder_id))
310
  break
311
 
312
  batch.append(token_id)
313
 
314
+ # Decode when batch is full
315
  if len(batch) >= batch_size:
316
  async with httpx.AsyncClient(timeout=10.0) as client:
317
  response = await client.post(
 
328
  continue
329
 
330
  except Exception as e:
331
+ print(f"❌ Decoder {decoder_id} error: {e}")
332
+ await text_queue.put(("done", decoder_id))
333
  finally:
334
  worker_health[decoder_url]["active_requests"] -= 1
335
  worker_health[decoder_url]["role"] = "idle"
 
337
  # Start generator
338
  gen_task = asyncio.create_task(generate_tokens())
339
 
340
+ # Start all decoders
341
+ decoder_tasks = [
342
+ asyncio.create_task(decode_tokens(dec_url, i))
343
+ for i, dec_url in enumerate(decoders)
344
+ ]
345
 
346
  # Stream results
347
  accumulated_text = ""
348
  decoders_done = 0
349
+ total_decoders = len(decoders)
350
 
351
  try:
352
  while decoders_done < total_decoders:
 
362
 
363
  finally:
364
  await gen_task
365
+ for task in decoder_tasks:
366
+ await task
 
367
 
368
  async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
369
  """Standard single-worker generation"""