Bc-AI commited on
Commit
4ad8327
·
verified ·
1 Parent(s): 65ca8df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -103
app.py CHANGED
@@ -1,36 +1,42 @@
1
  """
2
- SAM-Z-1 Cluster Head Node
3
- Receives requests and distributes to worker spaces
 
4
  """
5
 
6
- from fastapi import FastAPI, HTTPException, Request
7
  from fastapi.responses import StreamingResponse
8
  from pydantic import BaseModel
9
  import httpx
10
  import asyncio
11
  import json
12
  import time
13
- from typing import List, Optional
 
14
  import random
15
 
16
- app = FastAPI(title="SAM-Z-1 Cluster API", version="1.0.0")
17
 
18
  # ============================================================================
19
  # Configuration
20
  # ============================================================================
21
 
22
- # Add your worker space URLs here
23
  WORKER_URLS = [
24
  "https://bc-ai-worker-2.hf.space",
25
  "https://bc-ai-worker-sam-z-api.hf.space",
26
- # Add more workers as needed
27
  ]
28
 
29
- # Health check interval (seconds)
30
  HEALTH_CHECK_INTERVAL = 30
 
31
 
32
- # Worker health status
33
- worker_health = {url: {"healthy": True, "last_check": 0} for url in WORKER_URLS}
 
 
 
 
 
 
34
 
35
  # ============================================================================
36
  # Request Models
@@ -43,10 +49,10 @@ class GenerateRequest(BaseModel):
43
  top_k: int = 40
44
  top_p: float = 0.9
45
  repetition_penalty: float = 1.1
46
- stream: bool = False
47
 
48
  class ChatMessage(BaseModel):
49
- role: str # "user" or "assistant"
50
  content: str
51
 
52
  class ChatRequest(BaseModel):
@@ -56,22 +62,55 @@ class ChatRequest(BaseModel):
56
  top_k: int = 40
57
  top_p: float = 0.9
58
  repetition_penalty: float = 1.1
59
- stream: bool = False
60
 
61
  # ============================================================================
62
- # Load Balancing & Health Checks
63
  # ============================================================================
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def get_healthy_workers() -> List[str]:
66
  """Get list of healthy workers"""
67
  return [url for url, status in worker_health.items() if status["healthy"]]
68
 
69
- def select_worker() -> Optional[str]:
70
- """Select a worker using round-robin on healthy workers"""
71
  healthy = get_healthy_workers()
72
  if not healthy:
73
  return None
74
- return random.choice(healthy) # You could also implement round-robin here
 
 
 
 
 
 
 
 
 
 
75
 
76
  async def check_worker_health(worker_url: str) -> bool:
77
  """Check if a worker is healthy"""
@@ -91,7 +130,11 @@ async def health_check_loop():
91
  worker_health[worker_url]["last_check"] = time.time()
92
 
93
  status = "✅" if healthy else "❌"
94
- print(f"{status} Worker {worker_url}: {'healthy' if healthy else 'unhealthy'}")
 
 
 
 
95
 
96
  await asyncio.sleep(HEALTH_CHECK_INTERVAL)
97
 
@@ -100,6 +143,156 @@ async def startup_event():
100
  """Start health check loop on startup"""
101
  asyncio.create_task(health_check_loop())
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # ============================================================================
104
  # API Endpoints
105
  # ============================================================================
@@ -108,16 +301,30 @@ async def startup_event():
108
  async def root():
109
  """API info"""
110
  healthy_count = len(get_healthy_workers())
 
 
111
  return {
112
- "name": "SAM-Z-1 Cluster API",
113
- "version": "1.0.0",
 
 
114
  "workers": len(WORKER_URLS),
115
  "healthy_workers": healthy_count,
 
 
 
 
 
 
 
 
 
116
  "endpoints": {
117
  "generate": "/v1/generate",
118
  "chat": "/v1/chat",
119
  "health": "/health",
120
- "workers": "/workers"
 
121
  }
122
  }
123
 
@@ -125,10 +332,14 @@ async def root():
125
  async def health():
126
  """Health check endpoint"""
127
  healthy_count = len(get_healthy_workers())
 
 
128
  return {
129
  "status": "healthy" if healthy_count > 0 else "unhealthy",
130
  "workers_total": len(WORKER_URLS),
131
- "workers_healthy": healthy_count
 
 
132
  }
133
 
134
  @app.get("/workers")
@@ -139,107 +350,110 @@ async def workers_status():
139
  {
140
  "url": url,
141
  "healthy": status["healthy"],
 
142
  "last_check": status["last_check"]
143
  }
