Bc-AI commited on
Commit
0a83aff
·
verified ·
1 Parent(s): c62ac67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +620 -620
app.py CHANGED
@@ -15,275 +15,275 @@ from abc import ABC, abstractmethod
15
  # ==============================================================================
16
  @keras.saving.register_keras_serializable()
17
  class RotaryEmbedding(keras.layers.Layer):
18
-     def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
19
-         super().__init__(**kwargs)
20
-         self.dim = dim
21
-         self.max_len = max_len
22
-         self.theta = theta
23
-         self.built_cache = False
24
-
25
-     def build(self, input_shape):
26
-         if not self.built_cache:
27
-             inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
28
-             t = tf.range(self.max_len, dtype=tf.float32)
29
-             freqs = tf.einsum("i,j->ij", t, inv_freq)
30
-             emb = tf.concat([freqs, freqs], axis=-1)
31
-
32
-             self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
33
-             self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
34
-             self.built_cache = True
35
-         super().build(input_shape)
36
-
37
-     def rotate_half(self, x):
38
-         x1, x2 = tf.split(x, 2, axis=-1)
39
-         return tf.concat([-x2, x1], axis=-1)
40
-
41
-     def call(self, q, k):
42
-         seq_len = tf.shape(q)[2]
43
-         dtype = q.dtype
44
-         cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
45
-         sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
46
-
47
-         q_rotated = (q * cos) + (self.rotate_half(q) * sin)
48
-         k_rotated = (k * cos) + (self.rotate_half(k) * sin)
49
-
50
-         return q_rotated, k_rotated
51
-
52
-     def get_config(self):
53
-         config = super().get_config()
54
-         config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
55
-         return config
56
 
57
 
58
  @keras.saving.register_keras_serializable()
59
  class RMSNorm(keras.layers.Layer):
60
-     def __init__(self, epsilon=1e-5, **kwargs):
61
-         super().__init__(**kwargs)
62
-         self.epsilon = epsilon
63
 
64
-     def build(self, input_shape):
65
-         self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
66
 
67
-     def call(self, x):
68
-         variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
69
-         return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
70
 
71
-     def get_config(self):
72
-         config = super().get_config()
73
-         config.update({"epsilon": self.epsilon})
74
-         return config
75
 
76
 
77
  @keras.saving.register_keras_serializable()
78
  class TransformerBlock(keras.layers.Layer):
79
-     def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
80
-         super().__init__(**kwargs)
81
-         self.d_model = d_model
82
-         self.n_heads = n_heads
83
-         self.ff_dim = ff_dim
84
-         self.dropout_rate = dropout
85
-         self.max_len = max_len
86
-         self.rope_theta = rope_theta
87
-         self.head_dim = d_model // n_heads
88
-         self.layer_idx = layer_idx
89
-
90
-         self.pre_attn_norm = RMSNorm()
91
-         self.pre_ffn_norm = RMSNorm()
92
-
93
-         self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
94
-         self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
95
-         self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
96
-         self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
97
-
98
-         self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
99
-
100
-         self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
101
-         self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
102
-         self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
103
-
104
-         self.dropout = keras.layers.Dropout(dropout)
105
-
106
-     def call(self, x, training=None):
107
-         B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
108
-         dtype = x.dtype
109
-
110
-         res = x
111
-         y = self.pre_attn_norm(x)
112
-
113
-         q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
114
-         k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
115
-         v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
116
-
117
-         q, k = self.rope(q, k)
118
-
119
-         scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
120
-
121
-         mask = tf.where(
122
-             tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
123
-             tf.constant(-1e9, dtype=dtype),
124
-             tf.constant(0.0, dtype=dtype)
125
-         )
126
-         scores += mask
127
-         attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
128
-
129
-         attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
130
-         x = res + self.dropout(self.out_proj(attn), training=training)
131
-
132
-         res = x
133
-         y = self.pre_ffn_norm(x)
134
-         ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
135
-
136
-         return res + self.dropout(ffn, training=training)
137
-
138
-     def get_config(self):
139
-         config = super().get_config()
140
-         config.update({
141
-             "d_model": self.d_model,
142
-             "n_heads": self.n_heads,
143
-             "ff_dim": self.ff_dim,
144
-             "dropout": self.dropout_rate,
145
-             "max_len": self.max_len,
146
-             "rope_theta": self.rope_theta,
147
-             "layer_idx": self.layer_idx
148
-         })
149
-         return config
150
 
151
 
152
  @keras.saving.register_keras_serializable()
153
  class SAM1Model(keras.Model):
154
-     def __init__(self, **kwargs):
155
-         super().__init__()
156
-         if 'config' in kwargs and isinstance(kwargs['config'], dict):
157
-             self.cfg = kwargs['config']
158
-         elif 'vocab_size' in kwargs:
159
-             self.cfg = kwargs
160
-         else:
161
-             self.cfg = kwargs.get('cfg', kwargs)
162
 
163
-         self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
164
 
165
-         ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
166
-         block_args = {
167
-             'd_model': self.cfg['d_model'],
168
-             'n_heads': self.cfg['n_heads'],
169
-             'ff_dim': ff_dim,
170
-             'dropout': self.cfg['dropout'],
171
-             'max_len': self.cfg['max_len'],
172
-             'rope_theta': self.cfg['rope_theta']
173
-         }
174
 
175
-         self.blocks = []
176
-         for i in range(self.cfg['n_layers']):
177
-             block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
178
-             self.blocks.append(block)
179
 
180
-         self.norm = RMSNorm(name="final_norm")
181
-         self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
182
 
183
-     def call(self, input_ids, training=None):
184
-         x = self.embed(input_ids)
185
 
186
-         for block in self.blocks:
187
-             x = block(x, training=training)
188
 
189
-         return self.lm_head(self.norm(x))
190
 
191
-     def get_config(self):
192
-         base_config = super().get_config()
193
-         base_config['config'] = self.cfg
194
-         return base_config
195
 
196
 
197
  # ==============================================================================
198
  # Helper Functions
199
  # ==============================================================================
