Bc-AI commited on
Commit
25388aa
Β·
verified Β·
1 Parent(s): 120f320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -398
app.py CHANGED
@@ -1,9 +1,6 @@
1
  """
2
- SAM-Z-1 Distributed Worker Node v5.0
3
- - Supports BOTH old SAM-Z-1 AND 4 new SAM-X-1 models
4
- - Different tokenizers and vocabularies per model family
5
- - Auto version detection
6
- - Backward compatible with v4 head nodes
7
  """
8
 
9
  from fastapi import FastAPI, HTTPException
@@ -17,56 +14,10 @@ import os
17
  from tokenizers import Tokenizer
18
  import numpy as np
19
  import time
20
- from typing import List, Optional, Dict
21
  import asyncio
22
 
23
- app = FastAPI(title="SAM-Z-1 Distributed Worker", version="5.0.0")
24
-
25
- # ============================================================================
26
- # Configuration - ALL 5 MODELS
27
- # ============================================================================
28
-
29
- MODEL_REGISTRY = {
30
- # Original SAM-Z-1 (keep this!)
31
- "SAM-Z-1": {
32
- "repo": "Smilyai-labs/Sam-Z-1-tensorflow",
33
- "weights": "ckpt.weights.h5",
34
- "config": "config.json",
35
- "tokenizer_repo": "Smilyai-labs/Sam-Z-1-tensorflow",
36
- "family": "sam-z" # Different tokenizer family
37
- },
38
- # New SAM-X-1 family (different tokenizer!)
39
- "SAM-X-1-Large": {
40
- "repo": "Smilyai-labs/Sam-1x-instruct",
41
- "weights": "ckpt.weights.h5",
42
- "config": None,
43
- "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
44
- "family": "sam-x"
45
- },
46
- "SAM-X-1-Fast": {
47
- "repo": "Smilyai-labs/Sam-X-1-fast",
48
- "weights": "sam1_fast_finetuned.weights.h5",
49
- "config": "sam1_fast_finetuned_config.json",
50
- "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
51
- "family": "sam-x"
52
- },
53
- "SAM-X-1-Mini": {
54
- "repo": "Smilyai-labs/Sam-X-1-Mini",
55
- "weights": "sam1_mini_finetuned.weights.h5",
56
- "config": "sam1_mini_finetuned_config.json",
57
- "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
58
- "family": "sam-x"
59
- },
60
- "SAM-X-1-Nano": {
61
- "repo": "Smilyai-labs/Sam-X-1-Nano",
62
- "weights": "sam1_nano_finetuned.weights.h5",
63
- "config": "sam1_nano_finetuned_config.json",
64
- "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
65
- "family": "sam-x"
66
- }
67
- }
68
-
69
- CACHE_DIR = "./model_cache"
70
 
71
  # ============================================================================
72
  # Model Architecture
@@ -250,19 +201,24 @@ class SAM1Model(keras.Model):
250
  return base_config
251
 
252
  # ============================================================================
253
- # Global State - Separate tokenizers per family!
254
  # ============================================================================
255
 
256
- loaded_models = {} # Dict[model_name, (model, fast_forward, config, tokenizer, eos_token_id)]
257
- tokenizer_cache = {} # Dict[family, (tokenizer, eos_token_id)]
258
- current_model = None
 
 
 
 
 
259
 
 
260
  worker_stats = {
261
  "total_requests": 0,
262
  "total_tokens": 0,
263
  "decode_requests": 0,
264
- "uptime_start": time.time(),
265
- "model_usage": {}
266
  }
267
 
268
  # ============================================================================
@@ -278,7 +234,6 @@ class GenerateRequest(BaseModel):
278
  repetition_penalty: float = 1.1
279
  stream: bool = False
280
  return_token_ids: bool = False
281
- model: Optional[str] = None
282
 
283
  class ChatMessage(BaseModel):
284
  role: str
@@ -293,70 +248,12 @@ class ChatRequest(BaseModel):
293
  repetition_penalty: float = 1.1
294
  stream: bool = False
295
  return_token_ids: bool = False
296
- model: Optional[str] = None
297
 
298
  class DecodeRequest(BaseModel):
299
  token_ids: List[int]
300
- model: Optional[str] = None # Need to know which tokenizer to use!
301
 
302
  class BatchDecodeRequest(BaseModel):
303
  batches: List[List[int]]
