Bc-AI commited on
Commit
bdd11d5
·
verified ·
1 Parent(s): 4ad8327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +627 -176
app.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
- SAM-Z-1 Smart Load Balancing Cluster Head Node
3
- - Light load: parallel gen/decode split for max speed
4
- - Heavy load: 1 worker per request for throughput
5
  """
6
 
7
- from fastapi import FastAPI, HTTPException
8
- from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  import httpx
11
  import asyncio
@@ -15,7 +15,7 @@ from typing import List, Optional, Dict
15
  from collections import deque
16
  import random
17
 
18
- app = FastAPI(title="SAM-Z-1 Smart Cluster API", version="3.0.0")
19
 
20
  # ============================================================================
21
  # Configuration
@@ -26,17 +26,36 @@ WORKER_URLS = [
26
  "https://bc-ai-worker-sam-z-api.hf.space",
27
  ]
28
 
29
- HEALTH_CHECK_INTERVAL = 30
30
- LOAD_CHECK_WINDOW = 10 # seconds to measure load
31
 
32
- # Load thresholds
33
- LIGHT_LOAD_THRESHOLD = 2 # requests in window
34
- HEAVY_LOAD_THRESHOLD = 5 # requests in window
35
 
36
  # Worker state
37
- worker_health = {url: {"healthy": True, "last_check": 0, "active_requests": 0} for url in WORKER_URLS}
38
- request_timestamps = deque(maxlen=100) # track recent requests
39
- current_load_mode = "light" # "light" or "heavy"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # ============================================================================
42
  # Request Models
@@ -69,13 +88,10 @@ class ChatRequest(BaseModel):
69
  # ============================================================================
70
 
71
  def get_current_load() -> int:
72
- """Calculate current load based on recent requests"""
73
  now = time.time()
74
- # Count requests in the last LOAD_CHECK_WINDOW seconds
75
  return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW)
76
 
77
  def update_load_mode():
78
- """Update load mode based on current load"""
79
  global current_load_mode
80
  load = get_current_load()
81
 
@@ -83,37 +99,82 @@ def update_load_mode():
83
  current_load_mode = "light"
84
  elif load >= HEAVY_LOAD_THRESHOLD:
85
  current_load_mode = "heavy"
86
- # hysteresis zone between thresholds maintains current mode
87
 
88
  return current_load_mode, load
89
 
90
  def track_request():
91
- """Track a new request"""
92
  request_timestamps.append(time.time())
 
93
 
94
  def get_healthy_workers() -> List[str]:
95
- """Get list of healthy workers"""
96
  return [url for url, status in worker_health.items() if status["healthy"]]
97
 
98
  def get_least_busy_worker() -> Optional[str]:
99
- """Get worker with fewest active requests"""
100
  healthy = get_healthy_workers()
101
  if not healthy:
102
  return None
103
  return min(healthy, key=lambda url: worker_health[url]["active_requests"])
104
 
105
- def select_worker_pair() -> tuple:
106
- """Select 2 workers for parallel operation"""
107
  healthy = get_healthy_workers()
108
  if len(healthy) < 2:
109
- return (healthy[0], None) if len(healthy) == 1 else (None, None)
110
 
111
- # Sort by active requests, take 2 least busy
112
  sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
113
- return (sorted_workers[0], sorted_workers[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  async def check_worker_health(worker_url: str) -> bool:
116
- """Check if a worker is healthy"""
117
  try:
118
  async with httpx.AsyncClient(timeout=5.0) as client:
119
  response = await client.get(f"{worker_url}/health")
@@ -122,54 +183,50 @@ async def check_worker_health(worker_url: str) -> bool:
122
  return False
123
 
124
  async def health_check_loop():
125
- """Background task to check worker health"""
126
  while True:
127
  for worker_url in WORKER_URLS:
128
  healthy = await check_worker_health(worker_url)
129
  worker_health[worker_url]["healthy"] = healthy
130
  worker_health[worker_url]["last_check"] = time.time()
131
-
132
- status = "✅" if healthy else "❌"
133
- active = worker_health[worker_url]["active_requests"]
134
- print(f"{status} {worker_url}: {'healthy' if healthy else 'unhealthy'} | Active: {active}")
135
 
136
- mode, load = update_load_mode()
137
- print(f"📊 Load mode: {mode.upper()} | Current load: {load} req/{LOAD_CHECK_WINDOW}s")
138
 
139
  await asyncio.sleep(HEALTH_CHECK_INTERVAL)
140
 
141
  @app.on_event("startup")
142
  async def startup_event():
143
- """Start health check loop on startup"""
144
  asyncio.create_task(health_check_loop())
145
 
146
  # ============================================================================
147
- # Generation Strategies
148
  # ============================================================================
149
 
150
- async def light_load_generation(
151
  generator_url: str,
152
- decoder_url: str,
 
153
  request_data: dict,
154
  endpoint: str = "generate"
155
  ):
156
  """