200
  def count_parameters(model):
201
-     """Count total and non-zero parameters in model."""
202
-     total_params = 0
203
-     non_zero_params = 0
204
-     
205
-     for weight in model.weights:
206
-         w = weight.numpy()
207
-         total_params += w.size
208
-         non_zero_params += np.count_nonzero(w)
209
-     
210
-     return total_params, non_zero_params
211
 
212
 
213
  def format_param_count(count):
214
-     """Format parameter count in human readable format."""
215
-     if count >= 1e9:
216
-         return f"{count/1e9:.2f}B"
217
-     elif count >= 1e6:
218
-         return f"{count/1e6:.2f}M"
219
-     elif count >= 1e3:
220
-         return f"{count/1e3:.2f}K"
221
-     else:
222
-         return str(count)
223
 
224
 
225
  # ==============================================================================
226
  # Model Backend Interface
227
  # ==============================================================================
228
  class ModelBackend(ABC):
229
-     @abstractmethod
230
-     def predict(self, input_ids):
231
-         pass
232
-     
233
-     @abstractmethod
234
-     def get_name(self):
235
-         pass
236
-     
237
-     @abstractmethod
238
-     def get_info(self):
239
-         pass
240
 
241
 
242
  class KerasBackend(ModelBackend):
243
-     def __init__(self, model, name, display_name):
244
-         self.model = model
245
-         self.name = name
246
-         self.display_name = display_name
247
-         
248
-         # Count parameters
249
-         total, non_zero = count_parameters(model)
250
-         self.total_params = total
251
-         self.non_zero_params = non_zero
252
-         self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
253
-         
254
-         # Calculate actual model config for speed estimation
255
-         self.n_heads = model.cfg.get('n_heads', 0)
256
-         self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
257
-     
258
-     def predict(self, input_ids):
259
-         inputs = np.array([input_ids], dtype=np.int32)
260
-         logits = self.model(inputs, training=False)
261
-         return logits[0, -1, :].numpy()
262
-     
263
-     def get_name(self):
264
-         return self.display_name
265
-     
266
-     def get_info(self):
267
-         info = f"{self.display_name}\n"
268
-         info += f"  Total params: {format_param_count(self.total_params)}\n"
269
-         info += f"  Attention heads: {self.n_heads}\n"
270
-         info += f"  FFN dimension: {self.ff_dim}\n"
271
-         if self.sparsity > 1:
272
-             info += f"  Sparsity: {self.sparsity:.1f}%\n"
273
-         return info
274
 
275
 
276
  # ==============================================================================
277
  # EASY MODEL REGISTRY - ADD YOUR MODELS HERE!
278
  # ==============================================================================