304
- model: Optional[str] = None
305
-
306
- # ============================================================================
307
- # Tokenizer Management
308
- # ============================================================================
309
-
310
- async def load_tokenizer(family: str, repo: str) -> tuple:
311
- """Load tokenizer for a model family"""
312
- if family in tokenizer_cache:
313
- return tokenizer_cache[family]
314
-
315
- print(f" πŸ”€ Loading tokenizer for {family} family from {repo}...")
316
-
317
- try:
318
- from transformers import AutoTokenizer
319
-
320
- hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
321
- custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
322
- hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
323
-
324
- os.makedirs(f"./temp_tokenizer_{family}", exist_ok=True)
325
- hf_tokenizer.save_pretrained(f"./temp_tokenizer_{family}")
326
- tokenizer = Tokenizer.from_file(f"./temp_tokenizer_{family}/tokenizer.json")
327
-
328
- eos_token = "<|endoftext|>"
329
- eos_token_id = tokenizer.token_to_id(eos_token)
330
-
331
- if eos_token_id is None:
332
- tokenizer.add_special_tokens([eos_token])
333
- eos_token_id = tokenizer.token_to_id(eos_token)
334
-
335
- tokenizer_cache[family] = (tokenizer, eos_token_id)
336
- print(f" βœ… Tokenizer ready (vocab size: {tokenizer.get_vocab_size()}, EOS: {eos_token_id})")
337
-
338
- return tokenizer, eos_token_id
339
-
340
- except Exception as e:
341
- print(f" ⚠️ Tokenizer load failed: {e}")
342
- raise
343
-
344
- def get_tokenizer_for_model(model_name: str):
345
- """Get the correct tokenizer for a model"""
346
- if not model_name or model_name not in loaded_models:
347
- model_name = current_model
348
-
349
- if model_name in loaded_models:
350
- _, _, _, tokenizer, eos_id = loaded_models[model_name]
351
- return tokenizer, eos_id
352
-
353
- # Fallback to first available
354
- if loaded_models:
355
- first_model = list(loaded_models.keys())[0]
356
- _, _, _, tokenizer, eos_id = loaded_models[first_model]
357
- return tokenizer, eos_id
358
-
359
- raise HTTPException(status_code=503, detail="No models loaded")
360
 
361
  # ============================================================================
362
  # Generation Functions
@@ -369,22 +266,11 @@ def generate_tokens(
369
  top_k: int = 40,
370
  top_p: float = 0.9,
371
  repetition_penalty: float = 1.1,
372
- return_token_ids: bool = False,
373
- model_name: Optional[str] = None
374
  ):
375
- """Core generation with correct tokenizer per model"""
376
- global loaded_models, current_model
377
 
378
- # Select model
379
- if model_name and model_name in loaded_models:
380
- model, fast_forward, config, tokenizer, eos_token_id = loaded_models[model_name]
381
- elif current_model:
382
- model, fast_forward, config, tokenizer, eos_token_id = loaded_models[current_model]
383
- else:
384
- model_name = list(loaded_models.keys())[0]
385
- model, fast_forward, config, tokenizer, eos_token_id = loaded_models[model_name]
386
-
387
- # Encode with model's tokenizer
388
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
389
 
390
  if len(input_ids) == 0:
@@ -463,29 +349,26 @@ def format_chat_prompt(messages: List[ChatMessage]) -> str:
463
 
464
  @app.get("/", response_class=HTMLResponse)
465
  async def status_page():
