Bc-AI commited on
Commit
3759d45
Β·
verified Β·
1 Parent(s): 4b0bf97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -48
app.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
- SAM-Z-1 Smart Worker Node
3
- Supports both full generation and gen/decode split modes
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import StreamingResponse
8
  from pydantic import BaseModel
9
  import tensorflow as tf
10
  import keras
@@ -17,10 +17,10 @@ import time
17
  from typing import List, Optional
18
  import asyncio
19
 
20
- app = FastAPI(title="SAM-Z-1 Smart Worker", version="3.0.0")
21
 
22
  # ============================================================================
23
- # Model Architecture (same as before)
24
  # ============================================================================
25
 
26
  @keras.saving.register_keras_serializable()
@@ -201,7 +201,7 @@ class SAM1Model(keras.Model):
201
  return base_config
202
 
203
  # ============================================================================
204
- # Global Variables
205
  # ============================================================================
206
 
207
  model = None
@@ -213,6 +213,14 @@ fast_forward = None
213
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
214
  CACHE_DIR = "./model_cache"
215
 
 
 
 
 
 
 
 
 
216
  # ============================================================================
217
  # Request Models
218
  # ============================================================================
@@ -225,7 +233,7 @@ class GenerateRequest(BaseModel):
225
  top_p: float = 0.9
226
  repetition_penalty: float = 1.1
227
  stream: bool = False
228
- return_token_ids: bool = False # NEW: for gen/decode split
229
 
230
  class ChatMessage(BaseModel):
231
  role: str
@@ -239,11 +247,14 @@ class ChatRequest(BaseModel):
239
  top_p: float = 0.9
240
  repetition_penalty: float = 1.1
241
  stream: bool = False
242
- return_token_ids: bool = False # NEW
243
 
244
  class DecodeRequest(BaseModel):
245
  token_ids: List[int]
246
 
 
 
 
247
  # ============================================================================
248
  # Generation Functions
249
  # ============================================================================
@@ -257,11 +268,7 @@ def generate_tokens(
257
  repetition_penalty: float = 1.1,
258
  return_token_ids: bool = False
259
  ):
260
- """
261
- Core generation function
262
- If return_token_ids=True, yields (token_id, None)
263
- If return_token_ids=False, yields (token_id, token_text)
264
- """
265
  global model, tokenizer, config, eos_token_id, fast_forward
266
 