157
- LIGHT LOAD MODE: Split generation and decoding
158
- - Generator worker: produces token IDs only
159
- - Decoder worker: decodes token IDs to text
160
- This parallelizes the bottleneck!
161
  """
162
 
163
- # Queues for pipeline
164
- token_queue = asyncio.Queue(maxsize=10)
165
- text_queue = asyncio.Queue(maxsize=10)
 
 
 
 
 
166
 
167
  async def generate_tokens():
168
- """Worker 1: Generate token IDs"""
169
  try:
170
  worker_health[generator_url]["active_requests"] += 1
171
-
172
- # Request token IDs only mode
173
  request_data_tokens = {**request_data, "return_token_ids": True}
174
 
175
  async with httpx.AsyncClient(timeout=300.0) as client:
@@ -185,7 +242,7 @@ async def light_load_generation(
185
  if "token_id" in data:
186
  await token_queue.put(data["token_id"])
187
  elif "done" in data:
188
- await token_queue.put(None) # Signal end
189
  break
190
  except:
191
  pass
@@ -194,21 +251,21 @@ async def light_load_generation(
194
  await token_queue.put(None)
195
  finally:
196
  worker_health[generator_url]["active_requests"] -= 1
 
197
 
198
- async def decode_tokens():
199
- """Worker 2: Decode token IDs to text"""
200
  try:
201
  worker_health[decoder_url]["active_requests"] += 1
202
-
203
  batch = []
204
- batch_size = 5 # decode in small batches for speed
205
 
206
  while True:
207
  try:
208
  token_id = await asyncio.wait_for(token_queue.get(), timeout=1.0)
209
 
210
  if token_id is None:
211
- # Decode remaining batch
212
  if batch:
213
  async with httpx.AsyncClient(timeout=10.0) as client:
214
  response = await client.post(
@@ -217,13 +274,13 @@ async def light_load_generation(
217
  )
218
  text = response.json()["text"]
219
  await text_queue.put(("text", text))
 
220
 
221
- await text_queue.put(("done", None))
222
  break
223
 
224
  batch.append(token_id)
225
 
226
- # Decode batch when full
227
  if len(batch) >= batch_size:
228
  async with httpx.AsyncClient(timeout=10.0) as client:
229
  response = await client.post(
@@ -232,6 +289,7 @@ async def light_load_generation(
232
  )
233
  text = response.json()["text"]
234
  await text_queue.put(("text", text))
 
235
 
236
  batch = []
237
 
@@ -239,23 +297,31 @@ async def light_load_generation(
239
  continue
240
 
241
  except Exception as e:
242
- print(f"❌ Decoder error: {e}")
243
- await text_queue.put(("done", None))
244
  finally:
245
  worker_health[decoder_url]["active_requests"] -= 1
 
246
 
247
- # Start both pipelines
248
  gen_task = asyncio.create_task(generate_tokens())
249
- dec_task = asyncio.create_task(decode_tokens())
250
 
251
- # Stream decoded text
 
 
 
 
252
  accumulated_text = ""
 
 
 
253
  try:
254
- while True:
255
  msg_type, data = await text_queue.get()
256
 
257
  if msg_type == "done":
258
- break
 
259
 
260
  if msg_type == "text":
261
  accumulated_text += data
@@ -263,19 +329,15 @@ async def light_load_generation(
263
 
264
  finally:
265
  await gen_task
266
- await dec_task
 
 
267
 
268
- async def heavy_load_generation(
269
- worker_url: str,
270
- request_data: dict,
271
- endpoint: str = "generate"
272
- ):
273
- """
274
- HEAVY LOAD MODE: Single worker per request
275
- Standard streaming for max throughput
276
- """
277
  try:
278
  worker_health[worker_url]["active_requests"] += 1
 
279
 
280
  async with httpx.AsyncClient(timeout=300.0) as client:
281
  async with client.stream(
@@ -292,104 +354,481 @@ async def heavy_load_generation(
292
 
293
  finally:
294
  worker_health[worker_url]["active_requests"] -= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # ============================================================================
297
  # API Endpoints
298
  # ============================================================================
299
 
300
- @app.get("/")
301
- async def root():
302
- """API info"""
303
- healthy_count = len(get_healthy_workers())
304
  mode, load = update_load_mode()
 
305
 
306
  return {
307
- "name": "SAM-Z-1 Smart Cluster API",
308
- "version": "3.0.0",
309
  "mode": mode,
310
  "current_load": load,
311
  "workers": len(WORKER_URLS),
312
  "healthy_workers": healthy_count,
313
- "features": [
314
- "smart_load_balancing",
315
- "parallel_gen_decode",
316
- "adaptive_routing"
317
- ],
318
- "load_strategy": {
319
- "light": "parallel gen/decode split for speed",
320
- "heavy": "1 worker per request for throughput"
321
- },
322
- "endpoints": {
323
- "generate": "/v1/generate",
324
- "chat": "/v1/chat",
325
- "health": "/health",
326
- "workers": "/workers",
327
- "stats": "/stats"
328
- }
329
  }
330
 
331
  @app.get("/health")
332
  async def health():
333
- """Health check endpoint"""
334
  healthy_count = len(get_healthy_workers())
335
- mode, load = update_load_mode()
336
-
337
  return {
338
  "status": "healthy" if healthy_count > 0 else "unhealthy",
339
- "workers_total": len(WORKER_URLS),
340
- "workers_healthy": healthy_count,
341
- "load_mode": mode,
342
- "current_load": load
343
- }
344
-
345
- @app.get("/workers")
346
- async def workers_status():
347
- """Get status of all workers"""
348
- return {
349
- "workers": [
350
- {
351
- "url": url,
352
- "healthy": status["healthy"],
353
- "active_requests": status["active_requests"],
354
- "last_check": status["last_check"]
355
- }
356
- for url, status in worker_health.items()
357
- ]
358
- }
359
-
360
- @app.get("/stats")
361
- async def stats():
362
- """Get cluster statistics"""
363
- mode, load = update_load_mode()
364
-
365
- return {
366
- "load_mode": mode,
367
- "current_load": load,
368
- "load_window_seconds": LOAD_CHECK_WINDOW,
369
- "thresholds": {
370
- "light": LIGHT_LOAD_THRESHOLD,
371
- "heavy": HEAVY_LOAD_THRESHOLD
372
- },
373
- "recent_requests": len(request_timestamps),
374
- "worker_stats": {
375
- url: {
376
- "healthy": status["healthy"],
377
- "active": status["active_requests"]
378
- }
379
- for url, status in worker_health.items()
380
- }
381
  }
382
 
383
  @app.post("/v1/generate")
384
  async def generate(request: GenerateRequest):
385
- """Generate text with smart load balancing"""
386
-
387
  track_request()
388
  mode, load = update_load_mode()
389
 
390
  healthy = get_healthy_workers()
391
  if not healthy:
392
- raise HTTPException(status_code=503, detail="No healthy workers available")
 
393
 
394
  request_data = {
395
  "prompt": request.prompt,
@@ -401,33 +840,39 @@ async def generate(request: GenerateRequest):
401
  "stream": True
402
  }
403
 
404
- print(f"🎯 Mode: {mode.upper()} | Load: {load} | Request: generate")
405
 
406
- if mode == "light" and len(healthy) >= 2:
407
- # LIGHT LOAD: parallel gen/decode
408
- generator, decoder = select_worker_pair()
409
- return StreamingResponse(
410
- light_load_generation(generator, decoder, request_data, "generate"),
411
- media_type="text/event-stream"
412
- )
413
- else:
414
- # HEAVY LOAD: single worker
415
- worker = get_least_busy_worker()
416
- return StreamingResponse(
417
- heavy_load_generation(worker, request_data, "generate"),
418
- media_type="text/event-stream"
419
- )
 
 
 
 
 
 
420
 
421
  @app.post("/v1/chat")
422
  async def chat(request: ChatRequest):
423
- """Chat completion with smart load balancing"""
424
-
425
  track_request()
426
  mode, load = update_load_mode()
427
 
428
  healthy = get_healthy_workers()
429
  if not healthy:
430
- raise HTTPException(status_code=503, detail="No healthy workers available")
 
431
 
432
  request_data = {
433
  "messages": [{"role": m.role, "content": m.content} for m in request.messages],
@@ -439,22 +884,28 @@ async def chat(request: ChatRequest):
439
  "stream": True
440
  }
441
 
442
- print(f"🎯 Mode: {mode.upper()} | Load: {load} | Request: chat")
443
 
444
- if mode == "light" and len(healthy) >= 2:
445
- # LIGHT LOAD: parallel gen/decode
446
- generator, decoder = select_worker_pair()
447
- return StreamingResponse(
448
- light_load_generation(generator, decoder, request_data, "chat"),
449
- media_type="text/event-stream"
450
- )
451
- else:
452
- # HEAVY LOAD: single worker
453
- worker = get_least_busy_worker()
454
- return StreamingResponse(
455
- heavy_load_generation(worker, request_data, "chat"),
456
- media_type="text/event-stream"
457
- )
 
 
 
 
 
 
458
 
459
  # ============================================================================
460
  # Launch
 
1
  """
