Keeby-smilyai commited on
Commit
49df435
Β·
verified Β·
1 Parent(s): 5479729

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -139
app.py CHANGED
@@ -7,12 +7,9 @@ NUM_CORES = os.cpu_count() or 4
7
 
8
  os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
9
  os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
10
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
11
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
12
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
13
- # NEW: Enable XLA (Accelerated Linear Algebra)
14
- os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
15
- os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
16
 
17
  import gradio as gr
18
  import tensorflow as tf
@@ -23,13 +20,11 @@ from tokenizers import Tokenizer
23
  import numpy as np
24
  import time
25
 
26
- # Configure TF for maximum performance
27
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
28
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
29
- tf.config.optimizer.set_jit(True) # Enable XLA JIT compilation
30
- tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})
31
 
32
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN + XLA enabled")
33
 
34
  # ============================================================================
35
  # 🎊 FESTIVE MODE TOGGLE 🎊
@@ -62,8 +57,6 @@ class RotaryEmbedding(keras.layers.Layer):
62
 
63
  def build(self, input_shape):
64
  super().build(input_shape)
65
- # Pre-build cache immediately
66
- self._build_cache()
67
 
68
  def _build_cache(self):
69
  if not self.built_cache:
@@ -71,19 +64,17 @@ class RotaryEmbedding(keras.layers.Layer):
71
  t = tf.range(self.max_len, dtype=tf.float32)
72
  freqs = tf.einsum("i,j->ij", t, inv_freq)
73
  emb = tf.concat([freqs, freqs], axis=-1)
74
- # Use tf.constant for faster access
75
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
76
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
77
  self.built_cache = True
78
 
79
- @tf.function(jit_compile=True)
80
  def rotate_half(self, x):
81
  x1, x2 = tf.split(x, 2, axis=-1)
82
  return tf.concat([-x2, x1], axis=-1)
83
 
84
- @tf.function(jit_compile=True)
85
  def call(self, q, k, offset=0):
86
  """Apply rotary embeddings with position offset for KV-cache."""
 
87
  seq_len = tf.shape(q)[2]
88
  dtype = q.dtype
89
 
@@ -111,7 +102,6 @@ class RMSNorm(keras.layers.Layer):
111
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
112
  super().build(input_shape)
113
 
114
- @tf.function(jit_compile=True)
115
  def call(self, x):
116
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
117
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
@@ -149,8 +139,15 @@ class TransformerBlock(keras.layers.Layer):
149
  self.dropout = keras.layers.Dropout(self.dropout_rate)
150
  super().build(input_shape)
151
 
152
- # Removed @tf.function here to allow flexible past_kv handling
153
  def call(self, x, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
154
  B = tf.shape(x)[0]
155
  T = tf.shape(x)[1]
156
  dtype = x.dtype
@@ -158,9 +155,9 @@ class TransformerBlock(keras.layers.Layer):
158
  res = x
159
  y = self.pre_attn_norm(x)
160
 
161
- # Project Q, K, V
162
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
163
- q = tf.transpose(q, [0, 2, 1, 3])
164
 
165
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
166
  k = tf.transpose(k, [0, 2, 1, 3])
@@ -168,13 +165,13 @@ class TransformerBlock(keras.layers.Layer):
168
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
169
  v = tf.transpose(v, [0, 2, 1, 3])
170
 
171
- # Position offset
172
  if past_kv is not None:
173
  past_len = tf.shape(past_kv[0])[2]
174
  else:
175
  past_len = 0
176
 
177
- # Apply RoPE
178
  q, k = self.rope(q, k, offset=past_len)
179
 
180
  # Concatenate with past KV
@@ -200,13 +197,13 @@ class TransformerBlock(keras.layers.Layer):
200
  attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
201
  attn_out = tf.reshape(attn_out, [B, T, self.d_model])
202
 
203
- x = res + self.out_proj(attn_out)
204
 
205
  # FFN
206
  res = x
207
  y = self.pre_ffn_norm(x)
208
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
209
- output = res + ffn
210
 
211
  return output, new_kv
212
 
@@ -252,8 +249,15 @@ class SAM1Model(keras.Model):
252
  self.norm = RMSNorm(name="final_norm")
253
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
254
 
255
- # Don't use @tf.function on the main call to allow flexible caching
256
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
257
  x = self.embed(input_ids)
258
 
259
  new_past_kv = [] if use_cache else None
@@ -318,11 +322,12 @@ if use_checkpoint:
318
  'n_heads': config['num_attention_heads'],
319
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
320
  'max_len': config['max_position_embeddings'],
321
- 'dropout': 0.0, # Set to 0 for inference
322
  'rope_theta': config['rope_theta']
323
  }