279
  MODEL_REGISTRY = [
280
-     # Format: (display_name, repo_id, weights_filename, config_filename)
281
-     # Smaller models are ACTUALLY faster (fewer params = real speedup!)
282
-     
283
-     ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
284
-     ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
285
-     ("SAM-X-1-Mini 🚀 (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
286
-     ("SAM-X-1-Nano ⚡⚡ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
287
  ]
288
 
289
  # To add a new model, just add a new line above! Format:
@@ -307,20 +307,20 @@ tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tok
307
 
308
  # Load config
309
  with open(config_path, 'r') as f:
310
-     base_config = json.load(f)
311
 
312
  print(f"✅ Base config loaded")
313
 
314
  # Build base model config
315
  base_model_config = {
316
-     'vocab_size': base_config['vocab_size'],
317
-     'd_model': base_config['hidden_size'],
318
-     'n_heads': base_config['num_attention_heads'],
319
-     'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
320
-     'dropout': base_config.get('dropout', 0.0),
321
-     'max_len': base_config['max_position_embeddings'],
322
-     'rope_theta': base_config['rope_theta'],
323
-     'n_layers': base_config['num_hidden_layers']
324
  }
325
 
326
  # Recreate tokenizer
@@ -330,13 +330,13 @@ eos_token = ""
330
  eos_token_id = tokenizer.token_to_id(eos_token)
331
 
332
  if eos_token_id is None:
333
-     tokenizer.add_special_tokens([eos_token])
334
-     eos_token_id = tokenizer.token_to_id(eos_token)
335
 
336
  custom_tokens = ["<think>", "<think/>"]
337
  for token in custom_tokens:
338
-     if tokenizer.token_to_id(token) is None:
339
-         tokenizer.add_special_tokens([token])
340
 
341
  tokenizer.no_padding()
342
  tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
@@ -351,49 +351,49 @@ available_models = {}
351
  dummy_input = tf.zeros((1, 1), dtype=tf.int32)
352
 
353
  for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
354
-     try:
355
-         print(f"\n⏳ Loading: {display_name}")
356
-         print(f"   Repo: {repo_id}")
357
-         print(f"   Weights: {weights_filename}")
358
-         
359
-         # Download weights
360
-         weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
361
-         
362
-         # Load custom config if specified (for pruned models)
363
-         if config_filename:
364
-             print(f"   Config: {config_filename}")
365
-             custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
366
-             with open(custom_config_path, 'r') as f:
367
-                 model_config = json.load(f)
368
-             print(f"   📐 Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim")
369
-         else:
370
-             model_config = base_model_config.copy()
371
-         
372
-         # Create model with appropriate config
373
-         model = SAM1Model(**model_config)
374
-         model(dummy_input)
375
-         model.load_weights(weights_path)
376
-         model.trainable = False
377
-         
378
-         # Create backend
379
-         backend = KerasBackend(model, display_name, display_name)
380
-         available_models[display_name] = backend
381
-         
382
-         # Print stats
383
-         print(f"   ✅ Loaded successfully!")
384
-         print(f"   📊 Parameters: {format_param_count(backend.total_params)}")
385
-         print(f"   📊 Attention heads: {backend.n_heads}")
386
-         print(f"   📊 FFN dimension: {backend.ff_dim}")
387
-         
388
-     except Exception as e:
389
-         print(f"   ⚠️  Failed to load: {e}")
390
-         print(f"   Skipping {display_name}...")
391
 
392
  if not available_models:
393
-     raise RuntimeError("❌ No models loaded! Check your MODEL_REGISTRY configuration.")
394
 
395
  print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
396
- print(f"   Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
397
 
398
  current_backend = list(available_models.values())[0]
399
 
@@ -405,360 +405,360 @@ print("💡 ABOUT PRUNING & SPEED".center(80))
405
  print("="*80)
406
  print("""
407
  📌 Does pruning reduce parameter count?
408
-    YES and NO:
409
-    • Total param count stays the same (architecture unchanged)
410
-    • BUT pruned weights are set to ZERO (sparse weights)
411
-    • Active/non-zero params are reduced significantly
412
-    
413
  📌 Does pruning speed up inference?
414
-    IT DEPENDS:
415
-    • Dense operations (regular matrix multiply): NO speedup by default
416
-    • Need sparse kernels or hardware support for actual speedup
417
-    • HOWEVER: Smaller active weights = better cache utilization
418
-    • Less computation on zeros = potential speedup on some hardware
419
-    
420
  📌 What DOES speed things up reliably?
421
-    ✅ Quantization (FP16, INT8) - smaller types = faster compute
422
-    ✅ Fewer layers (layer pruning)
423
-    ✅ Smaller hidden dimensions (width reduction)
424
-    ✅ Knowledge distillation to smaller architecture
425
-    
426
  📌 Why use structured pruning then?
427
-    ✅ Reduces memory footprint (especially with sparse storage)
428
-    ✅ Can be combined with quantization for real speedups
429
-    ✅ Preserves quality better than aggressive dimension reduction
430
-    ✅ Foundation for converting to truly smaller architecture
431
  """)
432
 
433
  def generate_response_stream(prompt, temperature=0.7, backend=None):
434
-     """Generate response and yield tokens one by one for streaming."""
435
-     if backend is None:
436
-         backend = current_backend
437
-     
438
-     encoded_prompt = tokenizer.encode(prompt)
439
-     input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
440
-     generated = input_ids.copy()
441
-
442
-     current_text = ""
443
-     in_thinking = False
444
-     
445
-     # Get max_len from the backend's model config
446
-     max_len = backend.model.cfg['max_len']
447
-
448
-     for _ in range(512):
449
-         current_input = generated[-max_len:]
450
-         
451
-         # Get logits from selected backend
452
-         next_token_logits = backend.predict(current_input)
453
-
454
-         if temperature > 0:
455
-             next_token_logits = next_token_logits / temperature
456
-             top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
457
-             top_k_logits = next_token_logits[top_k_indices]
458
-             top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
459
-             top_k_probs /= top_k_probs.sum()
460
-             next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)]
461
-         else:
462
-             next_token = np.argmax(next_token_logits)
463
-
464
-         if next_token == eos_token_id:
465
-             break
466
-
467
-         generated.append(int(next_token))
468
-
469
-         new_text = tokenizer.decode(generated[len(input_ids):])
470
-         if len(new_text) > len(current_text):
471
-             new_chunk = new_text[len(current_text):]
472
-             current_text = new_text
473
-             
474
-             if "<think>" in new_chunk:
475
-                 in_thinking = True
476
-             elif "</think>" in new_chunk or "<think/>" in new_chunk:
477
-                 in_thinking = False
478
-                 
479
-             yield new_chunk, in_thinking
480
 
481
  # ==============================================================================
482
  # Gradio Interface
483
  # ==============================================================================
484
  if __name__ == "__main__":
485
-     import gradio as gr
486
-
487
-     custom_css = """
488
-     .chat-container {
489
-         height: 600px;
490
-         overflow-y: auto;
491
-         padding: 20px;
492
-         background: #ffffff;
493
-     }
494
-     
495
-     .user-message {
496
-         background: #f7f7f8;
497
-         padding: 16px;
498
-         margin: 12px 0;
499
-         border-radius: 8px;
500
-     }
501
-     
502
-     .assistant-message {
503
-         background: #ffffff;
504
-         padding: 16px;
505
-         margin: 12px 0;
506
-         border-radius: 8px;
507
-         border-left: 3px solid #10a37f;
508
-     }
509
-     
510
-     .message-content {
511
-         color: #353740;
512
-         line-height: 1.6;
513
-         font-size: 15px;
514
-     }
515
-     
516
-     .message-header {
517
-         font-weight: 600;
518
-         margin-bottom: 8px;
519
-         color: #353740;
520
-         font-size: 14px;
521
-     }
522
-     
523
-     .thinking-content {
524
-         color: #6b7280;
525
-         font-style: italic;
526
-         border-left: 3px solid #d1d5db;
527
-         padding-left: 12px;
528
-         margin: 8px 0;
529
-         background: #f9fafb;
530
-         padding: 8px 12px;
531
-         border-radius: 4px;
532
-     }
533
-     
534
-     .input-row {
535
-         background: #ffffff;
536
-         padding: 12px;
537
-         border-radius: 8px;
538
-         margin-top: 12px;
539
-         border: 1px solid #e5e7eb;
540
-     }
541
-     
542
-     .gradio-container {
543
-         max-width: 900px !important;
544
-         margin: auto !important;
545
-     }
546
-     
547
-     .announcement-banner {
548
-         background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
549
-         color: white;
550
-         padding: 16px 24px;
551
-         border-radius: 12px;
552
-         margin-bottom: 20px;
553
-         box-shadow: 0 4px 6px rgba(0,0,0,0.1);
554
-         text-align: center;
555
-         font-size: 16px;
556
-         font-weight: 500;
557
-         animation: slideIn 0.5s ease-out;
558
-     }
559
-     
560
-     @keyframes slideIn {
561
-         from {
562
-             opacity: 0;
563
-             transform: translateY(-20px);
564
-         }
565
-         to {
566
-             opacity: 1;
567
-             transform: translateY(0);
568
-         }
569
-     }
570
-     
571
-     .announcement-banner strong {
572
-         font-weight: 700;
573
-         font-size: 18px;
574
-     }
575
-     
576
-     .settings-panel {
577
-         background: #f9fafb;
578
-         padding: 16px;
579
-         border-radius: 8px;
580
-         margin-bottom: 12px;
581
-         border: 1px solid #e5e7eb;
582
-     }
583
-     
584
-     .model-info {
585
-         background: #f0f9ff;
586
-         border: 1px solid #bae6fd;
587
-         padding: 12px;
588
-         border-radius: 8px;
589
-         margin-top: 8px;
590
-         font-size: 13px;
591
-         font-family: monospace;
592
-         white-space: pre-line;
593
-     }
594
-     """
595
-
596
-     def format_message_html(role, content, show_thinking=True):
597
-         """Format a single message as HTML."""
598
-         role_class = "user-message" if role == "user" else "assistant-message"
599
-         role_name = "You" if role == "user" else "SAM-X-1"
600
-         
601
-         thinking = ""
602
-         answer = ""
603
-         
604
-         if "<think>" in content:
605
-             parts = content.split("<think>", 1)
606
-             before_think = parts[0].strip()
607
-             
608
-             if len(parts) > 1:
609
-                 after_think = parts[1]
610
-                 
611
-                 if "</think>" in after_think:
612
-                     think_parts = after_think.split("</think>", 1)
613
-                     thinking = think_parts[0].strip()
614
-                     answer = (before_think + " " + think_parts[1]).strip()
615
-                 elif "<think/>" in after_think:
616
-                     think_parts = after_think.split("<think/>", 1)
617
-                     thinking = think_parts[0].strip()
618
-                     answer = (before_think + " " + think_parts[1]).strip()
619
-                 else:
620
-                     thinking = after_think.strip()
621
-                     answer = before_think
622
-             else:
623
-                 answer = before_think
624
-         else:
625
-             answer = content
626
-         
627
-         html = f'<div class="{role_class}">'
628
-         html += f'<div class="message-header">{role_name}</div>'
629
-         html += f'<div class="message-content">'
630
-         
631
-         if thinking and show_thinking:
632
-             html += f'<div class="thinking-content">💭 {thinking}</div>'
633
-         
634
-         if answer:
635
-             html += f'<div>{answer}</div>'
636
-         
637
-         html += '</div></div>'
638
-         return html
639
-
640
-     def render_history(history, show_thinking):
641
-         """Render chat history as HTML."""
642
-         html = ""
643
-         for msg in history:
644
-             html += format_message_html(msg["role"], msg["content"], show_thinking)
645
-         return html
646
-
647
-     def send_message(message, history, show_thinking, temperature, model_choice):
648
-         if not message.strip():
649
-             yield history, "", render_history(history, show_thinking), ""
650
-             return
651
-         
652
-         # Switch backend based on selection
653
-         backend = available_models[model_choice]
654
-         
655
-         # Add user message
656
-         history.append({"role": "user", "content": message})
657
-         yield history, "", render_history(history, show_thinking), backend.get_info()
658
-         
659
-         # Generate prompt
660
-         prompt = f"User: {message}\nSam:   <think>"
661
-         
662
-         # Start assistant message
663
-         history.append({"role": "assistant", "content": "<think>"})
664
-         
665
-         # Stream response
666
-         for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend):
667
-             history[-1]["content"] += new_chunk
668
-             yield history, "", render_history(history, show_thinking), backend.get_info()
669
-
670
-     # Create Gradio interface
671
-     with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
672
-         # Announcement Banner
673
-         gr.HTML("""
674
-         <div class="announcement-banner">
675
-             🎉 <strong>NEW UPDATE:</strong> Multiple model variants now available! 
676
-             Choose Fast/Mini/Nano for <strong>30-250% speed boost</strong>! ⚡
677
-             The models marked with (BETA) are not useful yet. <strong>They are still in development!</strong>
678
-         </div>
679
-         """)
680
-         
681
-         gr.Markdown("# 🤖 SAM-X-1 Multi-Model Chat")
682
-         
683
-         # Settings panel
684
-         with gr.Accordion("⚙️ Settings", open=False):
685
-             with gr.Row():
686
-                 model_selector = gr.Dropdown(
687
-                     choices=list(available_models.keys()),
688
-                     value=list(available_models.keys())[0],
689
-                     label="Model Selection",
690
-                     info="Choose your speed/quality tradeoff"
691
-                 )
692
-             
693
-             model_info_box = gr.Textbox(
694
-                 label="Selected Model Info",
695
-                 value=list(available_models.values())[0].get_info(),
696
-                 interactive=False,
697
-                 lines=4,
698
-                 elem_classes=["model-info"]
699
-             )
700
-             
701
-             with gr.Row():
702
-                 temperature_slider = gr.Slider(
703
-                     minimum=0.0,
704
-                     maximum=2.0,
705
-                     value=0.7,
706
-                     step=0.1,
707
-                     label="Temperature",
708
-                     info="Higher = more creative, Lower = more focused"
709
-                 )
710
-                 show_thinking_checkbox = gr.Checkbox(
711
-                     label="Show Thinking Process",
712
-                     value=True,
713
-                     info="Display model's reasoning"
714
-                 )
715
-         
716
-         # Chat state and display
717
-         chatbot_state = gr.State([])
718
-         chat_html = gr.HTML(value="", elem_classes=["chat-container"])
719
-         
720
-         # Input area
721
-         with gr.Row(elem_classes=["input-row"]):
722
-             msg_input = gr.Textbox(
723
-                 placeholder="Ask me anything...",
724
-                 show_label=False,
725
-                 container=False,
726
-                 scale=9
727
-             )
728
-             send_btn = gr.Button("Send", variant="primary", scale=1)
729
-         
730
-         with gr.Row():
731
-             clear_btn = gr.Button("🗑️ Clear Chat", size="sm")
732
-         
733
-         # Event handlers
734
-         msg_input.submit(
735
-             send_message,
736
-             inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
737
-             outputs=[chatbot_state, msg_input, chat_html, model_info_box]
738
-         )
739
-         
740
-         send_btn.click(
741
-             send_message,
742
-             inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
743
-             outputs=[chatbot_state, msg_input, chat_html, model_info_box]
744
-         )
745
-         
746
-         clear_btn.click(
747
-             lambda: ([], ""),
748
-             outputs=[chatbot_state, chat_html]
749
-         )
750
-         
751
-         show_thinking_checkbox.change(
752
-             lambda h, st: render_history(h, st),
753
-             inputs=[chatbot_state, show_thinking_checkbox],
754
-             outputs=[chat_html]
755
-         )
756
-         
757
-         # Update model info when selection changes
758
-         model_selector.change(
759
-             lambda choice: available_models[choice].get_info(),
760
-             inputs=[model_selector],
761
-             outputs=[model_info_box]
762
-         )
763
-
764
-     demo.launch(debug=True, share=True)
 
