Files changed (1) hide show
  1. app.py +372 -414
app.py CHANGED
@@ -1,4 +1,10 @@
1
  import os
 
 
 
 
 
 
2
  os.environ['KERAS_BACKEND'] = 'tensorflow'
3
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
 
@@ -9,6 +15,10 @@ from tokenizers import Tokenizer
9
  from huggingface_hub import hf_hub_download
10
  import json
11
  from abc import ABC, abstractmethod
 
 
 
 
12
 
13
  # ==============================================================================
14
  # Model Architecture (Must match training code)
@@ -237,6 +247,10 @@ class ModelBackend(ABC):
237
  @abstractmethod
238
  def get_info(self):
239
  pass
 
 
 
 
240
 
241
 
242
  class KerasBackend(ModelBackend):
@@ -256,6 +270,7 @@ class KerasBackend(ModelBackend):
256
  self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
257
 
258
  def predict(self, input_ids):
 
259
  inputs = np.array([input_ids], dtype=np.int32)
260
  logits = self.model(inputs, training=False)
261
  return logits[0, -1, :].numpy()
@@ -263,6 +278,9 @@ class KerasBackend(ModelBackend):
263
  def get_name(self):
264
  return self.display_name
265
 
 
 
 
266
  def get_info(self):
267
  info = f"{self.display_name}\n"
268
  info += f" Total params: {format_param_count(self.total_params)}\n"
@@ -274,186 +292,145 @@ class KerasBackend(ModelBackend):
274
 
275
 
276
  # ==============================================================================
277
- # EASY MODEL REGISTRY - ADD YOUR MODELS HERE!
278
  # ==============================================================================
279
  MODEL_REGISTRY = [
280
  # Format: (display_name, repo_id, weights_filename, config_filename)
281
- # Smaller models are ACTUALLY faster (fewer params = real speedup!)
282
-
283
  ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
284
  ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
285
  ("SAM-X-1-Mini 🚀 (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
286
  ("SAM-X-1-Nano ⚡⚡ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
287
  ]
288
 
289
- # To add a new model, just add a new line above! Format:
290
- # ("Display Name", "repo_id", "weights.h5", "config.json")
291
- # If config_filename is None, uses the default config
 
 
292
 
293
 
294
- # ==============================================================================
295
- # Load Models
296
- # ==============================================================================
297
- CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
 
 
 
298
 
299
- print("="*80)
300
- print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80))
301
- print("="*80)
302
-
303
- # Download config and tokenizer
304
- print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
305
- config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
306
- tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
307
-
308
- # Load config
309
- with open(config_path, 'r') as f:
310
- base_config = json.load(f)
311
-
312
- print(f"✅ Base config loaded")
313
-
314
- # Build base model config
315
- base_model_config = {
316
- 'vocab_size': base_config['vocab_size'],
317
- 'd_model': base_config['hidden_size'],
318
- 'n_heads': base_config['num_attention_heads'],
319
- 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
320
- 'dropout': base_config.get('dropout', 0.0),
321
- 'max_len': base_config['max_position_embeddings'],
322
- 'rope_theta': base_config['rope_theta'],
323
- 'n_layers': base_config['num_hidden_layers']
324
- }
325
-
326
- # Recreate tokenizer
327
- print("\n🔤 Recreating tokenizer...")
328
- tokenizer = Tokenizer.from_pretrained("gpt2")
329
- eos_token = ""
330
- eos_token_id = tokenizer.token_to_id(eos_token)
331
-
332
- if eos_token_id is None:
333
- tokenizer.add_special_tokens([eos_token])
334
- eos_token_id = tokenizer.token_to_id(eos_token)
335
 
336
- custom_tokens = ["<think>", "<think/>"]
337
- for token in custom_tokens:
338
- if tokenizer.token_to_id(token) is None:
339
- tokenizer.add_special_tokens([token])
 
 
340
 
341
- tokenizer.no_padding()
342
- tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
343
- print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
344
 
345
- # Load all models from registry
346
- print("\n" + "="*80)
347
- print("📦 LOADING MODELS".center(80))
348
- print("="*80)
349
 
350
- available_models = {}
351
- dummy_input = tf.zeros((1, 1), dtype=tf.int32)
 
352
 