267
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
@@ -314,7 +321,6 @@ def generate_tokens(
314
 
315
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
316
 
317
- # Yield token ID and optionally decoded text
318
  if return_token_ids:
319
  yield (next_token_id, None)
320
  else:
@@ -327,7 +333,6 @@ def generate_tokens(
327
  input_tensor = input_tensor[:, -config['max_position_embeddings']:]
328
 
329
  def format_chat_prompt(messages: List[ChatMessage]) -> str:
330
- """Format chat messages into prompt"""
331
  prompt = ""
332
  for msg in messages:
333
  if msg.role == "user":
@@ -339,60 +344,262 @@ def format_chat_prompt(messages: List[ChatMessage]) -> str:
339
  return prompt
340
 
341
  # ============================================================================
342
- # API Endpoints
343
  # ============================================================================
344
 
345
- @app.get("/")
346
- async def root():
347
- """Worker info"""
348
- return {
349
- "name": "SAM-Z-1 Smart Worker",
350
- "version": "3.0.0",
351
- "status": "ready" if model is not None else "loading",
352
- "model": MODEL_REPO,
353
- "features": ["full_generation", "token_only_mode", "decode_only_mode"],
354
- "endpoints": {
355
- "generate": "/generate",
356
- "chat": "/chat",
357
- "decode": "/decode",
358
- "health": "/health"
 
 
359
  }
360
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  @app.get("/health")
363
  async def health():
364
- """Health check"""
365
  return {
366
  "status": "healthy" if model is not None else "loading",
367
  "model_loaded": model is not None
368
  }
369
 
 
 
 
 
 
 
 
 
 
 
 
370
  @app.post("/decode")
371
  async def decode(request: DecodeRequest):
372
- """
373
- DECODE ONLY endpoint
374
- Takes token IDs and returns decoded text
375
- This is the bottleneck we're parallelizing!
376
- """
377
  if tokenizer is None:
378
  raise HTTPException(status_code=503, detail="Tokenizer not loaded")
379
 
380
  try:
 
381
  text = tokenizer.decode(request.token_ids)
382
  return {"text": text}
383
  except Exception as e:
384
  raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}")
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  @app.post("/generate")
387
  async def generate(request: GenerateRequest):
388
- """Generate text - supports both full gen and token-only mode"""
389
  if model is None:
390
- raise HTTPException(status_code=503, detail="Model not loaded yet")
391
 
 
392
  start_time = time.time()
393
 
394
  if request.stream:
395
- # Streaming response
396
  async def stream_tokens():
397
  generated_text = ""
398
  token_count = 0
@@ -408,12 +615,11 @@ async def generate(request: GenerateRequest):
408
  return_token_ids=request.return_token_ids
409
  ):
410
  token_count += 1
 
411
 
412
  if request.return_token_ids:
413
- # TOKEN-ONLY mode for gen/decode split
414
  yield f"data: {json.dumps({'token_id': token_id})}\n\n"
415
  else:
416
- # FULL mode with text
417
  generated_text += token_text
418
  yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
419
 
@@ -428,7 +634,6 @@ async def generate(request: GenerateRequest):
428
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
429
 
430
  else:
431
- # Non-streaming
432
  generated_text = ""
433
  token_count = 0
434
 
@@ -445,6 +650,7 @@ async def generate(request: GenerateRequest):
445
  if not request.return_token_ids:
446
  generated_text += token_text
447
  token_count += 1
 
448
 
449
  elapsed = time.time() - start_time
450
 
@@ -460,10 +666,11 @@ async def generate(request: GenerateRequest):
460
 
461
  @app.post("/chat")
462
  async def chat(request: ChatRequest):
463
- """Chat completion - supports both modes"""
464
  if model is None:
465
- raise HTTPException(status_code=503, detail="Model not loaded yet")
466
 
 
467
  prompt = format_chat_prompt(request.messages)
468
  start_time = time.time()
469
 
@@ -483,6 +690,7 @@ async def chat(request: ChatRequest):
483
  return_token_ids=request.return_token_ids
484
  ):
485
  token_count += 1
 
486
 
487
  if request.return_token_ids:
488
  yield f"data: {json.dumps({'token_id': token_id})}\n\n"
@@ -527,6 +735,7 @@ async def chat(request: ChatRequest):
527
  break
528
 
529
  token_count += 1
 
530
 
531
  elapsed = time.time() - start_time
532
 
@@ -544,12 +753,11 @@ async def chat(request: ChatRequest):
544
  raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
545
 
546
  # ============================================================================
547
- # Startup: Load Model
548
  # ============================================================================
549
 
550
  @app.on_event("startup")
551
  async def load_model():
552
- """Load model on startup"""
553
  global model, tokenizer, config, eos_token_id, fast_forward
554
 
555
  print("πŸš€ Loading SAM-Z-1 Model...")
@@ -619,7 +827,12 @@ async def load_model():
619
 
620
  fast_forward = optimized_forward
621
 
622
- print("βœ… SAM-Z-1 Smart Worker ready! πŸš€")
 
 
 
 
 
623
 
624
  except Exception as e:
625
  print(f"❌ Failed to load model: {e}")
 