15
  # ==============================================================================
16
  @keras.saving.register_keras_serializable()
17
  class RotaryEmbedding(keras.layers.Layer):
18
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.dim = dim
21
+ self.max_len = max_len
22
+ self.theta = theta
23
+ self.built_cache = False
24
+
25
+ def build(self, input_shape):
26
+ if not self.built_cache:
27
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
28
+ t = tf.range(self.max_len, dtype=tf.float32)
29
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
30
+ emb = tf.concat([freqs, freqs], axis=-1)
31
+
32
+ self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
33
+ self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
34
+ self.built_cache = True
35
+ super().build(input_shape)
36
+
37
+ def rotate_half(self, x):
38
+ x1, x2 = tf.split(x, 2, axis=-1)
39
+ return tf.concat([-x2, x1], axis=-1)
40
+
41
+ def call(self, q, k):
42
+ seq_len = tf.shape(q)[2]
43
+ dtype = q.dtype
44
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
45
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
46
+
47
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
48
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
49
+
50
+ return q_rotated, k_rotated
51
+
52
+ def get_config(self):
53
+ config = super().get_config()
54
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
55
+ return config
56
 
57
 
58
  @keras.saving.register_keras_serializable()
59
  class RMSNorm(keras.layers.Layer):
60
+ def __init__(self, epsilon=1e-5, **kwargs):
61
+ super().__init__(**kwargs)
62
+ self.epsilon = epsilon
63
 
