Keeby-smilyai commited on
Commit
5d1d6ad
Β·
verified Β·
1 Parent(s): 436b502

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -216
app.py CHANGED
@@ -1,42 +1,35 @@
1
  import os
2
 
3
- # === CPU Threading Optimization ===
4
- # Set these BEFORE importing TensorFlow
 
5
  NUM_CORES = os.cpu_count() or 4
6
 
7
  os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
8
  os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
 
 
 
9
 
10
- # Disable GPU (ensures CPU-only, avoids GPU detection overhead)
11
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
12
-
13
- import tensorflow as tf
14
-
15
- # Configure threading after import
16
- tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
17
- tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
18
-
19
- # Enable oneDNN optimizations (significant on Intel CPUs)
20
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
21
-
22
- # Optional: XLA JIT compilation (can help, test it)
23
- # tf.config.optimizer.set_jit(True)
24
-
25
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
26
  import gradio as gr
27
  import tensorflow as tf
28
  import keras
29
  from huggingface_hub import hf_hub_download
30
  import json
31
- import os
32
  from tokenizers import Tokenizer
33
  import numpy as np
34
  import time
35
 
 
 
 
 
 
 
36
  # ============================================================================
37
  # 🎊 FESTIVE MODE TOGGLE 🎊
38
  # ============================================================================
39
- FESTIVE = True
40
 
41
  # ============================================================================
42
  # Configuration & Model Loading
@@ -48,7 +41,7 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
48
  CACHE_DIR = "./model_cache"
49
 
50
  # ============================================================================
51
- # Model Architecture Definitions
52
  # ============================================================================
53
 
54
  @keras.saving.register_keras_serializable()
@@ -59,10 +52,12 @@ class RotaryEmbedding(keras.layers.Layer):
59
  self.max_len = max_len
60
  self.theta = theta
61
  self.built_cache = False
62
-
 
 
63
  def build(self, input_shape):
64
  super().build(input_shape)
65
-
66
  def _build_cache(self):
67
  if not self.built_cache:
68
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
@@ -72,26 +67,24 @@ class RotaryEmbedding(keras.layers.Layer):
72
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
73
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
74
  self.built_cache = True
75
-
76
  def rotate_half(self, x):
77
  x1, x2 = tf.split(x, 2, axis=-1)
78
  return tf.concat([-x2, x1], axis=-1)
79
-
80
  def call(self, q, k, offset=0):
81
  """Apply rotary embeddings with position offset for KV-cache."""
82
  self._build_cache()
83
  seq_len = tf.shape(q)[2]
84
  dtype = q.dtype
85
-
86
- # For q: positions are [offset, offset+seq_len)
87
- # For k: same positions (k is only the new tokens, past_k already has RoPE applied)
88
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
89
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
90
-
91
  q_embed = (q * cos) + (self.rotate_half(q) * sin)
92
  k_embed = (k * cos) + (self.rotate_half(k) * sin)
93
  return q_embed, k_embed
94
-
95
  def get_config(self):
96
  config = super().get_config()
97
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
@@ -103,14 +96,16 @@ class RMSNorm(keras.layers.Layer):
103
  def __init__(self, epsilon=1e-5, **kwargs):
104
  super().__init__(**kwargs)
105
  self.epsilon = epsilon
106
-
 
107
  def build(self, input_shape):
108
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
109
-
 
110
  def call(self, x):
111
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
112
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
113
-
114
  def get_config(self):
115
  config = super().get_config()
116
  config.update({"epsilon": self.epsilon})
@@ -129,19 +124,21 @@ class TransformerBlock(keras.layers.Layer):
129
  self.rope_theta = rope_theta
130
  self.head_dim = d_model // n_heads
131
  self.layer_idx = layer_idx
132
-
133
- self.pre_attn_norm = RMSNorm()
134
- self.pre_ffn_norm = RMSNorm()
135
- self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
136
- self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
137
- self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
138
- self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
139
- self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
140
- self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
141
- self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
142
- self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
143
- self.dropout = keras.layers.Dropout(dropout)
144
-
 
 
145
  def call(self, x, training=None, past_kv=None, use_cache=False):
146
  """
147
  Args:
@@ -154,69 +151,72 @@ class TransformerBlock(keras.layers.Layer):
154
  B = tf.shape(x)[0]
155
  T = tf.shape(x)[1]
156
  dtype = x.dtype
157
-
158
  res = x
159
  y = self.pre_attn_norm(x)
160
-
161
  # Project Q, K, V for current input
162
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
163
  q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
164
-
165
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
166
  k = tf.transpose(k, [0, 2, 1, 3])
167
-
168
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
169
  v = tf.transpose(v, [0, 2, 1, 3])
170
-
171
  # Determine position offset for RoPE
172
  if past_kv is not None:
173
  past_len = tf.shape(past_kv[0])[2]
174
  else:
175
  past_len = 0
176
-
177
  # Apply RoPE with position offset
178
  q, k = self.rope(q, k, offset=past_len)
179
-
180
  # Concatenate with past KV
181
  if past_kv is not None:
182
  k = tf.concat([past_kv[0], k], axis=2)
183
  v = tf.concat([past_kv[1], v], axis=2)
184
-
185
  new_kv = (k, v) if use_cache else None
186
-
187
  # Attention
188
  full_len = tf.shape(k)[2]
189
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
190
-
191
- # Causal mask: q attends to all of k (including past)
192
- # Shape: [T, full_len] where each query position can attend to positions <= its absolute position
193
  q_positions = tf.range(past_len, past_len + T)
194
  k_positions = tf.range(full_len)
195
  mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
196
  mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
197
  scores = scores + mask[None, None, :, :]
198
-
199
  attn = tf.nn.softmax(scores, axis=-1)
200
  attn_out = tf.matmul(attn, v)
201
  attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
202
  attn_out = tf.reshape(attn_out, [B, T, self.d_model])
203
-
204
  x = res + self.dropout(self.out_proj(attn_out), training=training)
205
-
206
  # FFN
207
  res = x
208
  y = self.pre_ffn_norm(x)
209
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
210
  output = res + self.dropout(ffn, training=training)
211
-
212
  return output, new_kv
213
-
214
  def get_config(self):
215
  config = super().get_config()
216
  config.update({
217
- "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
218
- "dropout": self.dropout_rate, "max_len": self.max_len,
219
- "rope_theta": self.rope_theta, "layer_idx": self.layer_idx
 
 
 
 
220
  })
221
  return config
222
 
@@ -231,13 +231,16 @@ class SAM1Model(keras.Model):
231
  self.cfg = kwargs
232
  else:
233
  self.cfg = kwargs.get('cfg', kwargs)
234
-
235
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
236
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
237
  block_args = {
238
- 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
239
- 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
240
- 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
 
 
 
241
  }
242
  self.blocks = [
243
  TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
@@ -245,7 +248,7 @@ class SAM1Model(keras.Model):
245
  ]
246
  self.norm = RMSNorm(name="final_norm")
247
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
248
-
249
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
250
  """