466
- models_html = ""
467
- for model_name in loaded_models.keys():
468
- usage = worker_stats["model_usage"].get(model_name, 0)
469
- _, _, _, tokenizer, _ = loaded_models[model_name]
470
- vocab_size = tokenizer.get_vocab_size()
471
- models_html += f'<li><strong>{model_name}</strong> - Vocab: {vocab_size} - Used: {usage}x</li>'
472
-
473
- return f"""
474
  <!DOCTYPE html>
475
  <html>
476
  <head>
477
- <title>SAM Worker v5.0 - Multi-Model</title>
478
  <style>
479
- * {{ margin: 0; padding: 0; box-sizing: border-box; }}
480
- body {{
481
  font-family: 'Courier New', monospace;
482
  background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
483
  color: #00bfff;
484
  padding: 20px;
485
  min-height: 100vh;
486
- }}
487
- .container {{ max-width: 1000px; margin: 0 auto; }}
488
- .header {{
 
 
 
489
  text-align: center;
490
  padding: 30px;
491
  background: rgba(0, 191, 255, 0.1);
@@ -493,77 +376,93 @@ async def status_page():
493
  border-radius: 10px;
494
  margin-bottom: 30px;
495
  box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
496
- }}
497
- .header h1 {{
498
  font-size: 2.5em;
499
  text-transform: uppercase;
500
  letter-spacing: 3px;
501
  animation: glow 2s ease-in-out infinite alternate;
502
- }}
503
- @keyframes glow {{
504
- from {{ text-shadow: 0 0 10px #00bfff; }}
505
- to {{ text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }}
506
- }}
507
- .badge {{
508
  display: inline-block;
509
  padding: 5px 15px;
510
  border-radius: 15px;
511
  font-size: 0.9em;
512
- margin: 5px;
513
- }}
514
- .badge-v5 {{
515
  background: rgba(0, 255, 136, 0.2);
516
  border: 1px solid #00ff88;
517
  color: #00ff88;
518
- }}
519
- .badge-multi {{
520
  background: rgba(255, 165, 0, 0.2);
521
  border: 1px solid #ffa500;
522
  color: #ffa500;
523
- }}
524
- .stats-grid {{
525
  display: grid;
526
  grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
527
  gap: 20px;
528
  margin-bottom: 30px;
529
- }}
530
- .stat-card {{
531
  background: rgba(0, 191, 255, 0.05);
532
  border: 1px solid #00bfff;
533
  border-radius: 8px;
534
  padding: 20px;
535
  text-align: center;
536
- }}
537
- .stat-label {{ font-size: 0.8em; opacity: 0.7; text-transform: uppercase; margin-bottom: 10px; }}
538
- .stat-value {{ font-size: 2em; font-weight: bold; }}
539
- .features {{
 
 
 
 
 
 
 
 
540
  background: rgba(0, 191, 255, 0.05);
541
  border: 1px solid #00bfff;
542
  border-radius: 8px;
543
  padding: 20px;
544
- margin-bottom: 20px;
545
- }}
546
- .features h3 {{ margin-bottom: 15px; }}
547
- .feature-list {{ list-style: none; padding: 0; }}
548
- .feature-list li {{
 
 
 
 
549
  padding: 10px;
550
  margin: 5px 0;
551
  background: rgba(0, 191, 255, 0.1);
552
  border-radius: 5px;
553
- border-left: 3px solid #00ff88;
554
- }}
555
- .timestamp {{ text-align: center; margin-top: 20px; opacity: 0.5; }}
 
 
 
 
 
 
 
556
  </style>
557
  </head>
558
  <body>
559
  <div class="container">
560
  <div class="header">
561
  <h1>βš™οΈ WORKER NODE βš™οΈ</h1>
562
- <div>SAM-Z-1 Distributed Worker v5.0</div>
563
- <div>
564
- <span class="badge badge-v5">V5 PROTOCOL</span>
565
- <span class="badge badge-multi">{len(loaded_models)} MODELS</span>
566
- </div>
567
  </div>
568
 
569
  <div class="stats-grid" id="stats">
@@ -585,23 +484,14 @@ async def status_page():
585
  </div>
586
  </div>
587
 
588
- <div class="features">
589
- <h3>πŸ€– LOADED MODELS ({len(loaded_models)})</h3>
590
- <ul class="feature-list">
591
- {models_html}
592
- </ul>
593
- </div>
594
-
595
  <div class="features">
596
  <h3>πŸš€ CAPABILITIES</h3>
597
  <ul class="feature-list">
598
- <li>βœ… Original SAM-Z-1 (preserved)</li>
599
- <li>βœ… 4 new SAM-X-1 models</li>
600
- <li>βœ… Separate tokenizers per family</li>
601
- <li>βœ… Multi-model selection</li>
602
- <li>βœ… Token & batch decoding</li>
603
- <li>βœ… Streaming support</li>
604
- <li>βœ… Auto version detection</li>
605
  </ul>
606
  </div>
607
 
@@ -609,8 +499,21 @@ async def status_page():
609
  </div>
610
 
611
  <script>
612
- async function updateStats() {{
613
- try {{
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  const statsRes = await fetch('/stats');
615
  const stats = await statsRes.json();
616
 
@@ -622,15 +525,16 @@ async def status_page():
622
  const h = Math.floor(uptime / 3600);
623
  const m = Math.floor((uptime % 3600) / 60);
624
  const s = uptime % 60;
625
- document.getElementById('uptime').textContent = `${{h}}h ${{m}}m ${{s}}s`;
626
 
627
  document.getElementById('timestamp').textContent =
628
- `Last update: ${{new Date().toLocaleTimeString()}}`;
629
- }} catch (e) {{
630
  console.error('Failed to update stats:', e);
631
- }}
632
- }}
633
 
 
634
  setInterval(updateStats, 1000);
635
  updateStats();
636
  </script>
@@ -645,38 +549,8 @@ async def status_page():
645
  @app.get("/health")
646
  async def health():
647
  return {
648
- "status": "healthy" if loaded_models else "loading",
649
- "model_loaded": len(loaded_models) > 0,
650
- "models_count": len(loaded_models)
651
- }
652
-
653
- @app.get("/info")
654
- async def worker_info():
655
- """Worker information for version detection"""
656
- return {
657
- "version": "v5",
658
- "models": list(loaded_models.keys()),
659
- "features": [
660
- "multi_model",
661
- "model_selection",
662
- "separate_tokenizers",
663
- "token_generation",
664
- "batch_decoding",
665
- "streaming"
666
- ],
667
- "model_families": {
668
- "sam-z": [m for m, info in MODEL_REGISTRY.items() if info["family"] == "sam-z"],
669
- "sam-x": [m for m, info in MODEL_REGISTRY.items() if info["family"] == "sam-x"]
670
- }
671
- }
672
-
673
- @app.get("/models")
674
- async def list_models():
675
- """List available models"""
676
- return {
677
- "models": list(loaded_models.keys()),
678
- "default": current_model,
679
- "count": len(loaded_models)
680
  }
681
 
682
  @app.get("/stats")
@@ -687,16 +561,17 @@ async def stats():
687
  "total_tokens": worker_stats["total_tokens"],
688
  "decode_requests": worker_stats["decode_requests"],
689
  "uptime": uptime,
690
- "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0,
691
- "model_usage": worker_stats["model_usage"]
692
  }
693
 
694
  @app.post("/decode")
695
  async def decode(request: DecodeRequest):
696
- """Fast single decode - uses correct tokenizer"""
 
 
 
697
  try:
698
  worker_stats["decode_requests"] += 1
699
- tokenizer, _ = get_tokenizer_for_model(request.model)
700
  text = tokenizer.decode(request.token_ids)
701
  return {"text": text}
702
  except Exception as e:
@@ -704,10 +579,12 @@ async def decode(request: DecodeRequest):
704
 
705
  @app.post("/decode/batch")
706
  async def batch_decode(request: BatchDecodeRequest):
707
- """Optimized batch decoding - uses correct tokenizer"""
 
 
 
708
  try:
709
  worker_stats["decode_requests"] += len(request.batches)
710
- tokenizer, _ = get_tokenizer_for_model(request.model)
711
  results = [tokenizer.decode(batch) for batch in request.batches]
712
  return {"texts": results}
713
  except Exception as e:
@@ -715,15 +592,9 @@ async def batch_decode(request: BatchDecodeRequest):
715
 
716
  @app.post("/generate")
717
  async def generate(request: GenerateRequest):
718
- """Generate text with model selection"""
719
- if not loaded_models:
720
- raise HTTPException(status_code=503, detail="No models loaded")
721
-
722
- # Track model usage
723
- model_name = request.model or current_model
724
- if model_name not in worker_stats["model_usage"]:
725
- worker_stats["model_usage"][model_name] = 0
726
- worker_stats["model_usage"][model_name] += 1
727
 
728
  worker_stats["total_requests"] += 1
729
  start_time = time.time()
@@ -741,8 +612,7 @@ async def generate(request: GenerateRequest):
741
  top_k=request.top_k,
742
  top_p=request.top_p,
743
  repetition_penalty=request.repetition_penalty,
744
- return_token_ids=request.return_token_ids,
745
- model_name=request.model
746
  ):
747
  token_count += 1
748
  worker_stats["total_tokens"] += 1
@@ -756,7 +626,7 @@ async def generate(request: GenerateRequest):
756
  await asyncio.sleep(0.001)
757
 
758
  elapsed = time.time() - start_time
759
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'model': model_name})}\n\n"
760
 
761
  except Exception as e:
762
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -775,8 +645,7 @@ async def generate(request: GenerateRequest):
775
  top_k=request.top_k,
776
  top_p=request.top_p,
777
  repetition_penalty=request.repetition_penalty,
778
- return_token_ids=request.return_token_ids,
779
- model_name=request.model
780
  ):
781
  if not request.return_token_ids:
782
  generated_text += token_text
@@ -789,8 +658,7 @@ async def generate(request: GenerateRequest):
789
  "text": generated_text,
790
  "tokens": token_count,
791
  "time": elapsed,
792
- "tokens_per_second": token_count / elapsed if elapsed > 0 else 0,
793
- "model": model_name
794
  }
795
 
796
  except Exception as e:
@@ -798,15 +666,9 @@ async def generate(request: GenerateRequest):
798
 
799
  @app.post("/chat")
800
  async def chat(request: ChatRequest):
801
- """Chat completion with model selection"""
802
- if not loaded_models:
803
- raise HTTPException(status_code=503, detail="No models loaded")
804
-
805
- # Track model usage
806
- model_name = request.model or current_model
807
- if model_name not in worker_stats["model_usage"]:
808
- worker_stats["model_usage"][model_name] = 0
809
- worker_stats["model_usage"][model_name] += 1
810
 
811
  worker_stats["total_requests"] += 1
812
  prompt = format_chat_prompt(request.messages)
@@ -825,8 +687,7 @@ async def chat(request: ChatRequest):
825
  top_k=request.top_k,
826
  top_p=request.top_p,
827
  repetition_penalty=request.repetition_penalty,
828
- return_token_ids=request.return_token_ids,
829
- model_name=request.model
830
  ):
831
  token_count += 1
832
  worker_stats["total_tokens"] += 1
@@ -845,7 +706,7 @@ async def chat(request: ChatRequest):
845
  await asyncio.sleep(0.001)
846
 
847
  elapsed = time.time() - start_time
848
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'model': model_name})}\n\n"
849
 
850
  except Exception as e:
851
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -864,8 +725,7 @@ async def chat(request: ChatRequest):
864
  top_k=request.top_k,
865
  top_p=request.top_p,
866
  repetition_penalty=request.repetition_penalty,
867
- return_token_ids=request.return_token_ids,
868
- model_name=request.model
869
  ):
870
  if not request.return_token_ids:
871
  generated_text += token_text
@@ -886,8 +746,7 @@ async def chat(request: ChatRequest):
886
  },
887
  "tokens": token_count,
888
  "time": elapsed,
889
- "tokens_per_second": token_count / elapsed if elapsed > 0 else 0,
890
- "model": model_name
891
  }
892
 
893
  except Exception as e:
@@ -897,152 +756,86 @@ async def chat(request: ChatRequest):
897
  # Model Loading
898
  # ============================================================================
899
 
900
- async def load_single_model(model_name: str, model_info: dict) -> bool:
901
- """Load a single model with its tokenizer"""
902
- global loaded_models, current_model
 
 
903
 
904
  try:
905
- print(f"\n⏳ Loading: {model_name} ({model_info['family']} family)")
906
- print(f" Repo: {model_info['repo']}")
907
- print(f" Weights: {model_info['weights']}")
908
 
909
- # Load tokenizer for this family
910
- tokenizer, eos_token_id = await load_tokenizer(
911
- model_info['family'],
912
- model_info['tokenizer_repo']
913
- )
914
-
915
- # Load config
916
- if model_info['config']:
917
- print(f" Config: {model_info['config']}")
918
- config_path = hf_hub_download(
919
- repo_id=model_info['repo'],
920
- filename=model_info['config'],
921
- cache_dir=CACHE_DIR
922
- )
923
- with open(config_path, 'r') as f:
924
- config_raw = json.load(f)
925
- else:
926
- # Load base config for Large model
927
- print(f" Loading base config from tokenizer repo...")
928
- config_path = hf_hub_download(
929
- repo_id=model_info['tokenizer_repo'],
930
- filename="config.json",
931
- cache_dir=CACHE_DIR
932
- )
933
- with open(config_path, 'r') as f:
934
- config_raw = json.load(f)
935
 
936
- # Convert to model format
937
- model_config = {
938
- 'vocab_size': config_raw['vocab_size'],
939
- 'd_model': config_raw['hidden_size'],
940
- 'n_heads': config_raw['num_attention_heads'],
941
- 'ff_mult': config_raw['intermediate_size'] / config_raw['hidden_size'],
942
- 'dropout': config_raw.get('dropout', 0.0),
943
- 'max_len': config_raw['max_position_embeddings'],
944
- 'rope_theta': config_raw['rope_theta'],
945
- 'n_layers': config_raw['num_hidden_layers']
946
- }
947
 
948
- # Add for config object
949
- model_config['max_position_embeddings'] = config_raw['max_position_embeddings']
950
 
951
- print(f" πŸ“ Architecture: {model_config['n_layers']} layers, {model_config['n_heads']} heads")
 
952
 
953
- # Load weights
954
- weights_path = hf_hub_download(
955
- repo_id=model_info['repo'],
956
- filename=model_info['weights'],
957
- cache_dir=CACHE_DIR
958
- )
959
 
960
- # Build model
961
- model = SAM1Model(**model_config)
962
- dummy_input = tf.zeros((1, 1), dtype=tf.int32)
963
- model(dummy_input)
964
- model.load_weights(weights_path)
965
- model.trainable = False
966
 
967
- # Create optimized forward pass
968
- @tf.function(
969
- input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)],
970
- jit_compile=True,
971
- reduce_retracing=True
972
- )
973
- def fast_predict(inputs):
974
- return model(inputs, training=False)
975
 
976
- # Warm up
977
- print(f" πŸ”₯ Warming up...")
978
- dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
979
- _ = fast_predict(dummy)
980
 
981
- # Store model with its tokenizer
982
- loaded_models[model_name] = (model, fast_predict, model_config, tokenizer, eos_token_id)
983
 
984
- # Set as default if first
985
- if current_model is None:
986
- current_model = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
 
988
- # Count parameters
989
- total_params = sum(np.prod(w.shape) for w in model.weights)
990
- if total_params >= 1e9:
991
- param_str = f"{total_params/1e9:.2f}B"
992
- elif total_params >= 1e6:
993
- param_str = f"{total_params/1e6:.2f}M"
994
  else:
995
- param_str = f"{total_params/1e3:.2f}K"
996
-
997
- print(f" βœ… Loaded successfully!")
998
- print(f" πŸ“Š Parameters: {param_str}")
999
- print(f" πŸ”€ Tokenizer vocab: {tokenizer.get_vocab_size()}")
1000
-
1001
- return True
1002
-
1003
- except Exception as e:
1004
- print(f" ⚠️ Failed to load {model_name}: {e}")
1005
- import traceback
1006
- traceback.print_exc()
1007
- return False
1008
-
1009
- @app.on_event("startup")
1010
- async def load_models():
1011
- global loaded_models, current_model
1012
-
1013
- print("="*80)
1014
- print("πŸš€ SAM-Z-1 Worker Node v5.0 - Multi-Model with Separate Tokenizers".center(80))
1015
- print("="*80)
1016
-
1017
- try:
1018
- # Load all models
1019
- print("\n" + "="*80)
1020
- print("πŸ“¦ LOADING ALL 5 MODELS".center(80))
1021
- print("="*80)
1022
-
1023
- loaded_count = 0
1024
- for model_name, model_info in MODEL_REGISTRY.items():
1025
- success = await load_single_model(model_name, model_info)
1026
- if success:
1027
- loaded_count += 1
1028
-
1029
- if loaded_count == 0:
1030
- raise RuntimeError("❌ No models loaded successfully!")
1031
 
1032
- print(f"\n{'='*80}")
1033
- print(f"βœ… Successfully loaded {loaded_count}/{len(MODEL_REGISTRY)} models")
1034
- print(f"πŸ“Œ Default model: {current_model}")
1035
 
1036
- # Show tokenizer families
1037
- print(f"\nπŸ”€ Tokenizer Families:")
1038
- print(f" SAM-Z family: {len([m for m, i in MODEL_REGISTRY.items() if i['family'] == 'sam-z'])} model(s)")
1039
- print(f" SAM-X family: {len([m for m, i in MODEL_REGISTRY.items() if i['family'] == 'sam-x'])} model(s)")
1040
 
1041
- print(f"\nπŸš€ Worker ready for inference!")
1042
- print(f"{'='*80}\n")
 
 
 
 
1043
 
1044
  except Exception as e:
1045
- print(f"\n❌ Failed to initialize worker: {e}")
1046
  import traceback
1047
  traceback.print_exc()
1048
  raise
 
1
  """