144
  for url, status in worker_health.items()
145
  ]
146
  }
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @app.post("/v1/generate")
149
  async def generate(request: GenerateRequest):
150
- """Generate text from prompt"""
151
- worker_url = select_worker()
152
 
153
- if not worker_url:
154
- raise HTTPException(
155
- status_code=503,
156
- detail="No healthy workers available"
157
- )
158
 
159
- try:
160
- async with httpx.AsyncClient(timeout=300.0) as client:
161
- if request.stream:
162
- # Streaming response
163
- async def stream_from_worker():
164
- async with client.stream(
165
- "POST",
166
- f"{worker_url}/generate",
167
- json=request.dict()
168
- ) as response:
169
- async for chunk in response.aiter_text():
170
- yield chunk
171
-
172
- return StreamingResponse(
173
- stream_from_worker(),
174
- media_type="text/event-stream"
175
- )
176
- else:
177
- # Non-streaming response
178
- response = await client.post(
179
- f"{worker_url}/generate",
180
- json=request.dict()
181
- )
182
- return response.json()
183
 
184
- except httpx.TimeoutException:
185
- # Mark worker as unhealthy and retry with another
186
- worker_health[worker_url]["healthy"] = False
187
- raise HTTPException(
188
- status_code=504,
189
- detail="Worker timeout - request failed"
 
 
190
  )
191
- except Exception as e:
192
- raise HTTPException(
193
- status_code=500,
194
- detail=f"Worker error: {str(e)}"
 
 
195
  )
196
 
197
  @app.post("/v1/chat")
198
  async def chat(request: ChatRequest):
199
- """Chat completion endpoint"""
200
- worker_url = select_worker()
201
 
202
- if not worker_url:
203
- raise HTTPException(
204
- status_code=503,
205
- detail="No healthy workers available"
206
- )
207
 
208
- try:
209
- async with httpx.AsyncClient(timeout=300.0) as client:
210
- if request.stream:
211
- # Streaming response
212
- async def stream_from_worker():
213
- async with client.stream(
214
- "POST",
215
- f"{worker_url}/chat",
216
- json=request.dict()
217
- ) as response:
218
- async for chunk in response.aiter_text():
219
- yield chunk
220
-
221
- return StreamingResponse(
222
- stream_from_worker(),
223
- media_type="text/event-stream"
224
- )
225
- else:
226
- # Non-streaming response
227
- response = await client.post(
228
- f"{worker_url}/chat",
229
- json=request.dict()
230
- )
231
- return response.json()
232
 