353
- for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
354
- try:
355
- print(f"\n⏳ Loading: {display_name}")
356
- print(f" Repo: {repo_id}")
357
- print(f" Weights: {weights_filename}")
358
-
359
- # Download weights
360
- weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
361
-
362
- # Load custom config if specified (for pruned models)
363
- if config_filename:
364
- print(f" Config: {config_filename}")
365
- custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
366
- with open(custom_config_path, 'r') as f:
367
- model_config = json.load(f)
368
- print(f" 📐 Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim")
369
- else:
370
- model_config = base_model_config.copy()
371
-
372
- # Create model with appropriate config
373
- model = SAM1Model(**model_config)
374
- model(dummy_input)
375
- model.load_weights(weights_path)
376
- model.trainable = False
377
-
378
- # Create backend
379
- backend = KerasBackend(model, display_name, display_name)
380
- available_models[display_name] = backend
381
-
382
- # Print stats
383
- print(f" ✅ Loaded successfully!")
384
- print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
385
- print(f" 📊 Attention heads: {backend.n_heads}")
386
- print(f" 📊 FFN dimension: {backend.ff_dim}")
387
-
388
- except Exception as e:
389
- print(f" ⚠️ Failed to load: {e}")
390
- print(f" Skipping {display_name}...")
391
 
392
- if not available_models:
393
- raise RuntimeError("❌ No models loaded! Check your MODEL_REGISTRY configuration.")
394
 
395
- print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
396
- print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
- current_backend = list(available_models.values())[0]
 
 
 
 
399
 
400
- # ==============================================================================
401
- # Important Note About Pruning and Speed
402
- # ==============================================================================
403
- print("\n" + "="*80)
404
- print("💡 ABOUT PRUNING & SPEED".center(80))
405
- print("="*80)
406
- print("""
407
- 📌 Does pruning reduce parameter count?
408
- YES and NO:
409
- • Total param count stays the same (architecture unchanged)
410
- • BUT pruned weights are set to ZERO (sparse weights)
411
- • Active/non-zero params are reduced significantly
412
-
413
- 📌 Does pruning speed up inference?
414
- IT DEPENDS:
415
- • Dense operations (regular matrix multiply): NO speedup by default
416
- • Need sparse kernels or hardware support for actual speedup
417
- • HOWEVER: Smaller active weights = better cache utilization
418
- • Less computation on zeros = potential speedup on some hardware
419
-
420
- 📌 What DOES speed things up reliably?
421
- ✅ Quantization (FP16, INT8) - smaller types = faster compute
422
- ✅ Fewer layers (layer pruning)
423
- ✅ Smaller hidden dimensions (width reduction)
424
- ✅ Knowledge distillation to smaller architecture
425
-
426
- 📌 Why use structured pruning then?
427
- ✅ Reduces memory footprint (especially with sparse storage)
428
- ✅ Can be combined with quantization for real speedups
429
- ✅ Preserves quality better than aggressive dimension reduction
430
- ✅ Foundation for converting to truly smaller architecture
431
- """)
432
-
433
- def generate_response_stream(prompt, temperature=0.7, backend=None):
434
  """Generate response and yield tokens one by one for streaming."""
435
- if backend is None:
436
- backend = current_backend
437
 
 
 
 
438
  encoded_prompt = tokenizer.encode(prompt)
439
  input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
440
  generated = input_ids.copy()
441
 
442
  current_text = ""
443
- in_thinking = False
444
-
445
- # Get max_len from the backend's model config
446
  max_len = backend.model.cfg['max_len']
447
 
448
- for _ in range(512):
449
- current_input = generated[-max_len:]
 
450
 
451
  # Get logits from selected backend
452
  next_token_logits = backend.predict(current_input)
453
 
454
  if temperature > 0:
 
455
  next_token_logits = next_token_logits / temperature
456
- top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
 
457
  top_k_logits = next_token_logits[top_k_indices]
458
  top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
459
  top_k_probs /= top_k_probs.sum()
@@ -466,299 +443,280 @@ def generate_response_stream(prompt, temperature=0.7, backend=None):
466
 
467
  generated.append(int(next_token))
468
 
 
469
  new_text = tokenizer.decode(generated[len(input_ids):])
 
470
  if len(new_text) > len(current_text):
471
  new_chunk = new_text[len(current_text):]
472
  current_text = new_text
473
 
474
- if "<think>" in new_chunk:
475
- in_thinking = True
476
- elif "</think>" in new_chunk or "<think/>" in new_chunk:
477
- in_thinking = False
478
-
479
  yield new_chunk, in_thinking
 
 
 
480
 
481
  # ==============================================================================
482
- # Gradio Interface
483
  # ==============================================================================