251
  Args:
@@ -256,22 +259,24 @@ class SAM1Model(keras.Model):
256
  logits, new_past_kv (or None)
257
  """
258
  x = self.embed(input_ids)
259
-
260
  new_past_kv = [] if use_cache else None
261
-
262
  for i, block in enumerate(self.blocks):
263
  layer_past = past_kv[i] if past_kv is not None else None
264
  x, layer_kv = block(x, training=training, past_kv=layer_past, use_cache=use_cache)
265
  if use_cache:
266
  new_past_kv.append(layer_kv)
267
-
268
  logits = self.lm_head(self.norm(x))
269
  return logits, new_past_kv
270
-
271
  def get_config(self):
272
  base_config = super().get_config()
273
  base_config['config'] = self.cfg
274
  return base_config
 
 
275
  # --- Model and Tokenizer Loading ---
276
 
277
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
@@ -287,6 +292,7 @@ except Exception as e:
287
  use_checkpoint = False
288
  except Exception as e_model:
289
  print(f"❌ Also failed to find model.keras: {e_model}")
 
290
 
291
  with open(config_path, 'r') as f:
292
  config = json.load(f)
@@ -320,37 +326,88 @@ if use_checkpoint:
320
  'rope_theta': config['rope_theta']
321
  }
322
  model = SAM1Model(config=model_config)
323
- dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
324
- _ = model(dummy_input, training=False)
 
 
325
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
 
326
  try:
327
  model.load_weights(weights_path)
328
  print("βœ… Checkpoint weights loaded successfully!")
329
  except Exception as e:
330
  print(f"❌ Failed to load checkpoint weights: {e}")
 
331
  else:
332
  print("πŸ“¦ Loading full saved model...")
333
  try:
334
- custom_objects = {'SAM1Model': SAM1Model, 'TransformerBlock': TransformerBlock, 'RMSNorm': RMSNorm, 'RotaryEmbedding': RotaryEmbedding}
 
 
 
 
 
335
  model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
336
  print("βœ… Model loaded successfully")
337
  except Exception as e:
338
  print(f"❌ Failed to load model: {e}")
 
339
 
340
  if model:
341
- print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
 
 
 
 
 
 
342
 
343
  # ============================================================================
344
- # Optimized Inference Logic (TF Functions)
345
  # ============================================================================
346
 
347
- # Define fast forward for real generation
348
- @tf.function(reduce_retracing=True)
349
- def fast_forward(input_tensor):
350
- return model(input_tensor, training=False)
351
-
352
  stop_generation = False
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  def generate_stream(
355
  prompt: str,
356
  max_tokens: int = 512,
@@ -362,118 +419,134 @@ def generate_stream(
362
  """Generate text with KV-cache for fast CPU inference."""
363
  global stop_generation
364
  stop_generation = False
365
-
 
366
  prompt_ids = tokenizer.encode(prompt).ids
367
  input_ids = [i for i in prompt_ids if i != eos_token_id]
368
-
 
 
 
 
369
  generated_text = ""
370
  token_count = 0
371
  token_freq = {}
372
-
 
 
 
 
 
 
 
 
373
  start_time = time.time()
374
-
375
  # === PREFILL PHASE ===
376
- # Process entire prompt, build initial KV cache
 
 
 
377
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
378
- logits, past_kv = model(input_tensor, training=False, use_cache=True)
379
 
 
 
 
 
 
 
380
  # Get logits for last position
381
  next_token_logits = logits[0, -1, :].numpy()
382
-
 
 
 
383
  # === GENERATION LOOP ===
 
 
384
  for step in range(max_tokens):
385
  if stop_generation:
386
  yield generated_text + "\n\n*[Generation stopped]*"
387
- break
388
-
389
- # Temperature
390
- scaled_logits = next_token_logits / temperature
391
-
392
- # Repetition penalty
393
- if repetition_penalty != 1.0:
394
- for token_id, freq in token_freq.items():
395
- if token_id < len(scaled_logits):
396
- scaled_logits[token_id] /= (repetition_penalty ** freq)
397
-
398
- # Top-K sampling
399
- if top_k > 0:
400
- top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
401
- top_k_logits = scaled_logits[top_k_indices]
402
- top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
403
- top_k_probs /= top_k_probs.sum()
404
-
405
- # Top-P (nucleus) sampling
406
- if top_p < 1.0:
407
- sorted_idx = np.argsort(top_k_probs)[::-1]
408
- cumsum = np.cumsum(top_k_probs[sorted_idx])
409
- cutoff = np.searchsorted(cumsum, top_p) + 1
410
- nucleus_idx = sorted_idx[:cutoff]
411
- nucleus_probs = top_k_probs[nucleus_idx]
412
- nucleus_probs /= nucleus_probs.sum()
413
- sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
414
- next_token_id = int(top_k_indices[nucleus_idx[sampled]])
415
- else:
416
- sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
417
- next_token_id = int(top_k_indices[sampled])
418
- else:
419
- probs = np.exp(scaled_logits - np.max(scaled_logits))
420
- probs /= probs.sum()
421
- next_token_id = np.random.choice(len(probs), p=probs)
422
-
423
  # Stop conditions
424
- if next_token_id == eos_token_id:
425
- break
426
- im_end_id = tokenizer.token_to_id("<|im_end|>")
427
- model_end_id = tokenizer.token_to_id("<im end for model tun>")
428
- if next_token_id in (im_end_id, model_end_id):
429
  break
430
-
431
  # Update frequency tracking
432
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
433
-
434
  # Decode and yield
435
  token_text = tokenizer.decode([next_token_id])
436
  generated_text += token_text
437
  token_count += 1
438
  yield generated_text
439
-
440
  # === DECODE PHASE (single token, reuse cache) ===
441
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
442
- logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
443
- next_token_logits = logits[0, -1, :].numpy()
444
 
 
 
 
 
 
 
 
 
445
  # Truncate cache if too long
446
- max_len = config['max_position_embeddings']
447
- if past_kv[0][0].shape[2] > max_len:
448
- past_kv = [(k[:, :, -max_len:, :], v[:, :, -max_len:, :]) for k, v in past_kv]
449
-
450
- elapsed = time.time() - start_time
451
- tps = token_count / elapsed if elapsed > 0 else 0
452
-
453
- if token_count > 0 and not stop_generation:
454
- generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tps:.1f} tok/s)]*"
 
455
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  yield generated_text
 
 
457
  # ============================================================================
458
  # Chat Interface Logic
459
  # ============================================================================
460
 
461
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
462
- """Format message history and SEED <think> if enabled"""
463
  prompt = ""
464
  for user_msg, assistant_msg in history:
465
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
466
  if assistant_msg:
467
- prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
468
-
 
 
469
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
470
-
471
- # 🧠 REAL REASONING: Just add the tag. The model will do the rest.
472
  if reasoning_enabled:
473
  prompt += "<think>"
474
-
475
  return prompt
476
 
 
477
  def chat_stream(
478
  message: str,
479
  history: list,
@@ -487,59 +560,67 @@ def chat_stream(
487
  if not message.strip():
488
  yield history
489
  return
490
-
491
  prompt = format_chat_prompt(message, history, reasoning_enabled)
492
  partial_response = ""
493
-
494
- # ⚑ NO FAKE REASONING HERE. We trust the model.
495
-
496
  for generated in generate_stream(
497
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
498
  ):
499
  partial_response = generated
500
-
501
- # Robust End-of-Turn Detection
502
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
503
  earliest_stop = len(partial_response)
504
  should_stop = False
505
 
506
  for tag in stop_tags:
507
  if tag in partial_response:
508
- earliest_stop = min(earliest_stop, partial_response.find(tag))
509
- should_stop = True
510
-
 
 
 
511
  if should_stop:
512
- partial_response = partial_response[:earliest_stop]
 
 
 
 
 
513
 
514
- # Post-process reasoning tags for display (Collapsing the REAL thought)
515
  if reasoning_enabled:
516
- if '<think>' in partial_response and '</think>' in partial_response:
517
- start_idx = partial_response.find('<think>')
518
- end_idx = partial_response.find('</think>')
519
  if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
520
- thought_content = partial_response[start_idx + len('<think>'):end_idx].strip()
521
-
522
- # Safe formatting outside f-string
523
  formatted_thought = thought_content.replace("\n", "<br>")
524
-
525
  details_html = (
526
  f'<details class="reasoning-block">'
527
- f'<summary>Model Reasoning (Click to show/hide)</summary>'
528
  f'<p>{formatted_thought}</p>'
529
  f'</details>'
530
  )
531
- partial_response = partial_response[:start_idx] + details_html + partial_response[end_idx + len('</think>'):]
532
- elif start_idx != -1 and end_idx == -1:
533
- # Model is currently thinking...
534
- partial_response = partial_response.replace('<think>', '**Thinking:** ')
535
-
536
- yield history + [[message, partial_response.strip()]]
 
 
 
 
537
 
538
  def stop_gen():
539
  global stop_generation
540
  stop_generation = True
541
  return None
542
 
 
543
  # ============================================================================
544
  # Gradio UI
545
  # ============================================================================
@@ -582,7 +663,7 @@ footer { text-align: center; padding: 2rem; color: #666; border-top: 1px solid #
582
  .gradio-html details.reasoning-block p { margin-top: 5px; padding-left: 10px; border-left: 1px dashed #ccc; white-space: pre-wrap; }
583
  .modal-overlay {
584
  position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: rgba(0, 0, 0, 0.7);
585
- display: flex; justify-content: center; align-items: center; z-index: 1000;
586
  }
587
  .modal-content {
588
  background: white; padding: 30px; border-radius: 15px; width: 90%; max-width: 900px;
@@ -601,25 +682,26 @@ footer { text-align: center; padding: 2rem; color: #666; border-top: 1px solid #
601
  border: none; border-radius: 8px; cursor: pointer; font-size: 1rem; transition: background-color 0.3s;
602
  }
603
  .close-btn:hover { background-color: #5d3a84; }
 
 
 
 
 
604
  """