233
- except httpx.TimeoutException:
234
- worker_health[worker_url]["healthy"] = False
235
- raise HTTPException(
236
- status_code=504,
237
- detail="Worker timeout - request failed"
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
- except Exception as e:
240
- raise HTTPException(
241
- status_code=500,
242
- detail=f"Worker error: {str(e)}"
 
 
243
  )
244
 
245
  # ============================================================================
 
1
  """
2
+ SAM-Z-1 Smart Load Balancing Cluster Head Node
3
+ - Light load: parallel gen/decode split for max speed
4
+ - Heavy load: 1 worker per request for throughput
5
  """
6
 
7
+ from fastapi import FastAPI, HTTPException
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  import httpx
11
  import asyncio
12
  import json
13
  import time
14
+ from typing import List, Optional, Dict
15
+ from collections import deque
16
  import random
17
 
18
+ app = FastAPI(title="SAM-Z-1 Smart Cluster API", version="3.0.0")
19
 
20
  # ============================================================================
21
  # Configuration
22
  # ============================================================================
23
 
 
24
  WORKER_URLS = [
25
  "https://bc-ai-worker-2.hf.space",
26
  "https://bc-ai-worker-sam-z-api.hf.space",
 
27
  ]
28
 
 
29
  HEALTH_CHECK_INTERVAL = 30
30
+ LOAD_CHECK_WINDOW = 10 # seconds to measure load
31
 
32
+ # Load thresholds
33
+ LIGHT_LOAD_THRESHOLD = 2 # requests in window
34
+ HEAVY_LOAD_THRESHOLD = 5 # requests in window
35
+
36
+ # Worker state
37
+ worker_health = {url: {"healthy": True, "last_check": 0, "active_requests": 0} for url in WORKER_URLS}
38
+ request_timestamps = deque(maxlen=100) # track recent requests
39
+ current_load_mode = "light" # "light" or "heavy"
40
 
41
  # ============================================================================
42
  # Request Models
 
49
  top_k: int = 40
50
  top_p: float = 0.9
51
  repetition_penalty: float = 1.1
52
+ stream: bool = True
53
 
54
  class ChatMessage(BaseModel):
55
+ role: str
56
  content: str
57
 
58
  class ChatRequest(BaseModel):
 
62
  top_k: int = 40
63
  top_p: float = 0.9
64
  repetition_penalty: float = 1.1
65
+ stream: bool = True
66
 
67
  # ============================================================================
68
+ # Load Management
69
  # ============================================================================
70
 
71
+ def get_current_load() -> int:
72
+ """Calculate current load based on recent requests"""
73
+ now = time.time()
74
+ # Count requests in the last LOAD_CHECK_WINDOW seconds
75
+ return sum(1 for ts in request_timestamps if now - ts < LOAD_CHECK_WINDOW)
76
+
77
+ def update_load_mode():
78
+ """Update load mode based on current load"""
79
+ global current_load_mode
80
+ load = get_current_load()
81
+
82
+ if load <= LIGHT_LOAD_THRESHOLD:
83
+ current_load_mode = "light"
84
+ elif load >= HEAVY_LOAD_THRESHOLD:
85
+ current_load_mode = "heavy"
86
+ # hysteresis zone between thresholds maintains current mode
87
+
88
+ return current_load_mode, load
89
+
90
+ def track_request():
91
+ """Track a new request"""
92
+ request_timestamps.append(time.time())
93
+
94
  def get_healthy_workers() -> List[str]:
95
  """Get list of healthy workers"""
96
  return [url for url, status in worker_health.items() if status["healthy"]]
97
 
98
+ def get_least_busy_worker() -> Optional[str]:
99
+ """Get worker with fewest active requests"""
100
  healthy = get_healthy_workers()
101
  if not healthy:
102
  return None
103
+ return min(healthy, key=lambda url: worker_health[url]["active_requests"])
104
+
105
+ def select_worker_pair() -> tuple:
106
+ """Select 2 workers for parallel operation"""
107
+ healthy = get_healthy_workers()
108
+ if len(healthy) < 2:
109
+ return (healthy[0], None) if len(healthy) == 1 else (None, None)
110
+
111
+ # Sort by active requests, take 2 least busy
112
+ sorted_workers = sorted(healthy, key=lambda url: worker_health[url]["active_requests"])
113
+ return (sorted_workers[0], sorted_workers[1])
114
 
115
  async def check_worker_health(worker_url: str) -> bool:
116
  """Check if a worker is healthy"""
 
130
  worker_health[worker_url]["last_check"] = time.time()
131
 
132
  status = "✅" if healthy else "❌"
133
+ active = worker_health[worker_url]["active_requests"]
134
+ print(f"{status} {worker_url}: {'healthy' if healthy else 'unhealthy'} | Active: {active}")
135
+
136
+ mode, load = update_load_mode()
137
+ print(f"📊 Load mode: {mode.upper()} | Current load: {load} req/{LOAD_CHECK_WINDOW}s")
138
 
139
  await asyncio.sleep(HEALTH_CHECK_INTERVAL)
140
 
 
143
  """Start health check loop on startup"""
144
  asyncio.create_task(health_check_loop())
145
 
146
+ # ============================================================================
147
+ # Generation Strategies
148
+ # ============================================================================
149
+
150
+ async def light_load_generation(
151
+ generator_url: str,
152
+ decoder_url: str,
153
+ request_data: dict,
154
+ endpoint: str = "generate"
155
+ ):
156
+ """
157
+ LIGHT LOAD MODE: Split generation and decoding
158
+ - Generator worker: produces token IDs only
159
+ - Decoder worker: decodes token IDs to text
160
+ This parallelizes the bottleneck!
161
+ """
162
+
163
+ # Queues for pipeline
164
+ token_queue = asyncio.Queue(maxsize=10)
165
+ text_queue = asyncio.Queue(maxsize=10)
166
+
167
+ async def generate_tokens():
168
+ """Worker 1: Generate token IDs"""
169
+ try:
170
+ worker_health[generator_url]["active_requests"] += 1
171
+
172
+ # Request token IDs only mode
173
+ request_data_tokens = {**request_data, "return_token_ids": True}
174
+
175
+ async with httpx.AsyncClient(timeout=300.0) as client:
176
+ async with client.stream(
177
+ "POST",
178
+ f"{generator_url}/{endpoint}",
179
+ json=request_data_tokens
180
+ ) as response:
181
+ async for chunk in response.aiter_text():
182
+ if chunk.strip() and chunk.startswith("data: "):
183
+ try:
184
+ data = json.loads(chunk[6:])
185
+ if "token_id" in data:
186
+ await token_queue.put(data["token_id"])
187
+ elif "done" in data:
188
+ await token_queue.put(None) # Signal end
189
+ break
190
+ except:
191
+ pass
192
+ except Exception as e:
193
+ print(f"❌ Generator error: {e}")
194
+ await token_queue.put(None)
195
+ finally:
196
+ worker_health[generator_url]["active_requests"] -= 1
197
+
198
+ async def decode_tokens():
199
+ """Worker 2: Decode token IDs to text"""
200
+ try:
201
+ worker_health[decoder_url]["active_requests"] += 1
202
+
203
+ batch = []
204
+ batch_size = 5 # decode in small batches for speed
205
+
206
+ while True:
207
+ try:
208
+ token_id = await asyncio.wait_for(token_queue.get(), timeout=1.0)
209
+
210
+ if token_id is None:
211
+ # Decode remaining batch
212
+ if batch:
213
+ async with httpx.AsyncClient(timeout=10.0) as client:
214
+ response = await client.post(
215
+ f"{decoder_url}/decode",
216
+ json={"token_ids": batch}
217
+ )
218
+ text = response.json()["text"]
219
+ await text_queue.put(("text", text))
220
+
221
+ await text_queue.put(("done", None))
222
+ break
223
+
224
+ batch.append(token_id)
225
+
226
+ # Decode batch when full
227
+ if len(batch) >= batch_size:
228
+ async with httpx.AsyncClient(timeout=10.0) as client:
229
+ response = await client.post(
230
+ f"{decoder_url}/decode",
231
+ json={"token_ids": batch}
232
+ )
233
+ text = response.json()["text"]
234
+ await text_queue.put(("text", text))
235
+
236
+ batch = []
237
+
238
+ except asyncio.TimeoutError:
239
+ continue
240
+
241
+ except Exception as e:
242
+ print(f"❌ Decoder error: {e}")
243
+ await text_queue.put(("done", None))
244
+ finally:
245
+ worker_health[decoder_url]["active_requests"] -= 1
246
+
247
+ # Start both pipelines
248
+ gen_task = asyncio.create_task(generate_tokens())
249
+ dec_task = asyncio.create_task(decode_tokens())
250
+
251
+ # Stream decoded text
252
+ accumulated_text = ""
253
+ try:
254
+ while True:
255
+ msg_type, data = await text_queue.get()
256
+
257
+ if msg_type == "done":
258
+ break
259
+
260
+ if msg_type == "text":
261
+ accumulated_text += data
262
+ yield f"data: {json.dumps({'delta': data, 'text': accumulated_text})}\n\n"
263
+
264
+ finally:
265
+ await gen_task
266
+ await dec_task
267
+
268
+ async def heavy_load_generation(
269
+ worker_url: str,
270
+ request_data: dict,
271
+ endpoint: str = "generate"
272
+ ):
273
+ """
274
+ HEAVY LOAD MODE: Single worker per request
275
+ Standard streaming for max throughput
276
+ """
277
+ try:
278
+ worker_health[worker_url]["active_requests"] += 1
279
+
280
+ async with httpx.AsyncClient(timeout=300.0) as client:
281
+ async with client.stream(
282
+ "POST",
283
+ f"{worker_url}/{endpoint}",
284
+ json=request_data
285
+ ) as response:
286
+ async for chunk in response.aiter_text():
287
+ if chunk.strip():
288
+ yield chunk
289
+
290
+ except Exception as e:
291
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
292
+
293
+ finally:
294
+ worker_health[worker_url]["active_requests"] -= 1
295
+
296
  # ============================================================================
297
  # API Endpoints
298
  # ============================================================================
 
301
  async def root():
302
  """API info"""
303
  healthy_count = len(get_healthy_workers())
304
+ mode, load = update_load_mode()
305
+
306
  return {
307
+ "name": "SAM-Z-1 Smart Cluster API",
308
+ "version": "3.0.0",
309
+ "mode": mode,
310
+ "current_load": load,
311
  "workers": len(WORKER_URLS),
312
  "healthy_workers": healthy_count,
313
+ "features": [
314
+ "smart_load_balancing",
315
+ "parallel_gen_decode",
316
+ "adaptive_routing"
317
+ ],
318
+ "load_strategy": {
319
+ "light": "parallel gen/decode split for speed",
320
+ "heavy": "1 worker per request for throughput"
321
+ },
322
  "endpoints": {
323
  "generate": "/v1/generate",
324
  "chat": "/v1/chat",
325
  "health": "/health",
326
+ "workers": "/workers",
327
+ "stats": "/stats"
328
  }
329
  }
330
 
 
332
  async def health():
333
  """Health check endpoint"""
334
  healthy_count = len(get_healthy_workers())
335
+ mode, load = update_load_mode()
336
+
337
  return {
338
  "status": "healthy" if healthy_count > 0 else "unhealthy",
339
  "workers_total": len(WORKER_URLS),
340
+ "workers_healthy": healthy_count,
341
+ "load_mode": mode,
342
+ "current_load": load
343
  }
344
 
345
  @app.get("/workers")
 
350
  {
351
  "url": url,
352
  "healthy": status["healthy"],
353
+ "active_requests": status["active_requests"],
354
  "last_check": status["last_check"]
355
  }
356
  for url, status in worker_health.items()
357
  ]