64
+ def build(self, input_shape):
65
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
66
 
67
+ def call(self, x):
68
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
69
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
70
 
71
+ def get_config(self):
72
+ config = super().get_config()
73
+ config.update({"epsilon": self.epsilon})
74
+ return config
75
 
76
 
77
  @keras.saving.register_keras_serializable()
78
  class TransformerBlock(keras.layers.Layer):
79
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
80
+ super().__init__(**kwargs)
81
+ self.d_model = d_model
82
+ self.n_heads = n_heads
83
+ self.ff_dim = ff_dim
84
+ self.dropout_rate = dropout
85
+ self.max_len = max_len
86
+ self.rope_theta = rope_theta
87
+ self.head_dim = d_model // n_heads
88
+ self.layer_idx = layer_idx
89
+
90
+ self.pre_attn_norm = RMSNorm()
91
+ self.pre_ffn_norm = RMSNorm()
92
+
93
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
94
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
95
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
96
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
97
+
98
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
99
+
100
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
101
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
102
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
103
+
104
+ self.dropout = keras.layers.Dropout(dropout)
105
+
106
+ def call(self, x, training=None):
107
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
108
+ dtype = x.dtype
109
+
110
+ res = x
111
+ y = self.pre_attn_norm(x)
112
+
113
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
114
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
115
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
116
+
117
+ q, k = self.rope(q, k)
118
+
119
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
120
+
121
+ mask = tf.where(
122
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
123
+ tf.constant(-1e9, dtype=dtype),
124
+ tf.constant(0.0, dtype=dtype)
125
+ )
126
+ scores += mask
127
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
128
+
129
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
130
+ x = res + self.dropout(self.out_proj(attn), training=training)
131
+
132
+ res = x
133
+ y = self.pre_ffn_norm(x)
134
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
135
+
136
+ return res + self.dropout(ffn, training=training)
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update({
141
+ "d_model": self.d_model,
142
+ "n_heads": self.n_heads,
143
+ "ff_dim": self.ff_dim,
144
+ "dropout": self.dropout_rate,
145
+ "max_len": self.max_len,
146
+ "rope_theta": self.rope_theta,
147
+ "layer_idx": self.layer_idx
148
+ })
149
+ return config
150
 
151
 
152
  @keras.saving.register_keras_serializable()
153
  class SAM1Model(keras.Model):
154
+ def __init__(self, **kwargs):
155
+ super().__init__()
156
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
157
+ self.cfg = kwargs['config']
158
+ elif 'vocab_size' in kwargs:
159
+ self.cfg = kwargs
160
+ else:
161
+ self.cfg = kwargs.get('cfg', kwargs)
162
 
163
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
164
 
165
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
166
+ block_args = {
167
+ 'd_model': self.cfg['d_model'],
168
+ 'n_heads': self.cfg['n_heads'],
169
+ 'ff_dim': ff_dim,
170
+ 'dropout': self.cfg['dropout'],
171
+ 'max_len': self.cfg['max_len'],
172
+ 'rope_theta': self.cfg['rope_theta']
173
+ }
174
 
175
+ self.blocks = []
176
+ for i in range(self.cfg['n_layers']):
177
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
178
+ self.blocks.append(block)
179
 
180
+ self.norm = RMSNorm(name="final_norm")
181
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
182
 
183
+ def call(self, input_ids, training=None):
184
+ x = self.embed(input_ids)
185
 
186
+ for block in self.blocks:
187
+ x = block(x, training=training)
188
 
189
+ return self.lm_head(self.norm(x))
190
 
191
+ def get_config(self):
192
+ base_config = super().get_config()
193
+ base_config['config'] = self.cfg
194
+ return base_config
195
 
196
 
197
  # ==============================================================================
198
  # Helper Functions
199
  # ==============================================================================
200
  def count_parameters(model):