605
 
606
- festive_css = custom_css
607
- custom_css = festive_css
608
-
609
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
610
- reasoning_enabled = gr.State(False)
611
- modal_shown = gr.State(False)
612
-
613
  welcome_modal_html = gr.HTML(
614
  """
615
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
616
  <div class="modal-content">
617
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
618
- <p>Our latest model, **Sam-large-2**, features **Chain-of-Thought (CoT)** functionality. You can toggle this feature using the πŸ’‘ button next to the input field.</p>
619
  <div class="comparison-box">
620
  <div class="comparison-mode mode-reasoning">
621
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
622
- <p>The model performs a **CoT step** first. The internal thought process is contained within the <code>&lt;think>...&lt;/think></code> tags.</p>
623
  </div>
624
  <div class="comparison-mode mode-direct">
625
  <h3>βšͺ Direct Mode (OFF)</h3>
@@ -636,35 +718,44 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
636
  gr.HTML("""
637
  <div class="header">
638
  <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
639
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
640
  alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
641
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
642
- <p><strong>LATEST RELEASE!</strong> Our **BEST Reasoning Model** - Full Chain-of-Thought!</p>
643
  <div class="twin-badge">Reasoning Model</div>
644
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
645
  </div>
646
  """)
647
  else:
648
- gr.HTML("""<div class="header"><h1>πŸ€– Sam-large-2 Chat</h1><p>Advanced Reasoning Model</p></div>""")
649
 
650
  with gr.Row():
651
  with gr.Column(scale=4):
652
  chatbot = gr.Chatbot(
653
- height=600, show_label=False,
654
- avatar_images=(None, "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"),
 
 
 
 
655
  bubble_full_width=False
656
  )
657
  with gr.Row():
658
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
659
- reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
660
  gr.HTML('<span class="new-tag-red">NEW</span>')
661
- msg = gr.Textbox(placeholder="Type your message here...", show_label=False, scale=8, container=False)
 
 
 
 
 
662
  submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
663
  stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
664
  with gr.Row():
665
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
666
  retry_btn = gr.Button("πŸ”„ Retry", size="sm")
667
-
668
  with gr.Column(scale=1):
669
  gr.Markdown("### βš™οΈ Generation Settings")
670
  max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=50, label="Max Tokens")
@@ -674,20 +765,30 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
674
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
675
  gr.Markdown("---")
676
  gr.Markdown(f"""### 🎊 Sam-large-2 Model Info
677
- **Type:** Chain-of-Thought Reasoning Model
678
- **Vocab:** {config['vocab_size']}
679
- **Reasoning:** Full CoT support (uses **<think>** tags)
680
- """)
 
 
 
 
 
 
 
 
 
 
 
 
681
 
682
- gr.Examples(examples=["Explain quantum computing", "Write a short poem about AI", "Solve 24*12 with reasoning"], inputs=msg)
683
-
684
  gr.HTML("""
685
- <footer>
686
- <p><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE! πŸŽ‰</strong></p>
687
- <p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β€’ Built by Smily studios with TensorFlow & Gradio</p>
688
- </footer>
689
- """)
690
-
691
  def show_modal_js():
692
  return """
693
  (function() {
@@ -697,31 +798,54 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
697
  }
698
  })();