324
  model = SAM1Model(config=model_config)
325
 
 
326
  dummy_input = tf.zeros((1, 16), dtype=tf.int32)
327
  _ = model(dummy_input, training=False, use_cache=False)
328
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
@@ -351,37 +356,11 @@ else:
351
  if model:
352
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
353
 
354
- # ============================================================================
355
- # OPTIMIZATION: Create compiled inference functions
356
- # ============================================================================
357
-
358
- print("⚑ Compiling optimized inference functions...")
359
-
360
- # Create traced functions for prefill and decode
361
- @tf.function(
362
- input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)],
363
- jit_compile=True
364
- )
365
- def prefill_fn(input_ids):
366
- """Optimized prefill phase - processes entire prompt at once."""
367
- return model(input_ids, training=False, use_cache=True)
368
-
369
- @tf.function(
370
- input_signature=[tf.TensorSpec(shape=[1, 1], dtype=tf.int32)],
371
- jit_compile=True
372
- )
373
- def decode_fn_no_cache(input_ids):
374
- """Fast decode for single token without cache (fallback)."""
375
- logits, _ = model(input_ids, training=False, use_cache=False)
376
- return logits
377
-
378
- # Warm up with compilation
379
- print("πŸ”₯ Warming up and compiling model...")
380
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
381
- _ = prefill_fn(warmup_input)
382
- warmup_single = tf.constant([[1]], dtype=tf.int32)
383
- _ = decode_fn_no_cache(warmup_single)
384
- print("βœ… Model warmed up and compiled")
385
 
386
  # ============================================================================
387
  # Optimized Inference Logic with KV-Cache
@@ -389,50 +368,39 @@ print("βœ… Model warmed up and compiled")
389
 
390
  stop_generation = False
391
 
392
- # Pre-compile sampling helper
393
- @tf.function(jit_compile=True)
394
- def apply_temperature_tf(logits, temperature):
395
- """GPU-accelerated temperature scaling."""
396
- return logits / temperature
397
 
398
  def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
399
- """Optimized NumPy sampling."""
400
- # Temperature
401
- if temperature != 1.0:
402
- logits = logits / temperature
403
-
404
- # Repetition penalty (vectorized)
405
- if repetition_penalty != 1.0 and token_freq:
406
- penalty = np.ones_like(logits)
407
  for token_id, freq in token_freq.items():
408
- if token_id < len(logits):
409
- penalty[token_id] = repetition_penalty ** freq
410
- logits = logits / penalty
411
-
412
- # Top-K
413
- if top_k > 0 and top_k < len(logits):
414
- top_k_indices = np.argpartition(logits, -top_k)[-top_k:]
415
- top_k_logits = logits[top_k_indices]
416
  else:
417
- top_k_indices = np.arange(len(logits))
418
- top_k_logits = logits
419
 
420
- # Softmax (stable)
421
- max_logit = np.max(top_k_logits)
422
- top_k_logits = top_k_logits - max_logit
423
  top_k_probs = np.exp(top_k_logits)
424
- sum_probs = np.sum(top_k_probs)
425
- top_k_probs = top_k_probs / sum_probs
426
 
427
- # Top-P
428
  if top_p < 1.0:
429
  sorted_idx = np.argsort(top_k_probs)[::-1]
430
  cumsum = np.cumsum(top_k_probs[sorted_idx])
431
  cutoff = np.searchsorted(cumsum, top_p) + 1
432
- cutoff = min(cutoff, len(sorted_idx))
433
  nucleus_idx = sorted_idx[:cutoff]
434
  nucleus_probs = top_k_probs[nucleus_idx]