358
  }
359
 
360
+ @app.get("/stats")
361
+ async def stats():
362
+ """Get cluster statistics"""
363
+ mode, load = update_load_mode()
364
+
365
+ return {
366
+ "load_mode": mode,
367
+ "current_load": load,
368
+ "load_window_seconds": LOAD_CHECK_WINDOW,
369
+ "thresholds": {
370
+ "light": LIGHT_LOAD_THRESHOLD,
371
+ "heavy": HEAVY_LOAD_THRESHOLD
372
+ },
373
+ "recent_requests": len(request_timestamps),
374
+ "worker_stats": {
375
+ url: {
376
+ "healthy": status["healthy"],
377
+ "active": status["active_requests"]
378
+ }
379
+ for url, status in worker_health.items()
380
+ }
381
+ }
382
+
383
  @app.post("/v1/generate")
384
  async def generate(request: GenerateRequest):
385
+ """Generate text with smart load balancing"""
 
386
 
387
+ track_request()
388
+ mode, load = update_load_mode()
 
 
 
389
 
390
+ healthy = get_healthy_workers()
391
+ if not healthy:
392
+ raise HTTPException(status_code=503, detail="No healthy workers available")
393
+
394
+ request_data = {
395
+ "prompt": request.prompt,
396
+ "max_tokens": request.max_tokens,
397
+ "temperature": request.temperature,
398
+ "top_k": request.top_k,
399
+ "top_p": request.top_p,
400
+ "repetition_penalty": request.repetition_penalty,
401
+ "stream": True
402
+ }
 
 
 
 
 
 
 
 
 
 
 