484
- if __name__ == "__main__":
485
- import gradio as gr
486
-
487
- custom_css = """
488
- .chat-container {
489
- height: 600px;
490
- overflow-y: auto;
491
- padding: 20px;
492
- background: #ffffff;
493
- }
494
-
495
- .user-message {
496
- background: #f7f7f8;
497
- padding: 16px;
498
- margin: 12px 0;
499
- border-radius: 8px;
500
- }
501
-
502
- .assistant-message {
503
- background: #ffffff;
504
- padding: 16px;
505
- margin: 12px 0;
506
- border-radius: 8px;
507
- border-left: 3px solid #10a37f;
508
- }
509
-
510
- .message-content {
511
- color: #353740;
512
- line-height: 1.6;
513
- font-size: 15px;
514
- }
515
-
516
- .message-header {
517
- font-weight: 600;
518
- margin-bottom: 8px;
519
- color: #353740;
520
- font-size: 14px;
521
- }
522
-
523
- .thinking-content {
524
- color: #6b7280;
525
- font-style: italic;
526
- border-left: 3px solid #d1d5db;
527
- padding-left: 12px;
528
- margin: 8px 0;
529
- background: #f9fafb;
530
- padding: 8px 12px;
531
- border-radius: 4px;
532
- }
533
-
534
- .input-row {
535
- background: #ffffff;
536
- padding: 12px;
537
- border-radius: 8px;
538
- margin-top: 12px;
539
- border: 1px solid #e5e7eb;
540
- }
541
-
542
- .gradio-container {
543
- max-width: 900px !important;
544
- margin: auto !important;
545
- }
546
-
547
- .announcement-banner {
548
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
549
- color: white;
550
- padding: 16px 24px;
551
- border-radius: 12px;
552
- margin-bottom: 20px;
553
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
554
- text-align: center;
555
- font-size: 16px;
556
- font-weight: 500;
557
- animation: slideIn 0.5s ease-out;
558
- }
559
 
560
- @keyframes slideIn {
561
- from {
562
- opacity: 0;
563
- transform: translateY(-20px);
564
- }
565
- to {
566
- opacity: 1;
567
- transform: translateY(0);
568
- }
569
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
- .announcement-banner strong {
572
- font-weight: 700;
573
- font-size: 18px;
574
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
- .settings-panel {
577
- background: #f9fafb;
578
- padding: 16px;
579
- border-radius: 8px;
580
- margin-bottom: 12px;
581
- border: 1px solid #e5e7eb;
582
- }
 
 
 
583
 
584
- .model-info {
585
- background: #f0f9ff;
586
- border: 1px solid #bae6fd;
587
- padding: 12px;
588
- border-radius: 8px;
589
- margin-top: 8px;
590
- font-size: 13px;
591
- font-family: monospace;
592
- white-space: pre-line;
593
- }
594
- """
595
 