699
  """
 
700
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
701
 
702
  def toggle_reasoning(current_state):
703
  new_state = not current_state
704
  return new_state, gr.update(elem_classes="" if new_state else "off")
705
 
706
- reasoning_btn.click(fn=toggle_reasoning, inputs=[reasoning_enabled], outputs=[reasoning_enabled, reasoning_btn], preprocess=False)
 
 
 
 
 
707
 
708
  common_inputs = [msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled]
709
-
710
- submit_event = msg.submit(chat_stream, inputs=common_inputs, outputs=[chatbot]).then(lambda: "", outputs=[msg])
711
- click_event = submit_btn.click(chat_stream, inputs=common_inputs, outputs=[chatbot]).then(lambda: "", outputs=[msg])
712
-
 
 
 
 
 
 
 
 
 
713
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[submit_event, click_event])
714
  clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
715
-
716
  def retry_last(history, max_tok, temp, topk, topp, rep_pen, reasoning_en):
717
- if not history: return history
 
718
  last_user_msg = history[-1][0]
719
  for update in chat_stream(last_user_msg, history[:-1], max_tok, temp, topk, topp, rep_pen, reasoning_en):
720
  yield update
721
-
722
- retry_event = retry_btn.click(retry_last, inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled], outputs=[chatbot])
 
 
 
 
723
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[retry_event])
724
 
725
  if __name__ == "__main__":
 
 
 
726
  demo.queue(max_size=20)
727
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
1
  import os
2
 
3
+ # ============================================================================
4
+ # CPU Optimization - MUST be before TensorFlow import
5
+ # ============================================================================
6
  NUM_CORES = os.cpu_count() or 4
7
 
8
  os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
9
  os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
10
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import gradio as gr
15
  import tensorflow as tf
16
  import keras
17
  from huggingface_hub import hf_hub_download
18
  import json
 
19
  from tokenizers import Tokenizer
20
  import numpy as np
21
  import time
22
 
23
+ # Configure TF threading
24
+ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
25
+ tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
26
+
27
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
28
+
29
  # ============================================================================
30
  # 🎊 FESTIVE MODE TOGGLE 🎊
31
  # ============================================================================
32
+ FESTIVE = True
33
 
34
  # ============================================================================
35
  # Configuration & Model Loading
 
41
  CACHE_DIR = "./model_cache"
42
 
43
  # ============================================================================
44
+ # Model Architecture Definitions (Optimized with KV-Cache)
45
  # ============================================================================
46
 
47
  @keras.saving.register_keras_serializable()
 
52
  self.max_len = max_len
53
  self.theta = theta
54
  self.built_cache = False
55
+ self.cos_cached = None
56
+ self.sin_cached = None
57
+
58
  def build(self, input_shape):
59
  super().build(input_shape)
60
+
61
  def _build_cache(self):
62
  if not self.built_cache:
63
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
 
67
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
68
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
69
  self.built_cache = True
70
+
71
  def rotate_half(self, x):
72
  x1, x2 = tf.split(x, 2, axis=-1)
73
  return tf.concat([-x2, x1], axis=-1)
74
+
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
77
  self._build_cache()
78
  seq_len = tf.shape(q)[2]
79
  dtype = q.dtype
80
+
 
 
81
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
82
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
83
+
84
  q_embed = (q * cos) + (self.rotate_half(q) * sin)
85
  k_embed = (k * cos) + (self.rotate_half(k) * sin)
86
  return q_embed, k_embed
87
+
88
  def get_config(self):
89
  config = super().get_config()
90
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
 
96
  def __init__(self, epsilon=1e-5, **kwargs):
97
  super().__init__(**kwargs)
98
  self.epsilon = epsilon
99
+ self.scale = None
100
+
101
  def build(self, input_shape):
102
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
103
+ super().build(input_shape)
104
+
105
  def call(self, x):
106
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
107
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
108
+
109
  def get_config(self):
110
  config = super().get_config()
111
  config.update({"epsilon": self.epsilon})
 
124
  self.rope_theta = rope_theta
125
  self.head_dim = d_model // n_heads
126
  self.layer_idx = layer_idx
127
+
128
+ def build(self, input_shape):
129
+ self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
130
+ self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
131
+ self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
132
+ self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
133
+ self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
134
+ self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
135
+ self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
136
+ self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
137
+ self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
138
+ self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
139
+ self.dropout = keras.layers.Dropout(self.dropout_rate)
140
+ super().build(input_shape)
141
+
142
  def call(self, x, training=None, past_kv=None, use_cache=False):
143
  """