435
- nucleus_probs = nucleus_probs / np.sum(nucleus_probs)
436
  sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
437
  return int(top_k_indices[nucleus_idx[sampled]])
438
  else:
@@ -448,23 +416,23 @@ def generate_stream(
448
  top_p: float = 0.9,
449
  repetition_penalty: float = 1.1
450
  ):
451
- """Optimized generation with compiled functions."""
452
  global stop_generation
453
  stop_generation = False
454
 
455
- # Tokenize
456
  prompt_ids = tokenizer.encode(prompt).ids
457
  input_ids = [i for i in prompt_ids if i != eos_token_id]
458
 
459
  if len(input_ids) == 0:
460
- yield "Error: Empty prompt"
461
  return
462
 
463
  generated_text = ""
464
  token_count = 0
465
  token_freq = {}
466
 
467
- # Stop tokens
468
  im_end_id = tokenizer.token_to_id("<|im_end|>")
469
  model_end_id = tokenizer.token_to_id("<im end for model tun>")
470
  stop_ids = {eos_token_id, im_end_id, model_end_id}
@@ -474,81 +442,81 @@ def generate_stream(
474
 
475
  start_time = time.time()
476
 
477
- # === PREFILL (optimized) ===
 
478
  if len(input_ids) > max_context - max_tokens:
479
  input_ids = input_ids[-(max_context - max_tokens):]
480
 
481
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
482
 
483
  try:
484
- logits, past_kv = prefill_fn(input_tensor)
485
- logits_np = logits[0, -1, :].numpy()
486
  except Exception as e:
487
- yield f"Error: {e}"
488
  return
489
 
 
 
 
490
  prefill_time = time.time() - start_time
491
- print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({len(input_ids)/prefill_time:.1f} tok/s)")
492
 
493
- # === DECODE LOOP ===
494
  decode_start = time.time()
495
- decode_times = []
496
 
497
  for step in range(max_tokens):
498
  if stop_generation:
499
- yield generated_text + "\n\n*[Stopped]*"
500
  return
501
 
502
- # Sample
503
- token_start = time.time()
504
- next_token_id = sample_token(logits_np, temperature, top_k, top_p, token_freq, repetition_penalty)
505
-
 
 
506
  if next_token_id in stop_ids:
507
  break
508
 
 
509
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
510
 
511
- # Decode
512
  token_text = tokenizer.decode([next_token_id])
513
  generated_text += token_text
514
  token_count += 1
515
-
516
- # Yield every token for streaming
517
- if token_count % 1 == 0: # Stream every token
518
- yield generated_text
519
 
520
- # === OPTIMIZED DECODE ===
521
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
522
 
523
  try:
524
  logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
525
- logits_np = logits[0, -1, :].numpy()
526
  except Exception as e:
527
- yield generated_text + f"\n\n*[Error: {e}]*"
528
  return
529
-
530
- decode_times.append(time.time() - token_start)
531
-
532
- # Cache management
533
- if past_kv and len(past_kv) > 0:
534
- current_len = past_kv[0][0].shape[2]
535
- if current_len > max_context:
536
- trim = current_len - max_context + 100
537
- past_kv = [(k[:, :, trim:, :], v[:, :, trim:, :]) for k, v in past_kv]
 
 
538
 
539
- # Stats
540
  decode_time = time.time() - decode_start
541
  total_time = time.time() - start_time
542
 
543
  if token_count > 0:
544
- avg_decode_time = np.mean(decode_times) if decode_times else 0
545
  decode_tps = token_count / decode_time if decode_time > 0 else 0
 
546
 
547
  stats = (
548
- f"\n\n*[{token_count} tokens in {total_time:.1f}s | "
549
- f"Prefill: {len(input_ids)/prefill_time:.1f} t/s | "
550
- f"Decode: {decode_tps:.1f} t/s | "
551
- f"Avg: {1/avg_decode_time:.1f} t/s]*"
552
  )
553
 
554
  if not stop_generation:
@@ -558,7 +526,7 @@ def generate_stream(
558
 
559
 
560
  # ============================================================================
561
- # Chat Interface Logic (unchanged from original)
562
  # ============================================================================
563
 
564
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
@@ -567,6 +535,7 @@ def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) ->
567
  for user_msg, assistant_msg in history:
568
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
569
  if assistant_msg:
 
570
  clean_msg = assistant_msg.split("\n\n*[")[0]
571
  prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
572
 
@@ -600,6 +569,7 @@ def chat_stream(
600
  ):
601
  partial_response = generated
602
 
 
603
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
604
  earliest_stop = len(partial_response)
605
  should_stop = False
@@ -613,12 +583,14 @@ def chat_stream(
613
 
614
  display_response = partial_response
615
  if should_stop:
 
616
  stats_start = partial_response.find("\n\n*[")
617
  if stats_start > earliest_stop:
618
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
619
  else:
620
  display_response = partial_response[:earliest_stop]
621
 
 
622
  if reasoning_enabled:
623
  if '<think>' in display_response and '</think>' in display_response:
624
  start_idx = display_response.find('<think>')
@@ -650,7 +622,7 @@ def stop_gen():
650
 
651
 
652
  # ============================================================================
653
- # Gradio UI (same as original, just using optimized backend)
654
  # ============================================================================
655
 
656
  custom_css = """
@@ -725,15 +697,15 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
725
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
726
  <div class="modal-content">
727
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
728
- <p>⚑ <strong>OPTIMIZED VERSION</strong> with XLA JIT compilation for 5-10x faster CPU inference!</p>
729
  <div class="comparison-box">
730
  <div class="comparison-mode mode-reasoning">
731
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
732
- <p>Chain-of-Thought with <code>&lt;think>...&lt;/think></code> tags.</p>
733
  </div>
734
  <div class="comparison-mode mode-direct">
735
  <h3>βšͺ Direct Mode (OFF)</h3>
736
- <p>Immediate answer generation for maximum speed.</p>
737
  </div>
738
  </div>
739
  <button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
@@ -749,13 +721,13 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
749
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
750
  alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
751
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
752
- <p><strong>XLA-OPTIMIZED!</strong> Our <strong>FASTEST Reasoning Model</strong> <span class="speed-indicator">⚑ 10-30 tok/s on CPU</span></p>
753
- <div class="twin-badge">Reasoning Model + XLA JIT</div>
754
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
755
  </div>
756
  """)
757
  else:
758
- gr.HTML("""<div class="header"><h1>πŸ€– Sam-large-2 Chat</h1><p>XLA-Optimized Reasoning Model</p></div>""")
759
 
760
  with gr.Row():
761
  with gr.Column(scale=4):
@@ -793,12 +765,11 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
793
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
794
  gr.Markdown("---")
795
  gr.Markdown(f"""### 🎊 Sam-large-2 Model Info
796
- **Type:** XLA-Optimized CoT Model
797
  **Vocab:** {config['vocab_size']:,}
798
  **Layers:** {config['num_hidden_layers']}
799
  **Context:** {config['max_position_embeddings']:,} tokens
800
- **Optimization:** KV-Cache + XLA JIT ⚑
801
- **Expected Speed:** 10-30 tok/s (CPU)
802
  """)
803
 
804
  gr.Examples(
@@ -813,8 +784,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
813
 
814
  gr.HTML("""
815
  <footer>
816
- <p><strong>⚑ Sam-large-2 - XLA-OPTIMIZED RELEASE! ⚑</strong></p>
817
- <p style="font-size: 0.9rem; color: #999;">5-10x faster with XLA JIT compilation β€’ KV-Cache enabled β€’ Built by Smily studios</p>
818
  </footer>
819
  """)
820
 
@@ -874,7 +845,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
874
 
875
  if __name__ == "__main__":
876
  print("\n" + "=" * 60)
877
- print("⚑ Starting Sam-large-2 Chat (XLA-OPTIMIZED)")
878
  print("=" * 60 + "\n")
879
  demo.queue(max_size=20)
880
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
7
 
8
  os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
9
  os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
10
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
 
 
 
13
 
14
  import gradio as gr
15
  import tensorflow as tf
 
20
  import numpy as np
21
  import time
22
 
23
+ # Configure TF threading
24
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
25
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
 
 
26
 
27
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
28
 
29
  # ============================================================================
30
  # 🎊 FESTIVE MODE TOGGLE 🎊
 
57
 
58
  def build(self, input_shape):
59
  super().build(input_shape)
 
 
60
 
61
  def _build_cache(self):
62
  if not self.built_cache:
 
64
  t = tf.range(self.max_len, dtype=tf.float32)
65
  freqs = tf.einsum("i,j->ij", t, inv_freq)
66
  emb = tf.concat([freqs, freqs], axis=-1)
 
67
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
68
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
69
  self.built_cache = True
70
 
 
71
  def rotate_half(self, x):
72
  x1, x2 = tf.split(x, 2, axis=-1)
73
  return tf.concat([-x2, x1], axis=-1)
74
 
 
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
77
+ self._build_cache()
78
  seq_len = tf.shape(q)[2]
79
  dtype = q.dtype
80
 
 
102
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
103
  super().build(input_shape)
104
 
 
105
  def call(self, x):
106
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
107
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
 
139
  self.dropout = keras.layers.Dropout(self.dropout_rate)
140
  super().build(input_shape)
141
 
 
142
  def call(self, x, training=None, past_kv=None, use_cache=False):
143
+ """
144
+ Args:
145
+ x: input tensor [B, T, D] (T=1 during cached generation)
146
+ past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
147
+ use_cache: whether to return updated kv cache
148
+ Returns:
149
+ output, (new_k, new_v) if use_cache else output, None
150
+ """
151
  B = tf.shape(x)[0]
152
  T = tf.shape(x)[1]
153
  dtype = x.dtype
 
155
  res = x
156
  y = self.pre_attn_norm(x)
157
 
158
+ # Project Q, K, V for current input
159
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
160
+ q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
161
 
162
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
163
  k = tf.transpose(k, [0, 2, 1, 3])
 
165
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
166
  v = tf.transpose(v, [0, 2, 1, 3])
167
 
168
+ # Determine position offset for RoPE
169
  if past_kv is not None:
170
  past_len = tf.shape(past_kv[0])[2]
171
  else:
172
  past_len = 0
173
 
174
+ # Apply RoPE with position offset
175
  q, k = self.rope(q, k, offset=past_len)
176
 
177
  # Concatenate with past KV
 
197
  attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
198
  attn_out = tf.reshape(attn_out, [B, T, self.d_model])
199
 
200
+ x = res + self.dropout(self.out_proj(attn_out), training=training)
201
 
202
  # FFN
203
  res = x
204
  y = self.pre_ffn_norm(x)
205
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
206
+ output = res + self.dropout(ffn, training=training)
207
 
208
  return output, new_kv
209
 
 
249
  self.norm = RMSNorm(name="final_norm")
250
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
251
 
 
252
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
253
+ """
254
+ Args:
255
+ input_ids: [B, T]
256
+ past_kv: list of (k, v) tuples, one per layer
257
+ use_cache: whether to return updated cache
258
+ Returns:
259
+ logits, new_past_kv (or None)
260
+ """
261
  x = self.embed(input_ids)
262
 
263
  new_past_kv = [] if use_cache else None
 
322
  'n_heads': config['num_attention_heads'],
323
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
324
  'max_len': config['max_position_embeddings'],
325
+ 'dropout': 0.1,
326
  'rope_theta': config['rope_theta']
327
  }
328
  model = SAM1Model(config=model_config)
329
 
330
+ # Build model with dummy input
331
  dummy_input = tf.zeros((1, 16), dtype=tf.int32)
332
  _ = model(dummy_input, training=False, use_cache=False)
333
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
 
356
  if model:
357
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
358
 
359
+ # Warm up the model
360
+ print("πŸ”₯ Warming up model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
362
+ _, _ = model(warmup_input, training=False, use_cache=True)
363
+ print("βœ… Model warmed up")
 
 
364
 
365
  # ============================================================================
366
  # Optimized Inference Logic with KV-Cache
 
368
 
369
  stop_generation = False
370
 
 
 
 
 
 
371
 
372
  def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
373
+ """Pure NumPy sampling for speed."""
374
+ # Temperature scaling
375
+ scaled_logits = logits / temperature
376
+
377
+ # Repetition penalty
378
+ if repetition_penalty != 1.0:
 
 
379
  for token_id, freq in token_freq.items():
380
+ if token_id < len(scaled_logits):
381
+ scaled_logits[token_id] /= (repetition_penalty ** freq)
382
+
383
+ # Top-K filtering
384
+ if top_k > 0 and top_k < len(scaled_logits):
385
+ top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
386
+ top_k_logits = scaled_logits[top_k_indices]
 
387
  else:
388
+ top_k_indices = np.arange(len(scaled_logits))
389
+ top_k_logits = scaled_logits
390
 
391
+ # Softmax (numerically stable)
392
+ top_k_logits = top_k_logits - np.max(top_k_logits)
 
393
  top_k_probs = np.exp(top_k_logits)
394
+ top_k_probs /= top_k_probs.sum()
 
395
 
396
+ # Top-P (nucleus) filtering
397
  if top_p < 1.0:
398
  sorted_idx = np.argsort(top_k_probs)[::-1]
399
  cumsum = np.cumsum(top_k_probs[sorted_idx])
400
  cutoff = np.searchsorted(cumsum, top_p) + 1
 
401
  nucleus_idx = sorted_idx[:cutoff]
402
  nucleus_probs = top_k_probs[nucleus_idx]
403
+ nucleus_probs /= nucleus_probs.sum()
404
  sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
405
  return int(top_k_indices[nucleus_idx[sampled]])
406
  else:
 
416
  top_p: float = 0.9,
417
  repetition_penalty: float = 1.1
418
  ):
419
+ """Generate text with KV-cache for fast CPU inference."""
420
  global stop_generation
421
  stop_generation = False
422
 
423
+ # Tokenize prompt
424
  prompt_ids = tokenizer.encode(prompt).ids
425
  input_ids = [i for i in prompt_ids if i != eos_token_id]
426
 
427
  if len(input_ids) == 0:
428
+ yield "Error: Empty prompt after tokenization"
429
  return
430
 
431
  generated_text = ""
432
  token_count = 0
433
  token_freq = {}
434
 
435
+ # Get special token IDs
436
  im_end_id = tokenizer.token_to_id("<|im_end|>")
437
  model_end_id = tokenizer.token_to_id("<im end for model tun>")
438
  stop_ids = {eos_token_id, im_end_id, model_end_id}
 
442
 
443
  start_time = time.time()
444
 
445
+ # === PREFILL PHASE ===
446
+ # Truncate if prompt is too long
447
  if len(input_ids) > max_context - max_tokens:
448
  input_ids = input_ids[-(max_context - max_tokens):]
449
 
450
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
451
 
452
  try:
453
+ logits, past_kv = model(input_tensor, training=False, use_cache=True)
 
454
  except Exception as e:
455
+ yield f"Error during prefill: {e}"
456
  return
457
 
458
+ # Get logits for last position
459
+ next_token_logits = logits[0, -1, :].numpy()
460
+
461
  prefill_time = time.time() - start_time
462
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
463
 
464
+ # === GENERATION LOOP ===
465
  decode_start = time.time()
 
466
 
467
  for step in range(max_tokens):
468
  if stop_generation:
469
+ yield generated_text + "\n\n*[Generation stopped]*"
470
  return
471
 
472
+ # Sample next token
473
+ next_token_id = sample_token(
474
+ next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
475
+ )
476
+
477
+ # Stop conditions
478
  if next_token_id in stop_ids:
479
  break
480
 
481
+ # Update frequency tracking
482
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
483
 
484
+ # Decode and yield
485
  token_text = tokenizer.decode([next_token_id])
486
  generated_text += token_text
487
  token_count += 1
488
+ yield generated_text
 
 
 
489
 
490
+ # === DECODE PHASE (single token, reuse cache) ===
491
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
492
 
493
  try:
494
  logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
 
495
  except Exception as e:
496
+ yield generated_text + f"\n\n*[Error during generation: {e}]*"
497
  return
498
+
499
+ next_token_logits = logits[0, -1, :].numpy()
500
+
501
+ # Truncate cache if too long
502
+ current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
503
+ if current_len > max_context:
504
+ trim_amount = current_len - max_context + 100 # Keep some buffer
505
+ past_kv = [
506
+ (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
507
+ for k, v in past_kv
508
+ ]
509
 
 
510
  decode_time = time.time() - decode_start
511
  total_time = time.time() - start_time
512
 
513
  if token_count > 0:
 
514
  decode_tps = token_count / decode_time if decode_time > 0 else 0
515
+ total_tps = token_count / total_time if total_time > 0 else 0
516
 
517
  stats = (
518
+ f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
519
+ f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
 
 
520
  )
521
 
522
  if not stop_generation:
 
526
 
527
 
528
  # ============================================================================
529
+ # Chat Interface Logic
530
  # ============================================================================
531
 
532
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
 
535
  for user_msg, assistant_msg in history:
536
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
537
  if assistant_msg:
538
+ # Clean up any stats from previous messages
539
  clean_msg = assistant_msg.split("\n\n*[")[0]
540
  prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
541
 
 
569
  ):
570
  partial_response = generated
571
 
572
+ # Robust end-of-turn detection
573
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
574
  earliest_stop = len(partial_response)
575
  should_stop = False
 
583
 
584
  display_response = partial_response
585
  if should_stop:
586
+ # Keep the stats portion if present
587
  stats_start = partial_response.find("\n\n*[")
588
  if stats_start > earliest_stop:
589
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
590
  else:
591
  display_response = partial_response[:earliest_stop]
592
 
593
+ # Post-process reasoning tags for display
594
  if reasoning_enabled:
595
  if '<think>' in display_response and '</think>' in display_response:
596
  start_idx = display_response.find('<think>')
 
622
 
623
 
624
  # ============================================================================
625
+ # Gradio UI
626
  # ============================================================================
627
 
628
  custom_css = """
 
697
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
698
  <div class="modal-content">
699
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
700
+ <p>Our latest model features <strong>Chain-of-Thought (CoT)</strong> functionality and <strong>KV-Cache optimization</strong> for fast CPU inference!</p>
701
  <div class="comparison-box">
702
  <div class="comparison-mode mode-reasoning">
703
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
704
+ <p>The model performs a <strong>CoT step</strong> first. The internal thought process is contained within <code>&lt;think>...&lt;/think></code> tags.</p>
705
  </div>
706
  <div class="comparison-mode mode-direct">
707
  <h3>βšͺ Direct Mode (OFF)</h3>
708
+ <p>The model generates the final answer immediately, maximizing speed.</p>
709
  </div>
710
  </div>
711
  <button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
 
721
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
722
  alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
723
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
724
+ <p><strong>LATEST RELEASE!</strong> Our <strong>BEST Reasoning Model</strong> - Now with KV-Cache! <span class="speed-indicator">⚑ 5-20x Faster</span></p>
725
+ <div class="twin-badge">Reasoning Model</div>
726
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
727
  </div>
728
  """)