2
+ SAM-Z-1 Distributed Compute Cluster Head Node
3
+ - Smart load balancing with distributed compute
4
+ - Real-time status dashboard
5
  """
6
 
7
+ from fastapi import FastAPI, HTTPException, WebSocket
8
+ from fastapi.responses import StreamingResponse, HTMLResponse
9
  from pydantic import BaseModel
10
  import httpx
11
  import asyncio
 
15
  from collections import deque
16
  import random
17
 
18
+ app = FastAPI(title="SAM-Z-1 Distributed Cluster", version="4.0.0")
19
 
20
  # ============================================================================
21
  # Configuration
 
26
  "https://bc-ai-worker-sam-z-api.hf.space",
27
  ]
28
 
29
+ HEALTH_CHECK_INTERVAL = 5 # faster checks for real-time dashboard
30
+ LOAD_CHECK_WINDOW = 10
31
 
32
+ LIGHT_LOAD_THRESHOLD = 2
33
+ HEAVY_LOAD_THRESHOLD = 5
 
34
 
35
  # Worker state
36
+ worker_health = {
37
+ url: {
38
+ "healthy": True,
39
+ "last_check": 0,
40
+ "active_requests": 0,
41
+ "total_requests": 0,
42
+ "total_tokens": 0,
43
+ "avg_latency": 0,
44
+ "role": "idle" # "generator", "decoder", "full", "idle"
45
+ } for url in WORKER_URLS
46
+ }
47
+
48
+ request_timestamps = deque(maxlen=100)
49
+ current_load_mode = "light"
50
+ cluster_stats = {
51
+ "total_requests": 0,
52
+ "successful_requests": 0,
53
+ "failed_requests": 0,
54
+ "uptime_start": time.time()
55
+ }
56
+
57
+ # Active WebSocket connections for real-time updates
58
+ active_connections = set()
59
 
60
  # ============================================================================
61
  # Request Models
 
88
  # ============================================================================
89
 
90
  def get_current_load() -> int:
 
91
  now = time.time()
 
92
  return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW)
93
 
94
  def update_load_mode():
 
95
  global current_load_mode
96
  load = get_current_load()
97
 
 
99
  current_load_mode = "light"
100
  elif load >= HEAVY_LOAD_THRESHOLD:
101
  current_load_mode = "heavy"
 
102
 
103
  return current_load_mode, load
104
 
105
  def track_request():
 
106
  request_timestamps.append(time.time())
107
+ cluster_stats["total_requests"] += 1
108
 
109
  def get_healthy_workers() -> List[str]:
 
110
  return [url for url, status in worker_health.items() if status["healthy"]]
111
 
112
  def get_least_busy_worker() -> Optional[str]:
 
113
  healthy = get_healthy_workers()
114
  if not healthy:
115
  return None
116
  return min(healthy, key=lambda url: worker_health[url]["active_requests"])
117
 
118
+ def select_distributed_workers() -> tuple:
119
+ """Select workers for distributed compute"""
120
  healthy = get_healthy_workers()
121
  if len(healthy) < 2:
122
+ return (healthy[0], None, None) if len(healthy) == 1 else (None, None, None)
123
 
 
124
  sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
125
+
126
+ if len(healthy) >= 3:
127
+ # 3 workers: 1 generator, 2 decoders
128
+ return (sorted_workers[0], sorted_workers[1], sorted_workers[2])
129
+ else:
130
+ # 2 workers: 1 generator, 1 decoder
131
+ return (sorted_workers[0], sorted_workers[1], None)
132
+
133
+ async def broadcast_stats():
134
+ """Broadcast stats to all connected WebSocket clients"""
135
+ if not active_connections:
136
+ return
137
+
138
+ mode, load = update_load_mode()
139
+ uptime = time.time() - cluster_stats["uptime_start"]
140
+
141
+ stats = {
142
+ "timestamp": time.time(),
143
+ "mode": mode,
144
+ "load": load,
145
+ "workers": [
146
+ {
147
+ "url": url.split("//")[1].split(".")[0], # shorter name
148
+ "healthy": status["healthy"],
149
+ "active": status["active_requests"],
150
+ "total": status["total_requests"],
151
+ "tokens": status["total_tokens"],
152
+ "latency": round(status["avg_latency"], 2),
153
+ "role": status["role"]
154
+ }
155
+ for url, status in worker_health.items()
156
+ ],
157
+ "cluster": {
158
+ "total_requests": cluster_stats["total_requests"],
159
+ "successful": cluster_stats["successful_requests"],
160
+ "failed": cluster_stats["failed_requests"],
161
+ "uptime": round(uptime, 0),
162
+ "rps": round(cluster_stats["total_requests"] / uptime if uptime > 0 else 0, 2)
163
+ }
164
+ }
165
+
166
+ # Broadcast to all connections
167
+ disconnected = set()
168
+ for ws in active_connections:
169
+ try:
170
+ await ws.send_json(stats)
171
+ except:
172
+ disconnected.add(ws)
173
+
174
+ # Remove disconnected
175
+ active_connections.difference_update(disconnected)
176
 
177
  async def check_worker_health(worker_url: str) -> bool:
 
178
  try:
179
  async with httpx.AsyncClient(timeout=5.0) as client:
180
  response = await client.get(f"{worker_url}/health")
 
183
  return False
184
 
185
  async def health_check_loop():
 
186
  while True:
187
  for worker_url in WORKER_URLS:
188
  healthy = await check_worker_health(worker_url)
189
  worker_health[worker_url]["healthy"] = healthy
190
  worker_health[worker_url]["last_check"] = time.time()
 
 
 
 
191
 
192
+ # Broadcast to dashboard
193
+ await broadcast_stats()
194
 
195
  await asyncio.sleep(HEALTH_CHECK_INTERVAL)
196
 
197
  @app.on_event("startup")
198
  async def startup_event():
 
199
  asyncio.create_task(health_check_loop())
200
 
201
  # ============================================================================
202
+ # Distributed Compute Generation
203
  # ============================================================================
204
 
205
+ async def distributed_generation(
206
  generator_url: str,
207
+ decoder1_url: str,
208
+ decoder2_url: Optional[str],
209
  request_data: dict,
210
  endpoint: str = "generate"
211
  ):
212
  """