1
  """
2
+ SAM-Z-1 Distributed Worker Node v4.0
3
+ Optimized for distributed gen/decode pipeline
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import StreamingResponse, HTMLResponse
8
  from pydantic import BaseModel
9
  import tensorflow as tf
10
  import keras
 
17
  from typing import List, Optional
18
  import asyncio
19
 
20
+ app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0")
21
 
22
  # ============================================================================
23
+ # Model Architecture
24
  # ============================================================================
25
 
26
  @keras.saving.register_keras_serializable()
 
201
  return base_config
202
 
203
  # ============================================================================
204
+ # Global State
205
  # ============================================================================
206
 
207
  model = None
 
213
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
214
  CACHE_DIR = "./model_cache"
215
 
216
+ # Stats
217
+ worker_stats = {
218
+ "total_requests": 0,
219
+ "total_tokens": 0,
220
+ "decode_requests": 0,
221
+ "uptime_start": time.time()
222
+ }
223
+
224
  # ============================================================================
225
  # Request Models
226
  # ============================================================================
 
233
  top_p: float = 0.9
234
  repetition_penalty: float = 1.1
235
  stream: bool = False
236
+ return_token_ids: bool = False
237
 
238
  class ChatMessage(BaseModel):
239
  role: str
 
247
  top_p: float = 0.9
248
  repetition_penalty: float = 1.1
249
  stream: bool = False
250
+ return_token_ids: bool = False
251
 
252
  class DecodeRequest(BaseModel):
253
  token_ids: List[int]
254
 
255
+ class BatchDecodeRequest(BaseModel):
256
+ batches: List[List[int]]
257
+
258
  # ============================================================================
259
  # Generation Functions
260
  # ============================================================================
 
268
  repetition_penalty: float = 1.1,
269
  return_token_ids: bool = False
270
  ):
271
+ """Core generation - yields (token_id, token_text or None)"""
 
 
 
 
272
  global model, tokenizer, config, eos_token_id, fast_forward
273
 
274
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
 
321
 
322
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
323
 
 
324
  if return_token_ids:
325
  yield (next_token_id, None)
326
  else:
 
333
  input_tensor = input_tensor[:, -config['max_position_embeddings']:]
334
 
335
  def format_chat_prompt(messages: List[ChatMessage]) -> str:
 
336
  prompt = ""
337
  for msg in messages:
338
  if msg.role == "user":
 
344
  return prompt
345
 
346
  # ============================================================================
347
+ # Status Page
348
  # ============================================================================
349
 
350
+ @app.get("/", response_class=HTMLResponse)
351
+ async def status_page():
352
+ """Worker status page"""
353
+ return """
354
+ <!DOCTYPE html>
355
+ <html>
356
+ <head>
357
+ <title>SAM-Z-1 Worker Node</title>
358
+ <style>
359
+ * { margin: 0; padding: 0; box-sizing: border-box; }
360
+ body {
361
+ font-family: 'Courier New', monospace;
362
+ background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
363
+ color: #00bfff;
364
+ padding: 20px;
365
+ min-height: 100vh;
366
  }