596
- def format_message_html(role, content, show_thinking=True):
597
- """Format a single message as HTML."""
598
- role_class = "user-message" if role == "user" else "assistant-message"
599
- role_name = "You" if role == "user" else "SAM-X-1"
600
-
601
- thinking = ""
602
- answer = ""
603
-
604
- if "<think>" in content:
605
- parts = content.split("<think>", 1)
606
- before_think = parts[0].strip()
607
-
608
- if len(parts) > 1:
609
- after_think = parts[1]
610
-
611
- if "</think>" in after_think:
612
- think_parts = after_think.split("</think>", 1)
613
- thinking = think_parts[0].strip()
614
- answer = (before_think + " " + think_parts[1]).strip()
615
- elif "<think/>" in after_think:
616
- think_parts = after_think.split("<think/>", 1)
617
- thinking = think_parts[0].strip()
618
- answer = (before_think + " " + think_parts[1]).strip()
619
- else:
620
- thinking = after_think.strip()
621
- answer = before_think
622
- else:
623
- answer = before_think
624
- else:
625
- answer = content
626
-
627
- html = f'<div class="{role_class}">'
628
- html += f'<div class="message-header">{role_name}</div>'
629
- html += f'<div class="message-content">'
630
-
631
- if thinking and show_thinking:
632
- html += f'<div class="thinking-content">💭 {thinking}</div>'
633
-
634
- if answer:
635
- html += f'<div>{answer}</div>'
636
-
637
- html += '</div></div>'
638
- return html
639
-
640
- def render_history(history, show_thinking):
641
- """Render chat history as HTML."""
642
- html = ""
643
- for msg in history:
644
- html += format_message_html(msg["role"], msg["content"], show_thinking)
645
- return html
646
-
647
- def send_message(message, history, show_thinking, temperature, model_choice):
648
- if not message.strip():
649
- yield history, "", render_history(history, show_thinking), ""
650
- return
651
-
652
- # Switch backend based on selection
653
- backend = available_models[model_choice]
654
-
655
- # Add user message
656
- history.append({"role": "user", "content": message})
657
- yield history, "", render_history(history, show_thinking), backend.get_info()
658
-
659
- # Generate prompt
660
- prompt = f"User: {message}\nSam: <think>"
661
-
662
- # Start assistant message
663
- history.append({"role": "assistant", "content": "<think>"})
664
-
665
- # Stream response
666
- for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend):
667
- history[-1]["content"] += new_chunk
668
- yield history, "", render_history(history, show_thinking), backend.get_info()
669
-
670
- # Create Gradio interface
671
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
672
- # Announcement Banner
673
- gr.HTML("""
674
- <div class="announcement-banner">
675
- 🎉 <strong>NEW UPDATE:</strong> Multiple model variants now available!
676
- Choose Fast/Mini/Nano for <strong>30-250% speed boost</strong>! ⚡
677
- The models marked with (BETA) are not useful yet. <strong>They are still in development!</strong>
678
- </div>
679
- """)
680
-
681
- gr.Markdown("# 🤖 SAM-X-1 Multi-Model Chat")
682
-
683
- # Settings panel
684
- with gr.Accordion("⚙️ Settings", open=False):
685
- with gr.Row():
686
- model_selector = gr.Dropdown(
687
- choices=list(available_models.keys()),
688
- value=list(available_models.keys())[0],
689
- label="Model Selection",
690
- info="Choose your speed/quality tradeoff"
691
- )
692
-
693
- model_info_box = gr.Textbox(
694
- label="Selected Model Info",
695
- value=list(available_models.values())[0].get_info(),
696
- interactive=False,
697
- lines=4,
698
- elem_classes=["model-info"]
699
- )
700
 
701
- with gr.Row():
702
- temperature_slider = gr.Slider(
703
- minimum=0.0,
704
- maximum=2.0,
705
- value=0.7,
706
- step=0.1,
707
- label="Temperature",
708
- info="Higher = more creative, Lower = more focused"
709
- )
710
- show_thinking_checkbox = gr.Checkbox(
711
- label="Show Thinking Process",
712
- value=True,
713
- info="Display model's reasoning"
714
  )
715
-
716
- # Chat state and display
717
- chatbot_state = gr.State([])
718
- chat_html = gr.HTML(value="", elem_classes=["chat-container"])
719
-
720
- # Input area
721
- with gr.Row(elem_classes=["input-row"]):
722
- msg_input = gr.Textbox(
723
- placeholder="Ask me anything...",
724
- show_label=False,
725
- container=False,
726
- scale=9
727
  )
728
- send_btn = gr.Button("Send", variant="primary", scale=1)
729
 
730
- with gr.Row():
731
- clear_btn = gr.Button("🗑️ Clear Chat", size="sm")
732
-
733
- # Event handlers
734
- msg_input.submit(
735
- send_message,
736
- inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
737
- outputs=[chatbot_state, msg_input, chat_html, model_info_box]
 
 
 
 
 
 
 
 
 
 
738
  )
739
-
740
- send_btn.click(
741
- send_message,
742
- inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
743
- outputs=[chatbot_state, msg_input, chat_html, model_info_box]
 
 
 
 
 
 
 
 
 
744
  )
 
 
745
 
746
- clear_btn.click(
747
- lambda: ([], ""),
748
- outputs=[chatbot_state, chat_html]
 
 
 
 
 
 
 
 
 
 
 
749
  )
750
 
751
- show_thinking_checkbox.change(
752
- lambda h, st: render_history(h, st),
753
- inputs=[chatbot_state, show_thinking_checkbox],
754
- outputs=[chat_html]
755
- )
756
 