403
 
404
+ print(f"🎯 Mode: {mode.upper()} | Load: {load} | Request: generate")
405
+
406
+ if mode == "light" and len(healthy) >= 2:
407
+ # LIGHT LOAD: parallel gen/decode
408
+ generator, decoder = select_worker_pair()
409
+ return StreamingResponse(
410
+ light_load_generation(generator, decoder, request_data, "generate"),
411
+ media_type="text/event-stream"
412
  )
413
+ else:
414
+ # HEAVY LOAD: single worker
415
+ worker = get_least_busy_worker()
416
+ return StreamingResponse(
417
+ heavy_load_generation(worker, request_data, "generate"),
418
+ media_type="text/event-stream"
419
  )
420
 
421
  @app.post("/v1/chat")
422
  async def chat(request: ChatRequest):
423
+ """Chat completion with smart load balancing"""
 
424
 
425
+ track_request()
426
+ mode, load = update_load_mode()
 
 
 
427
 
428
+ healthy = get_healthy_workers()
429
+ if not healthy:
430
+ raise HTTPException(status_code=503, detail="No healthy workers available")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ request_data = {
433
+ "messages": [{"role": m.role, "content": m.content} for m in request.messages],
434
+ "max_tokens": request.max_tokens,
435
+ "temperature": request.temperature,
436
+ "top_k": request.top_k,
437
+ "top_p": request.top_p,
438
+ "repetition_penalty": request.repetition_penalty,
439
+ "stream": True
440
+ }
441
+
442
+ print(f"🎯 Mode: {mode.upper()} | Load: {load} | Request: chat")
443
+
444
+ if mode == "light" and len(healthy) >= 2:
445
+ # LIGHT LOAD: parallel gen/decode
446
+ generator, decoder = select_worker_pair()
447
+ return StreamingResponse(
448
+ light_load_generation(generator, decoder, request_data, "chat"),
449
+ media_type="text/event-stream"
450
  )
451
+ else:
452
+ # HEAVY LOAD: single worker
453
+ worker = get_least_busy_worker()
454
+ return StreamingResponse(
455
+ heavy_load_generation(worker, request_data, "chat"),
456
+ media_type="text/event-stream"
457
  )
458
 
459
  # ============================================================================