144
  Args:
 
151
  B = tf.shape(x)[0]
152
  T = tf.shape(x)[1]
153
  dtype = x.dtype
154
+
155
  res = x
156
  y = self.pre_attn_norm(x)
157
+
158
  # Project Q, K, V for current input
159
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
160
  q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
161
+
162
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
163
  k = tf.transpose(k, [0, 2, 1, 3])
164
+
165
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
166
  v = tf.transpose(v, [0, 2, 1, 3])
167
+
168
  # Determine position offset for RoPE
169
  if past_kv is not None:
170
  past_len = tf.shape(past_kv[0])[2]
171
  else:
172
  past_len = 0
173
+
174
  # Apply RoPE with position offset
175
  q, k = self.rope(q, k, offset=past_len)
176
+
177
  # Concatenate with past KV
178
  if past_kv is not None:
179
  k = tf.concat([past_kv[0], k], axis=2)
180
  v = tf.concat([past_kv[1], v], axis=2)
181
+
182
  new_kv = (k, v) if use_cache else None
183
+
184
  # Attention
185
  full_len = tf.shape(k)[2]
186
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
187
+
188
+ # Causal mask
 
189
  q_positions = tf.range(past_len, past_len + T)
190
  k_positions = tf.range(full_len)
191
  mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
192
  mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
193
  scores = scores + mask[None, None, :, :]
194
+
195
  attn = tf.nn.softmax(scores, axis=-1)
196
  attn_out = tf.matmul(attn, v)
197
  attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
198
  attn_out = tf.reshape(attn_out, [B, T, self.d_model])
199
+
200
  x = res + self.dropout(self.out_proj(attn_out), training=training)
201
+
202
  # FFN
203
  res = x
204
  y = self.pre_ffn_norm(x)
205
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
206
  output = res + self.dropout(ffn, training=training)
207
+
208
  return output, new_kv
209
+
210
  def get_config(self):
211
  config = super().get_config()
212
  config.update({
213
+ "d_model": self.d_model,
214
+ "n_heads": self.n_heads,
215
+ "ff_dim": self.ff_dim,
216
+ "dropout": self.dropout_rate,
217
+ "max_len": self.max_len,
218
+ "rope_theta": self.rope_theta,
219
+ "layer_idx": self.layer_idx
220
  })
221
  return config
222
 
 
231
  self.cfg = kwargs
232
  else:
233
  self.cfg = kwargs.get('cfg', kwargs)
234
+
235
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
236
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
237
  block_args = {
238
+ 'd_model': self.cfg['d_model'],
239
+ 'n_heads': self.cfg['n_heads'],
240
+ 'ff_dim': ff_dim,
241
+ 'dropout': self.cfg['dropout'],
242
+ 'max_len': self.cfg['max_len'],
243
+ 'rope_theta': self.cfg['rope_theta']
244
  }
245
  self.blocks = [
246
  TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
 
248
  ]
249
  self.norm = RMSNorm(name="final_norm")
250
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
251
+
252
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
253
  """
254
  Args:
 
259
  logits, new_past_kv (or None)
260
  """