2
+ SAM-Z-1 Distributed Worker Node v4.0
3
+ Optimized for distributed gen/decode pipeline
 
 
 
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
 
14
  from tokenizers import Tokenizer
15
  import numpy as np
16
  import time
17
+ from typing import List, Optional
18
  import asyncio
19
 
20
+ app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ============================================================================
23
  # Model Architecture
 
201
  return base_config
202
 
203
  # ============================================================================
204
+ # Global State
205
  # ============================================================================
206
 
207
+ model = None
208
+ tokenizer = None
209
+ config = None
210
+ eos_token_id = None
211
+ fast_forward = None
212
+
213
+ MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
214
+ CACHE_DIR = "./model_cache"
215
 
216
+ # Stats
217
  worker_stats = {
218
  "total_requests": 0,
219
  "total_tokens": 0,
220
  "decode_requests": 0,
221
+ "uptime_start": time.time()
 
222
  }
223
 
224
  # ============================================================================
 
234
  repetition_penalty: float = 1.1
235
  stream: bool = False
236
  return_token_ids: bool = False
 
237
 
238
  class ChatMessage(BaseModel):
239
  role: str
 
248
  repetition_penalty: float = 1.1
249
  stream: bool = False
250
  return_token_ids: bool = False
 
251
 
252
  class DecodeRequest(BaseModel):