729
  else:
730
+ gr.HTML("""<div class="header"><h1>πŸ€– Sam-large-2 Chat</h1><p>Advanced Reasoning Model with KV-Cache</p></div>""")
731
 
732
  with gr.Row():
733
  with gr.Column(scale=4):
 
765
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
766
  gr.Markdown("---")
767
  gr.Markdown(f"""### 🎊 Sam-large-2 Model Info
768
+ **Type:** Chain-of-Thought Reasoning Model
769
  **Vocab:** {config['vocab_size']:,}
770
  **Layers:** {config['num_hidden_layers']}
771
  **Context:** {config['max_position_embeddings']:,} tokens
772
+ **Optimization:** KV-Cache enabled ⚑
 
773
  """)
774
 
775
  gr.Examples(
 
784
 
785
  gr.HTML("""
786
  <footer>
787
+ <p><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE with KV-Cache! πŸŽ‰</strong></p>
788
+ <p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β€’ Built by Smily studios with TensorFlow & Gradio</p>
789
  </footer>
790
  """)
791
 
 
845
 
846
  if __name__ == "__main__":
847
  print("\n" + "=" * 60)
848
+ print("πŸš€ Starting Sam-large-2 Chat with KV-Cache Optimization")
849
  print("=" * 60 + "\n")
850
  demo.queue(max_size=20)
851
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)