261
  x = self.embed(input_ids)
262
+
263
  new_past_kv = [] if use_cache else None
264
+
265
  for i, block in enumerate(self.blocks):
266
  layer_past = past_kv[i] if past_kv is not None else None
267
  x, layer_kv = block(x, training=training, past_kv=layer_past, use_cache=use_cache)
268
  if use_cache:
269
  new_past_kv.append(layer_kv)
270
+
271
  logits = self.lm_head(self.norm(x))
272
  return logits, new_past_kv
273
+
274
  def get_config(self):
275
  base_config = super().get_config()
276
  base_config['config'] = self.cfg
277
  return base_config
278
+
279
+
280
  # --- Model and Tokenizer Loading ---
281
 
282
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
292
  use_checkpoint = False
293
  except Exception as e_model:
294
  print(f"❌ Also failed to find model.keras: {e_model}")
295
+ raise RuntimeError("Could not load model weights")
296
 
297
  with open(config_path, 'r') as f:
298
  config = json.load(f)
 
326
  'rope_theta': config['rope_theta']
327
  }
328
  model = SAM1Model(config=model_config)
329
+
330
+ # Build model with dummy input
331
+ dummy_input = tf.zeros((1, 16), dtype=tf.int32)
332
+ _ = model(dummy_input, training=False, use_cache=False)
333
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
334
+
335
  try:
336
  model.load_weights(weights_path)
337
  print("βœ… Checkpoint weights loaded successfully!")
338
  except Exception as e:
339
  print(f"❌ Failed to load checkpoint weights: {e}")
340
+ raise
341
  else:
342
  print("πŸ“¦ Loading full saved model...")
343
  try:
344
+ custom_objects = {
345
+ 'SAM1Model': SAM1Model,
346
+ 'TransformerBlock': TransformerBlock,
347
+ 'RMSNorm': RMSNorm,
348
+ 'RotaryEmbedding': RotaryEmbedding
349
+ }
350
  model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
351
  print("βœ… Model loaded successfully")
352
  except Exception as e:
353
  print(f"❌ Failed to load model: {e}")
354
+ raise
355
 
356
  if model:
357
+ print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
358
+
359
+ # Warm up the model
360
+ print("πŸ”₯ Warming up model...")
361
+ warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
362
+ _, _ = model(warmup_input, training=False, use_cache=True)
363
+ print("βœ… Model warmed up")
364
 
365
  # ============================================================================
366
+ # Optimized Inference Logic with KV-Cache
367
  # ============================================================================
368
 
 
 
 
 
 
369
  stop_generation = False
370
 