213
+ DISTRIBUTED COMPUTE MODE
214
+ - 1 worker generates token IDs
215
+ - 2 workers decode in parallel (load balanced)
 
216
  """
217
 
218
+ token_queue = asyncio.Queue(maxsize=20)
219
+ text_queue = asyncio.Queue(maxsize=20)
220
+
221
+ # Mark roles
222
+ worker_health[generator_url]["role"] = "generator"
223
+ worker_health[decoder1_url]["role"] = "decoder"
224
+ if decoder2_url:
225
+ worker_health[decoder2_url]["role"] = "decoder"
226
 
227
  async def generate_tokens():
 
228
  try:
229
  worker_health[generator_url]["active_requests"] += 1
 
 
230
  request_data_tokens = {**request_data, "return_token_ids": True}
231
 
232
  async with httpx.AsyncClient(timeout=300.0) as client:
 
242
  if "token_id" in data:
243
  await token_queue.put(data["token_id"])
244
  elif "done" in data:
245
+ await token_queue.put(None)
246
  break
247
  except:
248
  pass
 
251
  await token_queue.put(None)
252
  finally:
253
  worker_health[generator_url]["active_requests"] -= 1
254
+ worker_health[generator_url]["role"] = "idle"
255
 
256
+ async def decode_tokens(decoder_url: str):
257
+ """Decoder worker - processes tokens from queue"""
258
  try:
259
  worker_health[decoder_url]["active_requests"] += 1
 
260
  batch = []
261
+ batch_size = 3
262
 
263
  while True:
264
  try:
265
  token_id = await asyncio.wait_for(token_queue.get(), timeout=1.0)
266
 
267
  if token_id is None:
268
+ # Decode remaining
269
  if batch:
270
  async with httpx.AsyncClient(timeout=10.0) as client:
271
  response = await client.post(
 
274
  )
275
  text = response.json()["text"]
276
  await text_queue.put(("text", text))
277
+ worker_health[decoder_url]["total_tokens"] += len(batch)
278
 
279
+ await text_queue.put(("done", decoder_url))
280
  break
281
 
282
  batch.append(token_id)
283
 
 
284
  if len(batch) >= batch_size:
285
  async with httpx.AsyncClient(timeout=10.0) as client:
286
  response = await client.post(
 
289
  )
290
  text = response.json()["text"]
291
  await text_queue.put(("text", text))
292
+ worker_health[decoder_url]["total_tokens"] += len(batch)
293
 
294
  batch = []
295
 
 
297
  continue
298
 
299
  except Exception as e:
300
+ print(f"❌ Decoder {decoder_url} error: {e}")
301
+ await text_queue.put(("done", decoder_url))
302
  finally:
303
  worker_health[decoder_url]["active_requests"] -= 1
304
+ worker_health[decoder_url]["role"] = "idle"
305
 
306
+ # Start generator
307
  gen_task = asyncio.create_task(generate_tokens())
 
308
 
309
+ # Start decoder(s)
310
+ dec1_task = asyncio.create_task(decode_tokens(decoder1_url))
311
+ dec2_task = asyncio.create_task(decode_tokens(decoder2_url)) if decoder2_url else None
312
+
313
+ # Stream results
314
  accumulated_text = ""
315
+ decoders_done = 0
316
+ total_decoders = 2 if decoder2_url else 1
317
+
318
  try:
319
+ while decoders_done < total_decoders:
320
  msg_type, data = await text_queue.get()
321
 
322
  if msg_type == "done":
323
+ decoders_done += 1
324
+ continue
325
 
326
  if msg_type == "text":
327
  accumulated_text += data
 
329
 
330
  finally:
331
  await gen_task
332
+ await dec1_task
333
+ if dec2_task:
334
+ await dec2_task
335
 
336
+ async def heavy_load_generation(worker_url: str, request_data: dict, endpoint: str = "generate"):
337
+ """Standard single-worker generation"""
 
 
 
 
 
 
 
338
  try:
339
  worker_health[worker_url]["active_requests"] += 1
340
+ worker_health[worker_url]["role"] = "full"
341
 
342
  async with httpx.AsyncClient(timeout=300.0) as client:
343
  async with client.stream(
 
354
 
355
  finally:
356
  worker_health[worker_url]["active_requests"] -= 1
357
+ worker_health[worker_url]["role"] = "idle"
358
+
359
+ # ============================================================================
360
+ # Dashboard
361
+ # ============================================================================
362
+
363
+ @app.get("/", response_class=HTMLResponse)
364
+ async def dashboard():
365
+ """Real-time futuristic dashboard"""
366
+ return """
367
+ <!DOCTYPE html>
368
+ <html>
369
+ <head>
370
+ <title>SAM-Z-1 Cluster Control</title>
371
+ <style>
372
+ * {
373
+ margin: 0;
374
+ padding: 0;
375
+ box-sizing: border-box;
376
+ }
377
+
378
+ body {
379
+ font-family: 'Courier New', monospace;
380
+ background: linear-gradient(135deg, #0a0e27 0%, #1a1f3a 100%);
381
+ color: #00ff88;
382
+ overflow: hidden;
383
+ height: 100vh;
384
+ }
385
+
386
+ .container {
387
+ padding: 20px;
388
+ max-width: 1400px;
389
+ margin: 0 auto;
390
+ }
391
+
392
+ .header {
393
+ text-align: center;
394
+ margin-bottom: 30px;
395
+ padding: 20px;
396
+ background: rgba(0, 255, 136, 0.1);
397
+ border: 2px solid #00ff88;
398
+ border-radius: 10px;
399
+ box-shadow: 0 0 20px rgba(0, 255, 136, 0.3);
400
+ }
401
+
402
+ .header h1 {
403
+ font-size: 2.5em;
404
+ text-transform: uppercase;
405
+ letter-spacing: 5px;
406
+ text-shadow: 0 0 10px #00ff88;
407
+ animation: glow 2s ease-in-out infinite alternate;
408
+ }
409
+
410
+ @keyframes glow {
411
+ from { text-shadow: 0 0 10px #00ff88, 0 0 20px #00ff88; }
412
+ to { text-shadow: 0 0 20px #00ff88, 0 0 30px #00ff88, 0 0 40px #00ff88; }
413
+ }
414
+
415
+ .status-bar {
416
+ display: flex;
417
+ gap: 20px;
418
+ margin-bottom: 30px;
419
+ }
420
+
421
+ .stat-card {
422
+ flex: 1;
423
+ background: rgba(0, 255, 136, 0.05);
424
+ border: 1px solid #00ff88;
425
+ border-radius: 8px;
426
+ padding: 15px;
427
+ position: relative;
428
+ overflow: hidden;
429
+ }
430
+
431
+ .stat-card::before {
432
+ content: '';
433
+ position: absolute;
434
+ top: 0;
435
+ left: -100%;
436
+ width: 100%;
437
+ height: 100%;
438
+ background: linear-gradient(90deg, transparent, rgba(0, 255, 136, 0.2), transparent);
439
+ animation: scan 3s infinite;
440
+ }
441
+
442
+ @keyframes scan {
443
+ 0% { left: -100%; }
444
+ 100% { left: 100%; }
445
+ }
446
+
447
+ .stat-label {
448
+ font-size: 0.8em;
449
+ opacity: 0.7;
450
+ text-transform: uppercase;
451
+ }
452
+
453
+ .stat-value {
454
+ font-size: 2em;
455
+ font-weight: bold;
456
+ margin-top: 5px;
457
+ }
458
+
459
+ .mode-badge {
460
+ display: inline-block;
461
+ padding: 5px 15px;
462
+ border-radius: 20px;
463
+ font-size: 0.9em;
464
+ font-weight: bold;
465
+ text-transform: uppercase;
466
+ margin-top: 10px;
467
+ }
468
+
469
+ .mode-light {
470
+ background: rgba(0, 255, 136, 0.2);
471
+ border: 1px solid #00ff88;
472
+ color: #00ff88;
473
+ }
474
+
475
+ .mode-heavy {
476
+ background: rgba(255, 68, 68, 0.2);
477
+ border: 1px solid #ff4444;
478
+ color: #ff4444;
479
+ }
480
+
481
+ .workers-grid {
482
+ display: grid;
483
+ grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
484
+ gap: 20px;
485
+ margin-bottom: 30px;
486
+ }
487
+
488
+ .worker-card {
489
+ background: rgba(10, 14, 39, 0.8);
490
+ border: 2px solid #00ff88;
491
+ border-radius: 10px;
492
+ padding: 20px;
493
+ position: relative;
494
+ transition: all 0.3s;
495
+ }
496
+
497
+ .worker-card:hover {
498
+ transform: translateY(-5px);
499
+ box-shadow: 0 5px 30px rgba(0, 255, 136, 0.4);
500
+ }
501
+
502
+ .worker-card.offline {
503
+ border-color: #ff4444;
504
+ opacity: 0.6;
505
+ }
506
+
507
+ .worker-header {
508
+ display: flex;
509
+ justify-content: space-between;
510
+ align-items: center;
511
+ margin-bottom: 15px;
512
+ }
513
+
514
+ .worker-name {
515
+ font-size: 1.2em;
516
+ font-weight: bold;
517
+ }
518
+
519
+ .status-dot {
520
+ width: 12px;
521
+ height: 12px;
522
+ border-radius: 50%;
523
+ animation: pulse 2s infinite;
524
+ }
525
+
526
+ .status-dot.online {
527
+ background: #00ff88;
528
+ box-shadow: 0 0 10px #00ff88;
529
+ }
530
+
531
+ .status-dot.offline {
532
+ background: #ff4444;
533
+ box-shadow: 0 0 10px #ff4444;
534
+ }
535
+
536
+ @keyframes pulse {
537
+ 0%, 100% { opacity: 1; }
538
+ 50% { opacity: 0.5; }
539
+ }
540
+
541
+ .worker-stats {
542
+ display: grid;
543
+ grid-template-columns: repeat(2, 1fr);
544
+ gap: 10px;
545
+ margin-top: 15px;
546
+ }
547
+
548
+ .worker-stat {
549
+ background: rgba(0, 255, 136, 0.05);
550
+ padding: 10px;
551
+ border-radius: 5px;
552
+ }
553
+
554
+ .worker-stat-label {
555
+ font-size: 0.7em;
556
+ opacity: 0.7;
557
+ }
558
+
559
+ .worker-stat-value {
560
+ font-size: 1.3em;
561
+ font-weight: bold;
562
+ margin-top: 3px;
563
+ }
564
+
565
+ .role-badge {
566
+ display: inline-block;
567
+ padding: 3px 10px;
568
+ border-radius: 12px;
569
+ font-size: 0.75em;
570
+ margin-top: 10px;
571
+ font-weight: bold;
572
+ }
573
+
574
+ .role-generator {
575
+ background: rgba(255, 165, 0, 0.2);
576
+ border: 1px solid #ffa500;
577
+ color: #ffa500;
578
+ }
579
+
580
+ .role-decoder {
581
+ background: rgba(0, 191, 255, 0.2);
582
+ border: 1px solid #00bfff;
583
+ color: #00bfff;
584
+ }
585
+
586
+ .role-full {
587
+ background: rgba(138, 43, 226, 0.2);
588
+ border: 1px solid #8a2be2;
589
+ color: #8a2be2;
590
+ }
591
+
592
+ .role-idle {
593
+ background: rgba(128, 128, 128, 0.2);
594
+ border: 1px solid #808080;
595
+ color: #808080;
596
+ }
597
+
598
+ .progress-bar {
599
+ width: 100%;
600
+ height: 4px;
601
+ background: rgba(0, 255, 136, 0.1);
602
+ border-radius: 2px;
603
+ margin-top: 10px;
604
+ overflow: hidden;
605
+ }
606
+
607
+ .progress-fill {
608
+ height: 100%;
609
+ background: linear-gradient(90deg, #00ff88, #00ffff);
610
+ transition: width 0.3s;
611
+ box-shadow: 0 0 10px #00ff88;
612
+ }
613
+
614
+ .cluster-info {
615
+ background: rgba(0, 255, 136, 0.05);
616
+ border: 1px solid #00ff88;
617
+ border-radius: 8px;
618
+ padding: 20px;
619
+ }
620
+
621
+ .info-grid {
622
+ display: grid;
623
+ grid-template-columns: repeat(4, 1fr);
624
+ gap: 20px;
625
+ }
626
+
627
+ .info-item {
628
+ text-align: center;
629
+ }
630
+
631
+ .timestamp {
632
+ text-align: center;
633
+ margin-top: 20px;
634
+ opacity: 0.5;
635
+ font-size: 0.9em;
636
+ }
637
+ </style>
638
+ </head>
639
+ <body>
640
+ <div class="container">
641
+ <div class="header">
642
+ <h1>⚡ SAM-Z-1 CLUSTER ⚡</h1>
643
+ <div>DISTRIBUTED COMPUTE SYSTEM v4.0</div>
644
+ </div>
645
+
646
+ <div class="status-bar">
647
+ <div class="stat-card">
648
+ <div class="stat-label">Load Mode</div>
649
+ <div class="stat-value" id="mode">--</div>
650
+ <div class="mode-badge" id="mode-badge">INITIALIZING</div>
651
+ </div>
652
+ <div class="stat-card">
653
+ <div class="stat-label">Current Load</div>
654
+ <div class="stat-value" id="load">0</div>
655
+ <div class="stat-label">requests / 10s</div>
656
+ </div>
657
+ <div class="stat-card">
658
+ <div class="stat-label">Total Requests</div>
659
+ <div class="stat-value" id="total-req">0</div>
660
+ </div>
661
+ <div class="stat-card">
662
+ <div class="stat-label">Req/Sec</div>
663
+ <div class="stat-value" id="rps">0.00</div>
664
+ </div>
665
+ </div>
666
+
667
+ <div class="workers-grid" id="workers">
668
+ <!-- Workers populated by JS -->
669
+ </div>
670
+
671
+ <div class="cluster-info">
672
+ <div class="stat-label" style="margin-bottom: 15px;">CLUSTER STATISTICS</div>
673
+ <div class="info-grid">
674
+ <div class="info-item">
675
+ <div class="stat-label">Successful</div>
676
+ <div class="stat-value" style="font-size: 1.5em;" id="success">0</div>
677
+ </div>
678
+ <div class="info-item">
679
+ <div class="stat-label">Failed</div>
680
+ <div class="stat-value" style="font-size: 1.5em;" id="failed">0</div>
681
+ </div>
682
+ <div class="info-item">
683
+ <div class="stat-label">Uptime</div>
684
+ <div class="stat-value" style="font-size: 1.5em;" id="uptime">0s</div>
685
+ </div>
686
+ <div class="info-item">
687
+ <div class="stat-label">Healthy Workers</div>
688
+ <div class="stat-value" style="font-size: 1.5em;" id="healthy">0</div>
689
+ </div>
690
+ </div>
691
+ </div>
692
+
693
+ <div class="timestamp" id="timestamp">Last update: --</div>
694
+ </div>
695
+
696
+ <script>
697
+ const ws = new WebSocket(`ws://${window.location.host}/ws`);
698
+
699
+ ws.onmessage = (event) => {
700
+ const data = JSON.parse(event.data);
701
+ updateDashboard(data);
702
+ };
703
+
704
+ ws.onerror = () => {
705
+ console.error('WebSocket error');
706
+ };
707
+
708
+ function updateDashboard(data) {
709
+ // Mode
710
+ document.getElementById('mode').textContent = data.mode.toUpperCase();
711
+ const modeBadge = document.getElementById('mode-badge');
712
+ modeBadge.textContent = `${data.mode.toUpperCase()} MODE`;
713
+ modeBadge.className = `mode-badge mode-${data.mode}`;
714
+
715
+ // Stats
716
+ document.getElementById('load').textContent = data.load;
717
+ document.getElementById('total-req').textContent = data.cluster.total_requests;
718
+ document.getElementById('rps').textContent = data.cluster.rps;
719
+ document.getElementById('success').textContent = data.cluster.successful;
720
+ document.getElementById('failed').textContent = data.cluster.failed;
721
+ document.getElementById('uptime').textContent = formatUptime(data.cluster.uptime);
722
+
723
+ // Workers
724
+ const workersDiv = document.getElementById('workers');
725
+ const healthyCount = data.workers.filter(w => w.healthy).length;
726
+ document.getElementById('healthy').textContent = `${healthyCount}/${data.workers.length}`;
727
+
728
+ workersDiv.innerHTML = data.workers.map(worker => `
729
+ <div class="worker-card ${worker.healthy ? '' : 'offline'}">
730
+ <div class="worker-header">
731
+ <div class="worker-name">${worker.url}</div>
732
+ <div class="status-dot ${worker.healthy ? 'online' : 'offline'}"></div>
733
+ </div>
734
+ <div class="role-badge role-${worker.role}">${worker.role.toUpperCase()}</div>
735
+ <div class="worker-stats">
736
+ <div class="worker-stat">
737
+ <div class="worker-stat-label">Active</div>
738
+ <div class="worker-stat-value">${worker.active}</div>
739
+ </div>
740
+ <div class="worker-stat">
741
+ <div class="worker-stat-label">Total</div>
742
+ <div class="worker-stat-value">${worker.total}</div>
743
+ </div>
744
+ <div class="worker-stat">
745
+ <div class="worker-stat-label">Tokens</div>
746
+ <div class="worker-stat-value">${worker.tokens}</div>
747
+ </div>
748
+ <div class="worker-stat">
749
+ <div class="worker-stat-label">Latency</div>
750
+ <div class="worker-stat-value">${worker.latency}ms</div>
751
+ </div>
752
+ </div>
753
+ <div class="progress-bar">
754
+ <div class="progress-fill" style="width: ${Math.min(worker.active * 33, 100)}%"></div>
755
+ </div>
756
+ </div>
757
+ `).join('');
758
+
759
+ // Timestamp
760
+ const now = new Date();
761
+ document.getElementById('timestamp').textContent =
762
+ `Last update: ${now.toLocaleTimeString()}`;
763
+ }
764
+
765
+ function formatUptime(seconds) {
766
+ const h = Math.floor(seconds / 3600);
767
+ const m = Math.floor((seconds % 3600) / 60);
768
+ const s = Math.floor(seconds % 60);
769
+ return `${h}h ${m}m ${s}s`;
770
+ }
771
+ </script>
772
+ </body>
773
+ </html>
774
+ """
775
+
776
+ @app.websocket("/ws")
777
+ async def websocket_endpoint(websocket: WebSocket):
778
+ """WebSocket for real-time dashboard updates"""
779
+ await websocket.accept()
780
+ active_connections.add(websocket)
781
+
782
+ try:
783
+ # Send initial data
784
+ await broadcast_stats()
785
+
786
+ # Keep connection alive
787
+ while True:
788
+ await websocket.receive_text()
789
+ except:
790
+ pass
791
+ finally:
792
+ active_connections.discard(websocket)
793
 
794
  # ============================================================================
795
  # API Endpoints
796
  # ============================================================================
797
 
798
+ @app.get("/api/status")
799
+ async def api_status():
800
+ """JSON API for status"""
 
801
  mode, load = update_load_mode()
802
+ healthy_count = len(get_healthy_workers())
803
 
804
  return {
805
+ "name": "SAM-Z-1 Distributed Cluster",
806
+ "version": "4.0.0",
807
  "mode": mode,
808
  "current_load": load,
809
  "workers": len(WORKER_URLS),
810
  "healthy_workers": healthy_count,
811
+ "features": ["distributed_compute", "smart_load_balancing", "real_time_dashboard"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  }
813
 
814
  @app.get("/health")
815
  async def health():
 
816
  healthy_count = len(get_healthy_workers())
 
 
817
  return {
818
  "status": "healthy" if healthy_count > 0 else "unhealthy",
819
+ "workers_healthy": healthy_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
  }
821
 
822
  @app.post("/v1/generate")
823
  async def generate(request: GenerateRequest):
824
+ """Generate text with distributed compute"""
 
825
  track_request()
826
  mode, load = update_load_mode()
827
 
828
  healthy = get_healthy_workers()
829
  if not healthy:
830
+ cluster_stats["failed_requests"] += 1
831
+ raise HTTPException(status_code=503, detail="No healthy workers")
832
 
833
  request_data = {
834
  "prompt": request.prompt,
 
840
  "stream": True
841
  }
842
 
843
+ print(f"🎯 {mode.upper()} | Load: {load} | Healthy: {len(healthy)}")
844
 
845
+ try:
846
+ if mode == "light" and len(healthy) >= 2:
847
+ # DISTRIBUTED MODE
848
+ generator, decoder1, decoder2 = select_distributed_workers()
849
+ cluster_stats["successful_requests"] += 1
850
+ return StreamingResponse(
851
+ distributed_generation(generator, decoder1, decoder2, request_data, "generate"),
852
+ media_type="text/event-stream"
853
+ )
854
+ else:
855
+ # HEAVY LOAD MODE
856
+ worker = get_least_busy_worker()
857
+ cluster_stats["successful_requests"] += 1
858
+ return StreamingResponse(
859
+ heavy_load_generation(worker, request_data, "generate"),
860
+ media_type="text/event-stream"
861
+ )
862
+ except Exception as e:
863
+ cluster_stats["failed_requests"] += 1
864
+ raise
865
 
866
  @app.post("/v1/chat")
867
  async def chat(request: ChatRequest):
868
+ """Chat with distributed compute"""
 
869
  track_request()
870
  mode, load = update_load_mode()
871
 
872
  healthy = get_healthy_workers()
873
  if not healthy:
874
+ cluster_stats["failed_requests"] += 1
875
+ raise HTTPException(status_code=503, detail="No healthy workers")
876
 
877
  request_data = {
878
  "messages": [{"role": m.role, "content": m.content} for m in request.messages],
 
884
  "stream": True
885
  }
886
 
887
+ print(f"💬 {mode.upper()} | Load: {load} | Healthy: {len(healthy)}")
888
 
889
+ try:
890
+ if mode == "light" and len(healthy) >= 2:
891
+ # DISTRIBUTED MODE
892
+ generator, decoder1, decoder2 = select_distributed_workers()
893
+ cluster_stats["successful_requests"] += 1
894
+ return StreamingResponse(
895
+ distributed_generation(generator, decoder1, decoder2, request_data, "chat"),
896
+ media_type="text/event-stream"
897
+ )
898
+ else:
899
+ # HEAVY LOAD MODE
900
+ worker = get_least_busy_worker()
901
+ cluster_stats["successful_requests"] += 1
902
+ return StreamingResponse(
903
+ heavy_load_generation(worker, request_data, "chat"),
904
+ media_type="text/event-stream"
905
+ )
906
+ except Exception as e:
907
+ cluster_stats["failed_requests"] += 1
908
+ raise
909
 
910
  # ============================================================================
911
  # Launch