201
+ """Count total and non-zero parameters in model."""
202
+ total_params = 0
203
+ non_zero_params = 0
204
+
205
+ for weight in model.weights:
206
+ w = weight.numpy()
207
+ total_params += w.size
208
+ non_zero_params += np.count_nonzero(w)
209
+
210
+ return total_params, non_zero_params
211
 
212
 
213
  def format_param_count(count):
214
+ """Format parameter count in human readable format."""
215
+ if count >= 1e9:
216
+ return f"{count/1e9:.2f}B"
217
+ elif count >= 1e6:
218
+ return f"{count/1e6:.2f}M"
219
+ elif count >= 1e3:
220
+ return f"{count/1e3:.2f}K"
221
+ else:
222
+ return str(count)
223
 
224
 
225
  # ==============================================================================
226
  # Model Backend Interface
227
  # ==============================================================================
228
  class ModelBackend(ABC):
229
+ @abstractmethod
230
+ def predict(self, input_ids):
231
+ pass
232
+
233
+ @abstractmethod
234
+ def get_name(self):
235
+ pass
236
+
237
+ @abstractmethod
238
+ def get_info(self):
239
+ pass
240
 
241
 
242
  class KerasBackend(ModelBackend):
243
+ def __init__(self, model, name, display_name):
244
+ self.model = model
245
+ self.name = name
246
+ self.display_name = display_name
247
+
248
+ # Count parameters
249
+ total, non_zero = count_parameters(model)
250
+ self.total_params = total
251
+ self.non_zero_params = non_zero
252
+ self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
253
+
254
+ # Calculate actual model config for speed estimation
255
+ self.n_heads = model.cfg.get('n_heads', 0)
256
+ self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
257
+
258
+ def predict(self, input_ids):
259
+ inputs = np.array([input_ids], dtype=np.int32)
260
+ logits = self.model(inputs, training=False)
261
+ return logits[0, -1, :].numpy()
262
+
263
+ def get_name(self):
264
+ return self.display_name
265
+
266
+ def get_info(self):
267
+ info = f"{self.display_name}\n"
268
+ info += f" Total params: {format_param_count(self.total_params)}\n"
269
+ info += f" Attention heads: {self.n_heads}\n"
270
+ info += f" FFN dimension: {self.ff_dim}\n"
271
+ if self.sparsity > 1:
272
+ info += f" Sparsity: {self.sparsity:.1f}%\n"
273
+ return info
274
 
275
 
276
  # ==============================================================================
277
  # EASY MODEL REGISTRY - ADD YOUR MODELS HERE!
278
  # ==============================================================================
279
  MODEL_REGISTRY = [
280
+ # Format: (display_name, repo_id, weights_filename, config_filename)
281
+ # Smaller models are ACTUALLY faster (fewer params = real speedup!)
282
+
283
+ ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
284
+ ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
285
+ ("SAM-X-1-Mini 🚀 (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
286
+ ("SAM-X-1-Nano ⚡⚡ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
287
  ]
288
 
289
  # To add a new model, just add a new line above! Format:
 
307
 
308
  # Load config
309
  with open(config_path, 'r') as f:
310
+ base_config = json.load(f)
311
 
312
  print(f"✅ Base config loaded")
313
 
314
  # Build base model config
315
  base_model_config = {
316
+ 'vocab_size': base_config['vocab_size'],
317
+ 'd_model': base_config['hidden_size'],
318
+ 'n_heads': base_config['num_attention_heads'],
319
+ 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
320
+ 'dropout': base_config.get('dropout', 0.0),
321
+ 'max_len': base_config['max_position_embeddings'],
322
+ 'rope_theta': base_config['rope_theta'],
323
+ 'n_layers': base_config['num_hidden_layers']
324
  }
325
 
326
  # Recreate tokenizer
 
330
  eos_token_id = tokenizer.token_to_id(eos_token)
331
 
332
  if eos_token_id is None:
333
+ tokenizer.add_special_tokens([eos_token])
334
+ eos_token_id = tokenizer.token_to_id(eos_token)
335
 
336
  custom_tokens = ["<think>", "<think/>"]
337
  for token in custom_tokens:
338
+ if tokenizer.token_to_id(token) is None:
339
+ tokenizer.add_special_tokens([token])
340
 
341
  tokenizer.no_padding()
342
  tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
 
351
  dummy_input = tf.zeros((1, 1), dtype=tf.int32)
352
 
353
  for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
354
+ try:
355
+ print(f"\n⏳ Loading: {display_name}")
356
+ print(f" Repo: {repo_id}")
357
+ print(f" Weights: {weights_filename}")
358
+
359
+ # Download weights
360
+ weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
361
+
362
+ # Load custom config if specified (for pruned models)
363
+ if config_filename:
364
+ print(f" Config: {config_filename}")
365
+ custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
366
+ with open(custom_config_path, 'r') as f:
367
+ model_config = json.load(f)
368
+ print(f" 📐 Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim")
369
+ else:
370
+ model_config = base_model_config.copy()
371
+
372
+ # Create model with appropriate config
373
+ model = SAM1Model(**model_config)
374
+ model(dummy_input)
375
+ model.load_weights(weights_path)
376
+ model.trainable = False
377
+
378
+ # Create backend
379
+ backend = KerasBackend(model, display_name, display_name)
380
+ available_models[display_name] = backend
381
+
382
+ # Print stats
383
+ print(f" ✅ Loaded successfully!")
384
+ print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
385
+ print(f" 📊 Attention heads: {backend.n_heads}")
386
+ print(f" 📊 FFN dimension: {backend.ff_dim}")
387
+
388
+ except Exception as e:
389
+ print(f" ⚠️ Failed to load: {e}")
390
+ print(f" Skipping {display_name}...")
391
 
392
  if not available_models:
393
+ raise RuntimeError("❌ No models loaded! Check your MODEL_REGISTRY configuration.")
394
 
395
  print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
396
+ print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
397
 
398
  current_backend = list(available_models.values())[0]
399
 
 
405
  print("="*80)
406
  print("""
407
  📌 Does pruning reduce parameter count?
408
+ YES and NO:
409
+ • Total param count stays the same (architecture unchanged)
410
+ • BUT pruned weights are set to ZERO (sparse weights)
411
+ • Active/non-zero params are reduced significantly
412
+
413
  📌 Does pruning speed up inference?
414
+ IT DEPENDS:
415
+ • Dense operations (regular matrix multiply): NO speedup by default
416
+ • Need sparse kernels or hardware support for actual speedup
417
+ • HOWEVER: Smaller active weights = better cache utilization
418
+ • Less computation on zeros = potential speedup on some hardware
419
+
420
  📌 What DOES speed things up reliably?
421
+ ✅ Quantization (FP16, INT8) - smaller types = faster compute
422
+ ✅ Fewer layers (layer pruning)
423
+ ✅ Smaller hidden dimensions (width reduction)
424
+ ✅ Knowledge distillation to smaller architecture
425
+
426
  📌 Why use structured pruning then?
427
+ ✅ Reduces memory footprint (especially with sparse storage)
428
+ ✅ Can be combined with quantization for real speedups
429
+ ✅ Preserves quality better than aggressive dimension reduction
430
+ ✅ Foundation for converting to truly smaller architecture
431
  """)
432
 
433
  def generate_response_stream(prompt, temperature=0.7, backend=None):
434
+ """Generate response and yield tokens one by one for streaming."""
435
+ if backend is None:
436
+ backend = current_backend
437
+
438
+ encoded_prompt = tokenizer.encode(prompt)
439
+ input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
440
+ generated = input_ids.copy()
441
+
442
+ current_text = ""
443
+ in_thinking = False
444
+
445
+ # Get max_len from the backend's model config
446
+ max_len = backend.model.cfg['max_len']
447
+
448
+ for _ in range(512):
449
+ current_input = generated[-max_len:]
450
+
451
+ # Get logits from selected backend
452
+ next_token_logits = backend.predict(current_input)
453
+
454
+ if temperature > 0:
455
+ next_token_logits = next_token_logits / temperature
456
+ top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
457
+ top_k_logits = next_token_logits[top_k_indices]
458
+ top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
459
+ top_k_probs /= top_k_probs.sum()
460
+ next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)]
461
+ else:
462
+ next_token = np.argmax(next_token_logits)
463
+
464
+ if next_token == eos_token_id:
465
+ break
466
+
467
+ generated.append(int(next_token))
468
+
469
+ new_text = tokenizer.decode(generated[len(input_ids):])
470
+ if len(new_text) > len(current_text):
471
+ new_chunk = new_text[len(current_text):]
472
+ current_text = new_text
473
+
474
+ if "<think>" in new_chunk:
475
+ in_thinking = True
476
+ elif "</think>" in new_chunk or "<think/>" in new_chunk:
477
+ in_thinking = False
478
+
479
+ yield new_chunk, in_thinking
480
 