253
  token_ids: List[int]
 
254
 
255
  class BatchDecodeRequest(BaseModel):
256
  batches: List[List[int]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  # ============================================================================
259
  # Generation Functions
 
266
  top_k: int = 40,
267
  top_p: float = 0.9,
268
  repetition_penalty: float = 1.1,
269
+ return_token_ids: bool = False
 
270
  ):
271
+ """Core generation - yields (token_id, token_text or None)"""
272
+ global model, tokenizer, config, eos_token_id, fast_forward
273
 
 
 
 
 
 
 
 
 
 
 
274
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
275
 
276
  if len(input_ids) == 0:
 
349
 
350
  @app.get("/", response_class=HTMLResponse)
351
  async def status_page():
352
+ """Worker status page"""
353
+ return """
 
 
 
 
 
 
354
  <!DOCTYPE html>
355
  <html>
356
  <head>
357
+ <title>SAM-Z-1 Worker Node</title>
358
  <style>
359
+ * { margin: 0; padding: 0; box-sizing: border-box; }
360
+ body {
361
  font-family: 'Courier New', monospace;
362
  background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
363
  color: #00bfff;
364
  padding: 20px;
365
  min-height: 100vh;
366
+ }
367
+ .container {
368
+ max-width: 900px;
369
+ margin: 0 auto;
370
+ }
371
+ .header {
372
  text-align: center;
373
  padding: 30px;
374
  background: rgba(0, 191, 255, 0.1);
 
376
  border-radius: 10px;
377
  margin-bottom: 30px;
378
  box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
379
+ }
380
+ .header h1 {
381
  font-size: 2.5em;
382
  text-transform: uppercase;
383
  letter-spacing: 3px;
384
  animation: glow 2s ease-in-out infinite alternate;
385
+ }
386
+ @keyframes glow {
387
+ from { text-shadow: 0 0 10px #00bfff; }
388
+ to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }
389
+ }
390
+ .badge {
391
  display: inline-block;
392
  padding: 5px 15px;
393
  border-radius: 15px;
394
  font-size: 0.9em;
395
+ margin-top: 10px;
396
+ }
397
+ .badge-ready {
398
  background: rgba(0, 255, 136, 0.2);
399
  border: 1px solid #00ff88;
400
  color: #00ff88;
401
+ }
402
+ .badge-loading {
403
  background: rgba(255, 165, 0, 0.2);
404
  border: 1px solid #ffa500;
405
  color: #ffa500;
406
+ }
407
+ .stats-grid {
408
  display: grid;
409
  grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
410
  gap: 20px;
411
  margin-bottom: 30px;
412
+ }
413
+ .stat-card {
414
  background: rgba(0, 191, 255, 0.05);
415
  border: 1px solid #00bfff;
416
  border-radius: 8px;
417
  padding: 20px;
418
  text-align: center;
419
+ }
420
+ .stat-label {
421
+ font-size: 0.8em;
422
+ opacity: 0.7;
423
+ text-transform: uppercase;
424
+ margin-bottom: 10px;
425
+ }
426
+ .stat-value {
427
+ font-size: 2em;
428
+ font-weight: bold;
429
+ }
430
+ .features {
431
  background: rgba(0, 191, 255, 0.05);
432
  border: 1px solid #00bfff;
433
  border-radius: 8px;
434
  padding: 20px;
435
+ }
436
+ .features h3 {
437
+ margin-bottom: 15px;
438
+ }
439
+ .feature-list {
440
+ list-style: none;
441
+ padding: 0;
442
+ }
443
+ .feature-list li {
444
  padding: 10px;
445
  margin: 5px 0;
446
  background: rgba(0, 191, 255, 0.1);
447
  border-radius: 5px;
448
+ }
449
+ .feature-list li:before {
450
+ content: "⚑ ";
451
+ color: #00ff88;
452
+ }
453
+ .timestamp {
454
+ text-align: center;
455
+ margin-top: 20px;
456
+ opacity: 0.5;
457
+ }
458
  </style>
459
  </head>
460
  <body>
461
  <div class="container">
462
  <div class="header">
463
  <h1>βš™οΈ WORKER NODE βš™οΈ</h1>
464
+ <div>SAM-Z-1 Distributed Worker v4.0</div>
465
+ <div class="badge" id="status-badge">CHECKING STATUS...</div>
 
 
 
466
  </div>
467
 
468
  <div class="stats-grid" id="stats">
 
484
  </div>
485
  </div>
486
 
 
 
 
 
 
 
 
487
  <div class="features">
488
  <h3>πŸš€ CAPABILITIES</h3>
489
  <ul class="feature-list">
490
+ <li>Full Text Generation</li>
491
+ <li>Token-Only Mode (for distributed pipeline)</li>
492
+ <li>High-Speed Batch Decoding</li>
493
+ <li>Chat Completion</li>
494
+ <li>Streaming & Non-Streaming</li>
 
 
495
  </ul>
496
  </div>
497
 
 
499
  </div>
500
 
501
  <script>
502
+ async function updateStats() {
503
+ try {
504
+ const response = await fetch('/health');
505
+ const data = await response.json();
506
+
507
+ const badge = document.getElementById('status-badge');
508
+ if (data.model_loaded) {
509
+ badge.textContent = 'βœ… READY FOR INFERENCE';
510
+ badge.className = 'badge badge-ready';
511
+ } else {
512
+ badge.textContent = '⏳ LOADING MODEL...';
513
+ badge.className = 'badge badge-loading';
514
+ }
515
+
516
+ // Fetch stats
517
  const statsRes = await fetch('/stats');
518
  const stats = await statsRes.json();
519
 
 
525
  const h = Math.floor(uptime / 3600);
526
  const m = Math.floor((uptime % 3600) / 60);
527
  const s = uptime % 60;
528
+ document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`;
529
 
530
  document.getElementById('timestamp').textContent =
531
+ `Last update: ${new Date().toLocaleTimeString()}`;
532
+ } catch (e) {
533
  console.error('Failed to update stats:', e);
534
+ }
535
+ }
536
 
537
+ // Update every second
538
  setInterval(updateStats, 1000);
539
  updateStats();
540
  </script>
 
549
  @app.get("/health")
550
  async def health():
551
  return {
552
+ "status": "healthy" if model is not None else "loading",
553
+ "model_loaded": model is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  }
555
 
556
  @app.get("/stats")
 
561
  "total_tokens": worker_stats["total_tokens"],
562
  "decode_requests": worker_stats["decode_requests"],
563
  "uptime": uptime,
564
+ "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0
 
565
  }
566
 
567
  @app.post("/decode")
568
  async def decode(request: DecodeRequest):
569
+ """Fast single decode"""
570
+ if tokenizer is None:
571
+ raise HTTPException(status_code=503, detail="Tokenizer not loaded")
572
+
573
  try:
574
  worker_stats["decode_requests"] += 1
 
575
  text = tokenizer.decode(request.token_ids)
576
  return {"text": text}
577
  except Exception as e:
 
579
 
580
  @app.post("/decode/batch")
581
  async def batch_decode(request: BatchDecodeRequest):
582
+ """Optimized batch decoding for distributed pipeline"""
583
+ if tokenizer is None:
584
+ raise HTTPException(status_code=503, detail="Tokenizer not loaded")
585
+
586
  try:
587
  worker_stats["decode_requests"] += len(request.batches)
 
588
  results = [tokenizer.decode(batch) for batch in request.batches]
589
  return {"texts": results}
590
  except Exception as e:
 
592
 
593
  @app.post("/generate")
594
  async def generate(request: GenerateRequest):
595
+ """Generate text"""
596
+ if model is None:
597
+ raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
 
 
598
 
599
  worker_stats["total_requests"] += 1
600
  start_time = time.time()
 
612
  top_k=request.top_k,
613
  top_p=request.top_p,
614
  repetition_penalty=request.repetition_penalty,
615
+ return_token_ids=request.return_token_ids
 
616
  ):
617
  token_count += 1
618
  worker_stats["total_tokens"] += 1
 
626
  await asyncio.sleep(0.001)
627
 
628
  elapsed = time.time() - start_time
629
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
630
 
631
  except Exception as e:
632
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
645
  top_k=request.top_k,
646
  top_p=request.top_p,
647
  repetition_penalty=request.repetition_penalty,
648
+ return_token_ids=request.return_token_ids
 
649
  ):
650
  if not request.return_token_ids:
651
  generated_text += token_text
 
658
  "text": generated_text,
659
  "tokens": token_count,
660
  "time": elapsed,
661
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
 
662
  }
663
 
664
  except Exception as e:
 
666
 
667
  @app.post("/chat")
668
  async def chat(request: ChatRequest):
669
+ """Chat completion"""
670
+ if model is None:
671
+ raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
 
 
672
 
673
  worker_stats["total_requests"] += 1
674
  prompt = format_chat_prompt(request.messages)
 
687
  top_k=request.top_k,
688
  top_p=request.top_p,
689
  repetition_penalty=request.repetition_penalty,
690
+ return_token_ids=request.return_token_ids
 
691
  ):
692
  token_count += 1
693
  worker_stats["total_tokens"] += 1
 
706
  await asyncio.sleep(0.001)
707
 
708
  elapsed = time.time() - start_time
709
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
710
 
711
  except Exception as e:
712
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
725
  top_k=request.top_k,
726
  top_p=request.top_p,
727
  repetition_penalty=request.repetition_penalty,
728
+ return_token_ids=request.return_token_ids
 
729
  ):
730
  if not request.return_token_ids:
731
  generated_text += token_text
 
746
  },
747
  "tokens": token_count,
748
  "time": elapsed,
749
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
 
750
  }
751
 
752
  except Exception as e:
 
756
  # Model Loading
757
  # ============================================================================
758
 
759
+ @app.on_event("startup")
760
+ async def load_model():
761
+ global model, tokenizer, config, eos_token_id, fast_forward
762
+
763
+ print("πŸš€ Loading SAM-Z-1 Model...")
764
 
765
  try:
766
+ config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
 
767
 
768
+ try:
769
+ weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
770
+ print("βœ… Found checkpoint weights")
771
+ use_checkpoint = True
772
+ except:
773
+ print("⚠️ Checkpoint not found, using model.keras")
774
+ model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
775
+ use_checkpoint = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
+ with open(config_path, 'r') as f:
778
+ config = json.load(f)
 
 
 
 
 
 
 
 
 
779
 
780
+ print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
 
781
 
782
+ print("πŸ“¦ Creating tokenizer...")
783
+ from transformers import AutoTokenizer
784
 
785
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
786
+ custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
787
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
 
 
 
788
 
789
+ os.makedirs("./temp_tokenizer", exist_ok=True)
790
+ hf_tokenizer.save_pretrained("./temp_tokenizer")
791
+ tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
 
 
 
792
 
793
+ eos_token_id = config.get('eos_token_id', 50256)
 
 
 
 
 
 
 
794
 
795
+ print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
 
 
 
796
 
797
+ print("πŸ”„ Loading model...")
 
798
 
799
+ if use_checkpoint:
800
+ model_config = {
801
+ 'vocab_size': config['vocab_size'],
802
+ 'd_model': config['hidden_size'],
803
+ 'n_layers': config['num_hidden_layers'],
804
+ 'n_heads': config['num_attention_heads'],
805
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
806
+ 'max_len': config['max_position_embeddings'],
807
+ 'dropout': 0.1,
808
+ 'rope_theta': config['rope_theta']
809
+ }
810
+
811
+ model = SAM1Model(config=model_config)
812
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
813
+ _ = model(dummy_input, training=False)
814
+
815
+ print(f"βœ… Architecture built: {model.count_params():,} parameters")
816
+
817
+ model.load_weights(weights_path)
818
+ print("βœ… Weights loaded!")
819
 
 
 
 
 
 
 
820
  else:
821
+ model = keras.models.load_model(model_path, compile=False)
822
+ print("βœ… Model loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
 
824
+ @tf.function(reduce_retracing=True)
825
+ def optimized_forward(input_tensor):
826
+ return model(input_tensor, training=False)
827
 
828
+ fast_forward = optimized_forward
 
 
 
829
 
830
+ print("βœ… SAM-Z-1 Distributed Worker ready! πŸš€")
831
+ print("πŸ”₯ Features enabled:")
832
+ print(" - Full text generation")
833
+ print(" - Token-only mode (distributed pipeline)")
834
+ print(" - Batch decoding optimization")
835
+ print(" - Streaming support")
836
 
837
  except Exception as e:
838
+ print(f"❌ Failed to load model: {e}")
839
  import traceback
840
  traceback.print_exc()
841
  raise