367
+ .container {
368
+ max-width: 900px;
369
+ margin: 0 auto;
370
+ }
371
+ .header {
372
+ text-align: center;
373
+ padding: 30px;
374
+ background: rgba(0, 191, 255, 0.1);
375
+ border: 2px solid #00bfff;
376
+ border-radius: 10px;
377
+ margin-bottom: 30px;
378
+ box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
379
+ }
380
+ .header h1 {
381
+ font-size: 2.5em;
382
+ text-transform: uppercase;
383
+ letter-spacing: 3px;
384
+ animation: glow 2s ease-in-out infinite alternate;
385
+ }
386
+ @keyframes glow {
387
+ from { text-shadow: 0 0 10px #00bfff; }
388
+ to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }
389
+ }
390
+ .badge {
391
+ display: inline-block;
392
+ padding: 5px 15px;
393
+ border-radius: 15px;
394
+ font-size: 0.9em;
395
+ margin-top: 10px;
396
+ }
397
+ .badge-ready {
398
+ background: rgba(0, 255, 136, 0.2);
399
+ border: 1px solid #00ff88;
400
+ color: #00ff88;
401
+ }
402
+ .badge-loading {
403
+ background: rgba(255, 165, 0, 0.2);
404
+ border: 1px solid #ffa500;
405
+ color: #ffa500;
406
+ }
407
+ .stats-grid {
408
+ display: grid;
409
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
410
+ gap: 20px;
411
+ margin-bottom: 30px;
412
+ }
413
+ .stat-card {
414
+ background: rgba(0, 191, 255, 0.05);
415
+ border: 1px solid #00bfff;
416
+ border-radius: 8px;
417
+ padding: 20px;
418
+ text-align: center;
419
+ }
420
+ .stat-label {
421
+ font-size: 0.8em;
422
+ opacity: 0.7;
423
+ text-transform: uppercase;
424
+ margin-bottom: 10px;
425
+ }
426
+ .stat-value {
427
+ font-size: 2em;
428
+ font-weight: bold;
429
+ }
430
+ .features {
431
+ background: rgba(0, 191, 255, 0.05);
432
+ border: 1px solid #00bfff;
433
+ border-radius: 8px;
434
+ padding: 20px;
435
+ }
436
+ .features h3 {
437
+ margin-bottom: 15px;
438
+ }
439
+ .feature-list {
440
+ list-style: none;
441
+ padding: 0;
442
+ }
443
+ .feature-list li {
444
+ padding: 10px;
445
+ margin: 5px 0;
446
+ background: rgba(0, 191, 255, 0.1);
447
+ border-radius: 5px;
448
+ }
449
+ .feature-list li:before {
450
+ content: "⚑ ";
451
+ color: #00ff88;
452
+ }
453
+ .timestamp {
454
+ text-align: center;
455
+ margin-top: 20px;
456
+ opacity: 0.5;
457
+ }
458
+ </style>
459
+ </head>
460
+ <body>
461
+ <div class="container">
462
+ <div class="header">
463
+ <h1>βš™οΈ WORKER NODE βš™οΈ</h1>
464
+ <div>SAM-Z-1 Distributed Worker v4.0</div>
465
+ <div class="badge" id="status-badge">CHECKING STATUS...</div>
466
+ </div>
467
+
468
+ <div class="stats-grid" id="stats">
469
+ <div class="stat-card">
470
+ <div class="stat-label">Total Requests</div>
471
+ <div class="stat-value" id="total-req">--</div>
472
+ </div>
473
+ <div class="stat-card">
474
+ <div class="stat-label">Total Tokens</div>
475
+ <div class="stat-value" id="total-tokens">--</div>
476
+ </div>
477
+ <div class="stat-card">
478
+ <div class="stat-label">Decode Requests</div>
479
+ <div class="stat-value" id="decode-req">--</div>
480
+ </div>
481
+ <div class="stat-card">
482
+ <div class="stat-label">Uptime</div>
483
+ <div class="stat-value" id="uptime">--</div>
484
+ </div>
485
+ </div>
486
+
487
+ <div class="features">
488
+ <h3>πŸš€ CAPABILITIES</h3>
489
+ <ul class="feature-list">
490
+ <li>Full Text Generation</li>
491
+ <li>Token-Only Mode (for distributed pipeline)</li>
492
+ <li>High-Speed Batch Decoding</li>
493
+ <li>Chat Completion</li>
494
+ <li>Streaming & Non-Streaming</li>
495
+ </ul>
496
+ </div>
497
+
498
+ <div class="timestamp" id="timestamp">Initializing...</div>
499
+ </div>
500
+
501
+ <script>
502
+ async function updateStats() {
503
+ try {
504
+ const response = await fetch('/health');
505
+ const data = await response.json();
506
+
507
+ const badge = document.getElementById('status-badge');
508
+ if (data.model_loaded) {
509
+ badge.textContent = 'βœ… READY FOR INFERENCE';
510
+ badge.className = 'badge badge-ready';
511
+ } else {
512
+ badge.textContent = '⏳ LOADING MODEL...';
513
+ badge.className = 'badge badge-loading';
514
+ }
515
+
516
+ // Fetch stats
517
+ const statsRes = await fetch('/stats');
518
+ const stats = await statsRes.json();
519
+
520
+ document.getElementById('total-req').textContent = stats.total_requests;
521
+ document.getElementById('total-tokens').textContent = stats.total_tokens;
522
+ document.getElementById('decode-req').textContent = stats.decode_requests;
523
+
524
+ const uptime = Math.floor(stats.uptime);
525
+ const h = Math.floor(uptime / 3600);
526
+ const m = Math.floor((uptime % 3600) / 60);
527
+ const s = uptime % 60;
528
+ document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`;
529
+
530
+ document.getElementById('timestamp').textContent =
531
+ `Last update: ${new Date().toLocaleTimeString()}`;
532
+ } catch (e) {
533
+ console.error('Failed to update stats:', e);
534
+ }
535
+ }
536
+
537
+ // Update every second
538
+ setInterval(updateStats, 1000);
539
+ updateStats();
540
+ </script>
541
+ </body>
542
+ </html>
543
+ """
544
+
545
+ # ============================================================================
546
+ # API Endpoints
547
+ # ============================================================================
548
 
549
  @app.get("/health")
550
  async def health():
 
551
  return {
552
  "status": "healthy" if model is not None else "loading",
553
  "model_loaded": model is not None
554
  }
555
 
556
+ @app.get("/stats")
557
+ async def stats():
558
+ uptime = time.time() - worker_stats["uptime_start"]
559
+ return {
560
+ "total_requests": worker_stats["total_requests"],
561
+ "total_tokens": worker_stats["total_tokens"],
562
+ "decode_requests": worker_stats["decode_requests"],
563
+ "uptime": uptime,
564
+ "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0
565
+ }
566
+
567
  @app.post("/decode")
568
  async def decode(request: DecodeRequest):
569
+ """Fast single decode"""
 
 
 
 
570
  if tokenizer is None:
571
  raise HTTPException(status_code=503, detail="Tokenizer not loaded")
572
 
573
  try:
574
+ worker_stats["decode_requests"] += 1
575
  text = tokenizer.decode(request.token_ids)
576
  return {"text": text}
577
  except Exception as e:
578
  raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}")
579
 
580
+ @app.post("/decode/batch")
581
+ async def batch_decode(request: BatchDecodeRequest):
582
+ """Optimized batch decoding for distributed pipeline"""
583
+ if tokenizer is None:
584
+ raise HTTPException(status_code=503, detail="Tokenizer not loaded")
585
+
586
+ try:
587
+ worker_stats["decode_requests"] += len(request.batches)
588
+ results = [tokenizer.decode(batch) for batch in request.batches]
589
+ return {"texts": results}
590
+ except Exception as e:
591
+ raise HTTPException(status_code=500, detail=f"Batch decode error: {str(e)}")
592
+
593
  @app.post("/generate")
594
  async def generate(request: GenerateRequest):
595
+ """Generate text"""
596
  if model is None:
597
+ raise HTTPException(status_code=503, detail="Model not loaded")
598
 
599
+ worker_stats["total_requests"] += 1
600
  start_time = time.time()
601
 
602
  if request.stream:
 
603
  async def stream_tokens():
604
  generated_text = ""
605
  token_count = 0
 
615
  return_token_ids=request.return_token_ids
616
  ):
617
  token_count += 1
618
+ worker_stats["total_tokens"] += 1
619
 
620
  if request.return_token_ids:
 
621
  yield f"data: {json.dumps({'token_id': token_id})}\n\n"
622
  else:
 
623
  generated_text += token_text
624
  yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
625
 
 
634
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
635
 
636
  else:
 
637
  generated_text = ""
638
  token_count = 0
639
 
 
650
  if not request.return_token_ids:
651
  generated_text += token_text
652
  token_count += 1
653
+ worker_stats["total_tokens"] += 1
654
 
655
  elapsed = time.time() - start_time
656
 
 
666
 
667
  @app.post("/chat")
668
  async def chat(request: ChatRequest):
669
+ """Chat completion"""
670
  if model is None:
671
+ raise HTTPException(status_code=503, detail="Model not loaded")
672
 
673
+ worker_stats["total_requests"] += 1
674
  prompt = format_chat_prompt(request.messages)
675
  start_time = time.time()
676
 
 
690
  return_token_ids=request.return_token_ids
691
  ):
692
  token_count += 1
693
+ worker_stats["total_tokens"] += 1
694
 
695
  if request.return_token_ids:
696
  yield f"data: {json.dumps({'token_id': token_id})}\n\n"
 
735
  break
736
 
737
  token_count += 1
738
+ worker_stats["total_tokens"] += 1
739
 
740
  elapsed = time.time() - start_time
741
 
 
753
  raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
754
 
755
  # ============================================================================
756
+ # Model Loading
757
  # ============================================================================
758
 
759
  @app.on_event("startup")
760
  async def load_model():
 
761
  global model, tokenizer, config, eos_token_id, fast_forward
762
 
763
  print("πŸš€ Loading SAM-Z-1 Model...")
 
827
 
828
  fast_forward = optimized_forward
829
 
830
+ print("βœ… SAM-Z-1 Distributed Worker ready! πŸš€")
831
+ print("πŸ”₯ Features enabled:")
832
+ print(" - Full text generation")
833
+ print(" - Token-only mode (distributed pipeline)")
834
+ print(" - Batch decoding optimization")
835
+ print(" - Streaming support")
836
 
837
  except Exception as e:
838
  print(f"❌ Failed to load model: {e}")