371
+
372
+ def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
373
+ """Pure NumPy sampling for speed."""
374
+ # Temperature scaling
375
+ scaled_logits = logits / temperature
376
+
377
+ # Repetition penalty
378
+ if repetition_penalty != 1.0:
379
+ for token_id, freq in token_freq.items():
380
+ if token_id < len(scaled_logits):
381
+ scaled_logits[token_id] /= (repetition_penalty ** freq)
382
+
383
+ # Top-K filtering
384
+ if top_k > 0 and top_k < len(scaled_logits):
385
+ top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
386
+ top_k_logits = scaled_logits[top_k_indices]
387
+ else:
388
+ top_k_indices = np.arange(len(scaled_logits))
389
+ top_k_logits = scaled_logits
390
+
391
+ # Softmax (numerically stable)
392
+ top_k_logits = top_k_logits - np.max(top_k_logits)
393
+ top_k_probs = np.exp(top_k_logits)
394
+ top_k_probs /= top_k_probs.sum()
395
+
396
+ # Top-P (nucleus) filtering
397
+ if top_p < 1.0:
398
+ sorted_idx = np.argsort(top_k_probs)[::-1]
399
+ cumsum = np.cumsum(top_k_probs[sorted_idx])
400
+ cutoff = np.searchsorted(cumsum, top_p) + 1
401
+ nucleus_idx = sorted_idx[:cutoff]
402
+ nucleus_probs = top_k_probs[nucleus_idx]
403
+ nucleus_probs /= nucleus_probs.sum()
404
+ sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
405
+ return int(top_k_indices[nucleus_idx[sampled]])
406
+ else:
407
+ sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
408
+ return int(top_k_indices[sampled])
409
+
410
+
411
  def generate_stream(
412
  prompt: str,
413
  max_tokens: int = 512,
 
419
  """Generate text with KV-cache for fast CPU inference."""
420
  global stop_generation
421
  stop_generation = False
422
+
423
+ # Tokenize prompt
424
  prompt_ids = tokenizer.encode(prompt).ids
425
  input_ids = [i for i in prompt_ids if i != eos_token_id]
426
+
427
+ if len(input_ids) == 0:
428
+ yield "Error: Empty prompt after tokenization"
429
+ return
430
+
431
  generated_text = ""
432
  token_count = 0
433
  token_freq = {}
434
+
435
+ # Get special token IDs
436
+ im_end_id = tokenizer.token_to_id("<|im_end|>")
437
+ model_end_id = tokenizer.token_to_id("<im end for model tun>")
438
+ stop_ids = {eos_token_id, im_end_id, model_end_id}
439
+ stop_ids.discard(None)
440
+
441
+ max_context = config['max_position_embeddings']
442
+
443
  start_time = time.time()
444
+
445
  # === PREFILL PHASE ===
446
+ # Truncate if prompt is too long
447
+ if len(input_ids) > max_context - max_tokens:
448
+ input_ids = input_ids[-(max_context - max_tokens):]
449
+
450
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
 
451
 
452
+ try:
453
+ logits, past_kv = model(input_tensor, training=False, use_cache=True)
454
+ except Exception as e:
455
+ yield f"Error during prefill: {e}"
456
+ return
457
+
458
  # Get logits for last position
459
  next_token_logits = logits[0, -1, :].numpy()
460
+
461
+ prefill_time = time.time() - start_time
462
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
463
+
464
  # === GENERATION LOOP ===
465
+ decode_start = time.time()
466
+
467
  for step in range(max_tokens):
468
  if stop_generation:
469
  yield generated_text + "\n\n*[Generation stopped]*"
470
+ return
471
+
472
+ # Sample next token
473
+ next_token_id = sample_token(
474
+ next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
475
+ )
476
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  # Stop conditions
478
+ if next_token_id in stop_ids:
 
 
 
 
479
  break
480
+
481
  # Update frequency tracking
482
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
483
+
484
  # Decode and yield
485
  token_text = tokenizer.decode([next_token_id])
486
  generated_text += token_text
487
  token_count += 1
488
  yield generated_text
489
+
490
  # === DECODE PHASE (single token, reuse cache) ===
491
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
 
 
492
 
493
+ try:
494
+ logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
495
+ except Exception as e:
496
+ yield generated_text + f"\n\n*[Error during generation: {e}]*"
497
+ return
498
+
499
+ next_token_logits = logits[0, -1, :].numpy()
500
+
501
  # Truncate cache if too long
502
+ current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
503
+ if current_len > max_context:
504
+ trim_amount = current_len - max_context + 100 # Keep some buffer
505
+ past_kv = [
506
+ (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
507
+ for k, v in past_kv
508
+ ]
509
+
510
+ decode_time = time.time() - decode_start
511
+ total_time = time.time() - start_time
512
 
513
+ if token_count > 0:
514
+ decode_tps = token_count / decode_time if decode_time > 0 else 0
515
+ total_tps = token_count / total_time if total_time > 0 else 0
516
+
517
+ stats = (
518
+ f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
519
+ f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
520
+ )
521
+
522
+ if not stop_generation:
523
+ generated_text += stats
524
+
525
  yield generated_text
526
+
527
+
528
  # ============================================================================
529
  # Chat Interface Logic
530
  # ============================================================================
531
 
532
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
533
+ """Format message history and seed <think> if enabled."""
534
  prompt = ""
535
  for user_msg, assistant_msg in history:
536
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
537
  if assistant_msg:
538
+ # Clean up any stats from previous messages
539
+ clean_msg = assistant_msg.split("\n\n*[")[0]
540
+ prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
541
+
542
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
543
+
 
544
  if reasoning_enabled:
545
  prompt += "<think>"
546
+
547
  return prompt
548
 
549
+
550
  def chat_stream(
551
  message: str,
552
  history: list,
 
560
  if not message.strip():
561
  yield history
562
  return
563
+
564
  prompt = format_chat_prompt(message, history, reasoning_enabled)
565
  partial_response = ""
566
+
 
 
567
  for generated in generate_stream(
568
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
569
  ):
570
  partial_response = generated
571
+
572
+ # Robust end-of-turn detection
573
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
574
  earliest_stop = len(partial_response)
575
  should_stop = False
576
 
577
  for tag in stop_tags:
578
  if tag in partial_response:
579
+ idx = partial_response.find(tag)
580
+ if idx < earliest_stop:
581
+ earliest_stop = idx
582
+ should_stop = True
583
+
584
+ display_response = partial_response
585
  if should_stop:
586
+ # Keep the stats portion if present
587
+ stats_start = partial_response.find("\n\n*[")
588
+ if stats_start > earliest_stop:
589
+ display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
590
+ else:
591
+ display_response = partial_response[:earliest_stop]
592
 
593
+ # Post-process reasoning tags for display
594
  if reasoning_enabled:
595
+ if '<think>' in display_response and '</think>' in display_response:
596
+ start_idx = display_response.find('<think>')
597
+ end_idx = display_response.find('</think>')
598
  if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
599
+ thought_content = display_response[start_idx + len('<think>'):end_idx].strip()
 
 
600
  formatted_thought = thought_content.replace("\n", "<br>")
 
601
  details_html = (
602
  f'<details class="reasoning-block">'
603
+ f'<summary>🧠 Model Reasoning (Click to expand)</summary>'
604
  f'<p>{formatted_thought}</p>'
605
  f'</details>'
606
  )
607
+ display_response = (
608
+ display_response[:start_idx] +
609
+ details_html +
610
+ display_response[end_idx + len('</think>'):]
611
+ )
612
+ elif '<think>' in display_response and '</think>' not in display_response:
613
+ display_response = display_response.replace('<think>', '**🧠 Thinking:** ')
614
+
615
+ yield history + [[message, display_response.strip()]]
616
+
617
 
618
  def stop_gen():
619
  global stop_generation
620
  stop_generation = True
621
  return None
622
 
623
+
624
  # ============================================================================
625
  # Gradio UI
626
  # ============================================================================
 
663
  .gradio-html details.reasoning-block p { margin-top: 5px; padding-left: 10px; border-left: 1px dashed #ccc; white-space: pre-wrap; }
664
  .modal-overlay {
665
  position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: rgba(0, 0, 0, 0.7);
666
+ display: flex; justify-content: center; align-items: center; z-index: 1000;
667
  }
668
  .modal-content {
669
  background: white; padding: 30px; border-radius: 15px; width: 90%; max-width: 900px;
 
682
  border: none; border-radius: 8px; cursor: pointer; font-size: 1rem; transition: background-color 0.3s;
683
  }
684
  .close-btn:hover { background-color: #5d3a84; }
685
+ .speed-indicator {
686
+ background: linear-gradient(135deg, #00b894, #00cec9);
687
+ color: white; padding: 5px 10px; border-radius: 10px; font-size: 0.8rem;
688
+ display: inline-block; margin-left: 10px;
689
+ }
690
  """
691
 
 
 
 
692
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
693
+ reasoning_enabled = gr.State(False)
694
+
 
695
  welcome_modal_html = gr.HTML(
696
  """
697
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
698
  <div class="modal-content">
699
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
700
+ <p>Our latest model features <strong>Chain-of-Thought (CoT)</strong> functionality and <strong>KV-Cache optimization</strong> for fast CPU inference!</p>
701
  <div class="comparison-box">
702
  <div class="comparison-mode mode-reasoning">
703
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
704
+ <p>The model performs a <strong>CoT step</strong> first. The internal thought process is contained within <code>&lt;think>...&lt;/think></code> tags.</p>
705
  </div>
706
  <div class="comparison-mode mode-direct">
707
  <h3>βšͺ Direct Mode (OFF)</h3>
 
718
  gr.HTML("""
719
  <div class="header">
720
  <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
721
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
722
  alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
723
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
724
+ <p><strong>LATEST RELEASE!</strong> Our <strong>BEST Reasoning Model</strong> - Now with KV-Cache! <span class="speed-indicator">⚑ 5-20x Faster</span></p>
725
  <div class="twin-badge">Reasoning Model</div>
726
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
727
  </div>
728
  """)
729
  else:
730
+ gr.HTML("""<div class="header"><h1>πŸ€– Sam-large-2 Chat</h1><p>Advanced Reasoning Model with KV-Cache</p></div>""")
731
 
732
  with gr.Row():
733
  with gr.Column(scale=4):
734
  chatbot = gr.Chatbot(
735
+ height=600,
736
+ show_label=False,
737
+ avatar_images=(
738
+ None,
739
+ "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"
740
+ ),
741
  bubble_full_width=False
742
  )
743
  with gr.Row():
744
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
745
+ reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
746
  gr.HTML('<span class="new-tag-red">NEW</span>')
747
+ msg = gr.Textbox(
748
+ placeholder="Type your message here...",
749
+ show_label=False,
750
+ scale=8,
751
+ container=False
752
+ )
753
  submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
754
  stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
755
  with gr.Row():
756
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
757
  retry_btn = gr.Button("πŸ”„ Retry", size="sm")
758
+
759
  with gr.Column(scale=1):
760
  gr.Markdown("### βš™οΈ Generation Settings")
761
  max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=50, label="Max Tokens")
 
765
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
766
  gr.Markdown("---")
767
  gr.Markdown(f"""### 🎊 Sam-large-2 Model Info
768
+ **Type:** Chain-of-Thought Reasoning Model
769
+ **Vocab:** {config['vocab_size']:,}
770
+ **Layers:** {config['num_hidden_layers']}
771
+ **Context:** {config['max_position_embeddings']:,} tokens
772
+ **Optimization:** KV-Cache enabled ⚑
773
+ """)
774
+
775
+ gr.Examples(
776
+ examples=[
777
+ "Explain quantum computing in simple terms",
778
+ "Write a short poem about artificial intelligence",
779
+ "What is 24 * 12? Show your reasoning.",
780
+ "What are the main differences between Python and JavaScript?"
781
+ ],
782
+ inputs=msg
783
+ )
784
 
 
 
785
  gr.HTML("""
786
+ <footer>
787
+ <p><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE with KV-Cache! πŸŽ‰</strong></p>
788
+ <p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β€’ Built by Smily studios with TensorFlow & Gradio</p>
789
+ </footer>
790
+ """)
791
+
792
  def show_modal_js():
793
  return """
794
  (function() {
 
798
  }
799
  })();
800
  """
801
+
802
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
803
 
804
  def toggle_reasoning(current_state):
805
  new_state = not current_state
806
  return new_state, gr.update(elem_classes="" if new_state else "off")
807
 
808
+ reasoning_btn.click(
809
+ fn=toggle_reasoning,
810
+ inputs=[reasoning_enabled],
811
+ outputs=[reasoning_enabled, reasoning_btn],
812
+ preprocess=False
813
+ )
814
 
815
  common_inputs = [msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled]
816
+
817
+ submit_event = msg.submit(
818
+ chat_stream,
819
+ inputs=common_inputs,
820
+ outputs=[chatbot]
821
+ ).then(lambda: "", outputs=[msg])
822
+
823
+ click_event = submit_btn.click(
824
+ chat_stream,
825
+ inputs=common_inputs,
826
+ outputs=[chatbot]
827
+ ).then(lambda: "", outputs=[msg])
828
+
829
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[submit_event, click_event])
830
  clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
831
+
832
  def retry_last(history, max_tok, temp, topk, topp, rep_pen, reasoning_en):
833
+ if not history:
834
+ return history
835
  last_user_msg = history[-1][0]
836
  for update in chat_stream(last_user_msg, history[:-1], max_tok, temp, topk, topp, rep_pen, reasoning_en):
837
  yield update
838
+
839
+ retry_event = retry_btn.click(
840
+ retry_last,
841
+ inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled],
842
+ outputs=[chatbot]
843
+ )
844
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[retry_event])
845
 
846
  if __name__ == "__main__":
847
+ print("\n" + "=" * 60)
848
+ print("πŸš€ Starting Sam-large-2 Chat with KV-Cache Optimization")
849
+ print("=" * 60 + "\n")
850
  demo.queue(max_size=20)
851
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)