757
- # Update model info when selection changes
758
- model_selector.change(
759
- lambda choice: available_models[choice].get_info(),
760
- inputs=[model_selector],
761
- outputs=[model_info_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
  )
 
 
763
 
764
- demo.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
+ import uuid
4
+ from datetime import datetime
5
+ from typing import List, Optional, Union, Dict, Any, Generator, Tuple
6
+
7
+ # Set environment variables for Keras/TensorFlow
8
  os.environ['KERAS_BACKEND'] = 'tensorflow'
9
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
10
 
 
15
  from huggingface_hub import hf_hub_download
16
  import json
17
  from abc import ABC, abstractmethod
18
+ from fastapi import FastAPI, HTTPException, status
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import StreamingResponse, JSONResponse
21
+ from pydantic import BaseModel, Field
22
 
23
  # ==============================================================================
24
  # Model Architecture (Must match training code)
 
247
  @abstractmethod
248
  def get_info(self):
249
  pass
250
+
251
+ @abstractmethod
252
+ def get_model(self) -> SAM1Model:
253
+ pass
254
 
255
 
256
  class KerasBackend(ModelBackend):
 
270
  self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
271
 
272
  def predict(self, input_ids):
273
+ # NOTE: This predicts the next token based on the input sequence
274
  inputs = np.array([input_ids], dtype=np.int32)
275
  logits = self.model(inputs, training=False)
276
  return logits[0, -1, :].numpy()
 
278
  def get_name(self):
279
  return self.display_name
280
 
281
+ def get_model(self) -> SAM1Model:
282
+ return self.model
283
+
284
  def get_info(self):
285
  info = f"{self.display_name}\n"
286
  info += f" Total params: {format_param_count(self.total_params)}\n"
 
292
 
293
 
294
  # ==============================================================================
295
+ # Model Registry and Asset Loading
296
  # ==============================================================================
297
  MODEL_REGISTRY = [
298
  # Format: (display_name, repo_id, weights_filename, config_filename)
 
 
299
  ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
300
  ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
301
  ("SAM-X-1-Mini 🚀 (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
302
  ("SAM-X-1-Nano ⚡⚡ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
303
  ]
304
 
305
+ CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
306
+ available_models: Dict[str, KerasBackend] = {}
307
+ tokenizer: Optional[Tokenizer] = None
308
+ eos_token_id: Optional[int] = None
309
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful and friendly assistant named SAM-X-1. Answer the user's request. You must prepend your answer with '<think>' and end your thoughts with '</think>' or '<think/>' followed by your actual response."
310
 
311
 
312
+ def load_all_assets():
313
+ """Load config, tokenizer, and all models."""
314
+ global tokenizer, eos_token_id, available_models, DEFAULT_SYSTEM_PROMPT
315
+
316
+ print("="*80)
317
+ print("🤖 SAM-X-1 API Backend Loading".center(80))
318
+ print("="*80)
319
 
320
+ # Download config and tokenizer
321
+ print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
322
+ config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
323
+
324
+ # Load config
325
+ with open(config_path, 'r') as f:
326
+ base_config = json.load(f)
327
+
328
+ print(f"✅ Base config loaded")
329
+
330
+ # Build base model config
331
+ base_model_config = {
332
+ 'vocab_size': base_config['vocab_size'],
333
+ 'd_model': base_config['hidden_size'],
334
+ 'n_heads': base_config['num_attention_heads'],
335
+ 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
336
+ 'dropout': base_config.get('dropout', 0.0),
337
+ 'max_len': base_config['max_position_embeddings'],
338
+ 'rope_theta': base_config['rope_theta'],
339
+ 'n_layers': base_config['num_hidden_layers']
340
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ # Recreate tokenizer
343
+ print("\n🔤 Recreating tokenizer...")
344
+ # NOTE: The original code uses "gpt2" to load the tokenizer architecture.
345
+ tokenizer = Tokenizer.from_pretrained("gpt2")
346
+ eos_token = ""
347
+ eos_token_id = tokenizer.token_to_id(eos_token)
348
 
349
+ if eos_token_id is None:
350
+ tokenizer.add_special_tokens([eos_token])
351
+ eos_token_id = tokenizer.token_to_id(eos_token)
352
 
353
+ custom_tokens = ["<think>", "<think/>", "</think>"]
354
+ for token in custom_tokens:
355
+ if tokenizer.token_to_id(token) is None:
356
+ tokenizer.add_special_tokens([token])
357
 
358
+ tokenizer.no_padding()
359
+ tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
360
+ print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
361
 
362
+ # Load all models from registry
363
+ print("\n" + "="*80)
364
+ print("📦 LOADING MODELS".center(80))
365
+ print("="*80)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
+ dummy_input = tf.zeros((1, 1), dtype=tf.int32)
 
368
 
369
+ for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
370
+ try:
371
+ print(f"\n⏳ Loading: {display_name}")
372
+
373
+ # Download weights
374
+ weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
375
+
376
+ # Load custom config if specified (for pruned models)
377
+ if config_filename:
378
+ custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
379
+ with open(custom_config_path, 'r') as f:
380
+ model_config = json.load(f)
381
+ else:
382
+ model_config = base_model_config.copy()
383
+
384
+ # Create model with appropriate config
385
+ model = SAM1Model(**model_config)
386
+ model(dummy_input)
387
+ model.load_weights(weights_path)
388
+ model.trainable = False
389
+
390
+ # Create backend
391
+ backend = KerasBackend(model, display_name, display_name)
392
+ available_models[display_name] = backend
393
+
394
+ # Print stats
395
+ print(f" ✅ Loaded successfully!")
396
+ print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
397
+
398
+ except Exception as e:
399
+ print(f" ⚠️ Failed to load {display_name}: {e}")
400
+ print(f" Skipping {display_name}...")
401
 
402
+ if not available_models:
403
+ # NOTE: In a real system, you might want a graceful fallback. Here, we must exit.
404
+ print("FATAL: No models loaded! Check your MODEL_REGISTRY configuration.")
405
+ # We raise a RuntimeError but let the startup event handle the final failure
406
+ # to ensure the FastAPI application runs the event loop.
407
 
408
+ def generate_response_stream(prompt: str, temperature: float, backend: KerasBackend, max_new_tokens: int = 512) -> Generator[Tuple[str, bool], None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  """Generate response and yield tokens one by one for streaming."""
 
 
410
 
411
+ if tokenizer is None or eos_token_id is None:
412
+ raise RuntimeError("Tokenizer not loaded.")
413
+
414
  encoded_prompt = tokenizer.encode(prompt)
415
  input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
416
  generated = input_ids.copy()
417
 
418
  current_text = ""
419
+ # Use max_len from the model config
 
 
420
  max_len = backend.model.cfg['max_len']
421
 
422
+ for _ in range(max_new_tokens):
423
+ # Sliding window for context
424
+ current_input = generated[-max_len:]
425
 
426
  # Get logits from selected backend
427
  next_token_logits = backend.predict(current_input)
428
 
429
  if temperature > 0:
430
+ # Top-K sampling
431
  next_token_logits = next_token_logits / temperature
432
+ top_k = 50
433
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
434
  top_k_logits = next_token_logits[top_k_indices]
435
  top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
436
  top_k_probs /= top_k_probs.sum()
 
443
 
444
  generated.append(int(next_token))
445
 
446
+ # Decode the newly generated part
447
  new_text = tokenizer.decode(generated[len(input_ids):])
448
+
449
  if len(new_text) > len(current_text):
450
  new_chunk = new_text[len(current_text):]
451
  current_text = new_text
452
 
453
+ # Simple check for thinking tags
454
+ in_thinking = "<think>" in current_text and not ( "</think>" in current_text or "<think/>" in current_text)
455
+
 
 
456
  yield new_chunk, in_thinking
457
+
458
+ yield "", False # End of stream
459
+
460
 
461
  # ==============================================================================
462
+ # FastAPI API & Pydantic Schemas (OpenAI Style)
463
  # ==============================================================================
464
+
465
+ # --- Pydantic Schemas for OpenAI API Compatibility ---
466
+ class ChatMessage(BaseModel):
467
+ role: str
468
+ content: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
+ class ChatCompletionRequest(BaseModel):
471
+ model: str = Field(..., description="The ID of the model to use.")
472
+ messages: List[ChatMessage] = Field(..., description="A list of messages comprising the conversation.")
473
+ temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature.")
474
+ max_tokens: Optional[int] = Field(512, ge=1, description="The maximum number of tokens to generate.")
475
+ stream: Optional[bool] = Field(False, description="Whether to stream the response.")
476
+
477
+ # OpenAI Response Structure: Chunk for Streaming
478
+ class ChatCompletionChunkChoice(BaseModel):
479
+ index: int = 0
480
+ delta: Dict[str, Optional[str]]
481
+ finish_reason: Optional[str] = None
482
+
483
+ class ChatCompletionChunk(BaseModel):
484
+ id: str
485
+ object: str = "chat.completion.chunk"
486
+ created: int = Field(default_factory=lambda: int(time.time()))
487
+ model: str
488
+ choices: List[ChatCompletionChunkChoice]
489
+
490
+ # OpenAI Response Structure: Full Response
491
+ class ChatCompletionUsage(BaseModel):
492
+ prompt_tokens: int
493
+ completion_tokens: int
494
+ total_tokens: int
495
+
496
+ class ChatCompletionChoice(BaseModel):
497
+ index: int = 0
498
+ message: ChatMessage
499
+ finish_reason: Optional[str] = None
500
+
501
+ class ChatCompletion(BaseModel):
502
+ id: str
503
+ object: str = "chat.completion"
504
+ created: int = Field(default_factory=lambda: int(time.time()))
505
+ model: str
506
+ choices: List[ChatCompletionChoice]
507
+ usage: ChatCompletionUsage
508
+
509
+ # Model Listing Response
510
+ class ModelCard(BaseModel):
511
+ id: str
512
+ object: str = "model"
513
+ created: int = Field(default_factory=lambda: int(time.time()))
514
+ owned_by: str = "SAM-X-1 Team"
515
+
516
+ class ModelList(BaseModel):
517
+ object: str = "list"
518
+ data: List[ModelCard]
519
+
520
+
521
+ # --- FastAPI Application ---
522
+ app = FastAPI(
523
+ title="SAM-X-1 Keras API (OpenAI-Style)",
524
+ description="A production-ready FastAPI backend for the SAM-X-1 Keras model.",
525
+ version="1.0.0",
526
+ )
527
+
528
+ # Production-grade CORS middleware
529
+ app.add_middleware(
530
+ CORSMiddleware,
531
+ allow_origins=["*"],
532
+ allow_credentials=True,
533
+ allow_methods=["*"],
534
+ allow_headers=["*"],
535
+ )
536
+
537
+
538
+ @app.on_event("startup")
539
+ async def startup_event():
540
+ """Load models and tokenizer when the FastAPI app starts."""
541
+ try:
542
+ load_all_assets()
543
+ except Exception as e:
544
+ # Print the error and allow FastAPI to start, but subsequent requests will fail
545
+ print(f"FATAL: Failed to load assets during startup: {e}")
546
+ pass
547
+
548
+ @app.get("/v1/models", response_model=ModelList)
549
+ async def list_models():
550
+ """Endpoint to list all available models."""
551
+ models_data = [
552
+ ModelCard(id=name, created=int(time.time()))
553
+ for name in available_models.keys()
554
+ ]
555
+ return ModelList(data=models_data)
556
+
557
+
558
+ def build_prompt_from_messages(messages: List[ChatMessage], system_prompt: str) -> str:
559
+ """Constructs the model's instruction-style prompt from a list of messages."""
560
+ prompt = f"System: {system_prompt}\n"
561
 
562
+ for message in messages:
563
+ role = message.role.capitalize()
564
+ content = message.content.strip()
565
+
566
+ if role == "User":
567
+ prompt += f"{role}: {content}\n"
568
+ elif role == "Assistant":
569
+ prompt += f"Sam: {content}\n"
570
+
571
+ prompt += "Sam: <think>"
572
+ return prompt
573
+
574
+
575
+ def format_sse_chunk(chunk: ChatCompletionChunk) -> str:
576
+ """Formats a Pydantic object as a Server-Sent Event (SSE) data block."""
577
+ return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
578
+
579
+ def streaming_generator(request: ChatCompletionRequest, backend: KerasBackend, full_prompt: str) -> Generator[str, None, None]:
580
+ """Generator function to stream LLM output in OpenAI SSE format."""
581
+ model_name = request.model
582
+ chat_id = f"chatcmpl-{uuid.uuid4().hex}"
583
+ max_new_tokens = request.max_tokens or 512
584
 
585
+ # 1. Send initial chunk with role
586
+ yield format_sse_chunk(
587
+ ChatCompletionChunk(
588
+ id=chat_id,
589
+ model=model_name,
590
+ choices=[ChatCompletionChunkChoice(index=0, delta={"role": "assistant"})]
591
+ )
592
+ )
593
+
594
+ full_response_text = ""
595
 
596
+ # 2. Stream tokens
597
+ try:
598
+ for new_chunk, _ in generate_response_stream(full_prompt, request.temperature, backend, max_new_tokens):
599
+ if not new_chunk:
600
+ continue
 
 
 
 
 
 
601
 
602
+ full_response_text += new_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
+ # Yield token chunk
605
+ yield format_sse_chunk(
606
+ ChatCompletionChunk(
607
+ id=chat_id,
608
+ model=model_name,
609
+ choices=[ChatCompletionChunkChoice(index=0, delta={"content": new_chunk})]
 
 
 
 
 
 
 
610
  )
 
 
 
 
 
 
 
 
 
 
 
 
611
  )
 
612
 
613
+ except Exception as e:
614
+ print(f"Error during streaming generation: {e}")
615
+ # A full production implementation would handle error chunks.
616
+ pass
617
+
618
+ # 3. Final chunk indicating the stream is finished
619
+ # NOTE: Calculating accurate token counts requires a dedicated token counter within the generation loop.
620
+ prompt_token_count = len(tokenizer.encode(full_prompt).ids) if tokenizer else 0
621
+ completion_token_count = len(tokenizer.encode(full_response_text).ids) if tokenizer else 0
622
+
623
+ yield format_sse_chunk(
624
+ ChatCompletionChunk(
625
+ id=chat_id,
626
+ model=model_name,
627
+ choices=[ChatCompletionChunkChoice(index=0, delta={}, finish_reason="stop")],
628
+ # Adding a usage object to the final chunk is non-standard but useful
629
+ # The official OpenAI spec includes usage in the final full response, not chunks.
630
+ # We'll omit it from the chunk for strict compatibility.
631
  )
632
+ )
633
+ # The required end-of-stream delimiter for SSE
634
+ yield "data: [DONE]\n\n"
635
+
636
+
637
+ @app.post("/v1/chat/completions")
638
+ async def chat_completions(request: ChatCompletionRequest):
639
+ """Main endpoint for chat completions, supporting both streaming and non-streaming."""
640
+
641
+ # 1. Model Validation
642
+ if request.model not in available_models:
643
+ raise HTTPException(
644
+ status_code=status.HTTP_400_BAD_REQUEST,
645
+ detail=f"Model '{request.model}' not found in registry. Available models: {list(available_models.keys())}"
646
  )
647
+ if tokenizer is None:
648
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model assets not loaded.")
649
 
650
+ backend = available_models[request.model]
651
+ max_new_tokens = request.max_tokens or 512
652
+
653
+ # 2. Prompt Formatting
654
+ if not request.messages:
655
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Messages array cannot be empty.")
656
+
657
+ full_prompt = build_prompt_from_messages(request.messages, DEFAULT_SYSTEM_PROMPT)
658
+
659
+ # 3. Streaming Response
660
+ if request.stream:
661
+ return StreamingResponse(
662
+ streaming_generator(request, backend, full_prompt),
663
+ media_type="text/event-stream"
664
  )
665
 
666
+ # 4. Non-Streaming Response (Blocking)
667
+ else:
668
+ full_response_text = ""
 
 
669
 
670
+ # Generator is forced to completion
671
+ for new_chunk, _ in generate_response_stream(full_prompt, request.temperature, backend, max_new_tokens):
672
+ full_response_text += new_chunk
673
+
674
+ # Build the final ChatCompletion response object
675
+ response_message = ChatMessage(role="assistant", content=full_response_text.strip())
676
+
677
+ # Token count approximation
678
+ prompt_token_count = len(tokenizer.encode(full_prompt).ids)
679
+ completion_token_count = len(tokenizer.encode(full_response_text).ids)
680
+
681
+ completion_response = ChatCompletion(
682
+ id=f"chatcmpl-{uuid.uuid4().hex}",
683
+ model=request.model,
684
+ choices=[ChatCompletionChoice(
685
+ message=response_message,
686
+ finish_reason="stop" # Simplified, could be "length" if max_tokens was hit precisely
687
+ )],
688
+ usage=ChatCompletionUsage(
689
+ prompt_tokens=prompt_token_count,
690
+ completion_tokens=completion_token_count,
691
+ total_tokens=prompt_token_count + completion_token_count
692
+ )
693
  )
694
+
695
+ return JSONResponse(content=completion_response.model_dump(exclude_none=True))
696
 
697
+ # ==============================================================================
698
+ # Execution Block
699
+ # ==============================================================================
700
+ if __name__ == "__main__":
701
+ # Ensure all models are loaded before running uvicorn
702
+ # This block is here for standalone execution and initial error checking
703
+ try:
704
+ load_all_assets()
705
+ except RuntimeError as e:
706
+ # If loading fails, print the error and exit gracefully
707
+ print(e)
708
+ exit(1)
709
+
710
+ import uvicorn
711
+
712
+ # Run the application
713
+ # NOTE: Set workers=1 for TensorFlow/Keras stability in standalone scripts.
714
+ # For robust production, use gunicorn to manage multiple uvicorn processes.
715
+ uvicorn.run(
716
+ "__main__:app",
717
+ host="0.0.0.0",
718
+ port=8000,
719
+ log_level="info",
720
+ workers=1,
721
+ # reload=True # Uncomment for development
722
+ )