481
  # ==============================================================================
482
  # Gradio Interface
483
  # ==============================================================================
484
  if __name__ == "__main__":
485
+ import gradio as gr
486
+
487
+ custom_css = """
488
+ .chat-container {
489
+ height: 600px;
490
+ overflow-y: auto;
491
+ padding: 20px;
492
+ background: #ffffff;
493
+ }
494
+
495
+ .user-message {
496
+ background: #f7f7f8;
497
+ padding: 16px;
498
+ margin: 12px 0;
499
+ border-radius: 8px;
500
+ }
501
+
502
+ .assistant-message {
503
+ background: #ffffff;
504
+ padding: 16px;
505
+ margin: 12px 0;
506
+ border-radius: 8px;
507
+ border-left: 3px solid #10a37f;
508
+ }
509
+
510
+ .message-content {
511
+ color: #353740;
512
+ line-height: 1.6;
513
+ font-size: 15px;
514
+ }
515
+
516
+ .message-header {
517
+ font-weight: 600;
518
+ margin-bottom: 8px;
519
+ color: #353740;
520
+ font-size: 14px;
521
+ }
522
+
523
+ .thinking-content {
524
+ color: #6b7280;
525
+ font-style: italic;
526
+ border-left: 3px solid #d1d5db;
527
+ padding-left: 12px;
528
+ margin: 8px 0;
529
+ background: #f9fafb;
530
+ padding: 8px 12px;
531
+ border-radius: 4px;
532
+ }
533
+
534
+ .input-row {
535
+ background: #ffffff;
536
+ padding: 12px;
537
+ border-radius: 8px;
538
+ margin-top: 12px;
539
+ border: 1px solid #e5e7eb;
540
+ }
541
+
542
+ .gradio-container {
543
+ max-width: 900px !important;
544
+ margin: auto !important;
545
+ }
546
+
547
+ .announcement-banner {
548
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
549
+ color: white;
550
+ padding: 16px 24px;
551
+ border-radius: 12px;
552
+ margin-bottom: 20px;
553
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
554
+ text-align: center;
555
+ font-size: 16px;
556
+ font-weight: 500;
557
+ animation: slideIn 0.5s ease-out;
558
+ }
559
+
560
+ @keyframes slideIn {
561
+ from {
562
+ opacity: 0;
563
+ transform: translateY(-20px);
564
+ }
565
+ to {
566
+ opacity: 1;
567
+ transform: translateY(0);
568
+ }
569
+ }
570
+
571
+ .announcement-banner strong {
572
+ font-weight: 700;
573
+ font-size: 18px;
574
+ }
575
+
576
+ .settings-panel {
577
+ background: #f9fafb;
578
+ padding: 16px;
579
+ border-radius: 8px;
580
+ margin-bottom: 12px;
581
+ border: 1px solid #e5e7eb;
582
+ }
583
+
584
+ .model-info {
585
+ background: #f0f9ff;
586
+ border: 1px solid #bae6fd;
587
+ padding: 12px;
588
+ border-radius: 8px;
589
+ margin-top: 8px;
590
+ font-size: 13px;
591
+ font-family: monospace;
592
+ white-space: pre-line;
593
+ }
594
+ """
595
+
596
+ def format_message_html(role, content, show_thinking=True):
597
+ """Format a single message as HTML."""
598
+ role_class = "user-message" if role == "user" else "assistant-message"
599
+ role_name = "You" if role == "user" else "SAM-X-1"
600
+
601
+ thinking = ""
602
+ answer = ""
603
+
604
+ if "<think>" in content:
605
+ parts = content.split("<think>", 1)
606
+ before_think = parts[0].strip()
607
+
608
+ if len(parts) > 1:
609
+ after_think = parts[1]
610
+
611
+ if "</think>" in after_think:
612
+ think_parts = after_think.split("</think>", 1)
613
+ thinking = think_parts[0].strip()
614
+ answer = (before_think + " " + think_parts[1]).strip()
615
+ elif "<think/>" in after_think:
616
+ think_parts = after_think.split("<think/>", 1)
617
+ thinking = think_parts[0].strip()
618
+ answer = (before_think + " " + think_parts[1]).strip()
619
+ else:
620
+ thinking = after_think.strip()
621
+ answer = before_think
622
+ else:
623
+ answer = before_think
624
+ else:
625
+ answer = content
626
+
627
+ html = f'<div class="{role_class}">'
628
+ html += f'<div class="message-header">{role_name}</div>'
629
+ html += f'<div class="message-content">'
630
+
631
+ if thinking and show_thinking:
632
+ html += f'<div class="thinking-content">💭 {thinking}</div>'
633
+
634
+ if answer:
635
+ html += f'<div>{answer}</div>'
636
+
637
+ html += '</div></div>'
638
+ return html
639
+
640
+ def render_history(history, show_thinking):
641
+ """Render chat history as HTML."""
642
+ html = ""
643
+ for msg in history:
644
+ html += format_message_html(msg["role"], msg["content"], show_thinking)
645
+ return html
646
+
647
+ def send_message(message, history, show_thinking, temperature, model_choice):
648
+ if not message.strip():
649
+ yield history, "", render_history(history, show_thinking), ""
650
+ return
651
+
652
+ # Switch backend based on selection
653
+ backend = available_models[model_choice]
654
+
655
+ # Add user message
656
+ history.append({"role": "user", "content": message})
657
+ yield history, "", render_history(history, show_thinking), backend.get_info()
658
+
659
+ # Generate prompt
660
+ prompt = f"User: {message}\nSam: <think>"
661
+
662
+ # Start assistant message
663
+ history.append({"role": "assistant", "content": "<think>"})
664
+
665
+ # Stream response
666
+ for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend):
667
+ history[-1]["content"] += new_chunk
668
+ yield history, "", render_history(history, show_thinking), backend.get_info()
669
+
670
+ # Create Gradio interface
671
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
672
+ # Announcement Banner
673
+ gr.HTML("""
674
+ <div class="announcement-banner">
675
+ 🎉 <strong>NEW UPDATE:</strong> Multiple model variants now available!
676
+ Choose Fast/Mini/Nano for <strong>30-250% speed boost</strong>! ⚡
677
+ The models marked with (BETA) are not useful yet. <strong>They are still in development!</strong>
678
+ </div>
679
+ """)
680
+
681
+ gr.Markdown("# 🤖 SAM-X-1 Multi-Model Chat")
682
+
683
+ # Settings panel
684
+ with gr.Accordion("⚙️ Settings", open=False):
685
+ with gr.Row():
686
+ model_selector = gr.Dropdown(
687
+ choices=list(available_models.keys()),
688
+ value=list(available_models.keys())[0],
689
+ label="Model Selection",
690
+ info="Choose your speed/quality tradeoff"
691
+ )
692
+
693
+ model_info_box = gr.Textbox(
694
+ label="Selected Model Info",
695
+ value=list(available_models.values())[0].get_info(),
696
+ interactive=False,
697
+ lines=4,
698
+ elem_classes=["model-info"]
699
+ )
700
+
701
+ with gr.Row():
702
+ temperature_slider = gr.Slider(
703
+ minimum=0.0,
704
+ maximum=2.0,
705
+ value=0.7,
706
+ step=0.1,
707
+ label="Temperature",
708
+ info="Higher = more creative, Lower = more focused"
709
+ )
710
+ show_thinking_checkbox = gr.Checkbox(
711
+ label="Show Thinking Process",
712
+ value=True,
713
+ info="Display model's reasoning"
714
+ )
715
+
716
+ # Chat state and display
717
+ chatbot_state = gr.State([])
718
+ chat_html = gr.HTML(value="", elem_classes=["chat-container"])
719
+
720
+ # Input area
721
+ with gr.Row(elem_classes=["input-row"]):
722
+ msg_input = gr.Textbox(
723
+ placeholder="Ask me anything...",
724
+ show_label=False,
725
+ container=False,
726
+ scale=9
727
+ )
728
+ send_btn = gr.Button("Send", variant="primary", scale=1)
729
+
730
+ with gr.Row():
731
+ clear_btn = gr.Button("🗑️ Clear Chat", size="sm")
732
+
733
+ # Event handlers
734
+ msg_input.submit(
735
+ send_message,
736
+ inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
737
+ outputs=[chatbot_state, msg_input, chat_html, model_info_box]
738
+ )
739
+
740
+ send_btn.click(
741
+ send_message,
742
+ inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
743
+ outputs=[chatbot_state, msg_input, chat_html, model_info_box]
744
+ )
745
+
746
+ clear_btn.click(
747
+ lambda: ([], ""),
748
+ outputs=[chatbot_state, chat_html]
749
+ )
750
+
751
+ show_thinking_checkbox.change(
752
+ lambda h, st: render_history(h, st),
753
+ inputs=[chatbot_state, show_thinking_checkbox],
754
+ outputs=[chat_html]
755
+ )
756
+
757
+ # Update model info when selection changes
758
+ model_selector.change(
759
+ lambda choice: available_models[choice].get_info(),
760
+ inputs=[model_selector],
761
+ outputs=[model_info_box]
762
+ )
763
+
764
+ demo.launch(debug=True, share=True)