Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,12 +7,9 @@ NUM_CORES = os.cpu_count() or 4
|
|
| 7 |
|
| 8 |
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
|
| 9 |
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
|
| 10 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
| 11 |
-
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
|
| 12 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 13 |
-
# NEW: Enable XLA (Accelerated Linear Algebra)
|
| 14 |
-
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
|
| 15 |
-
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
|
| 16 |
|
| 17 |
import gradio as gr
|
| 18 |
import tensorflow as tf
|
|
@@ -23,13 +20,11 @@ from tokenizers import Tokenizer
|
|
| 23 |
import numpy as np
|
| 24 |
import time
|
| 25 |
|
| 26 |
-
# Configure TF
|
| 27 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
| 28 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 29 |
-
tf.config.optimizer.set_jit(True) # Enable XLA JIT compilation
|
| 30 |
-
tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})
|
| 31 |
|
| 32 |
-
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN
|
| 33 |
|
| 34 |
# ============================================================================
|
| 35 |
# π FESTIVE MODE TOGGLE π
|
|
@@ -62,8 +57,6 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
| 62 |
|
| 63 |
def build(self, input_shape):
|
| 64 |
super().build(input_shape)
|
| 65 |
-
# Pre-build cache immediately
|
| 66 |
-
self._build_cache()
|
| 67 |
|
| 68 |
def _build_cache(self):
|
| 69 |
if not self.built_cache:
|
|
@@ -71,19 +64,17 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
| 71 |
t = tf.range(self.max_len, dtype=tf.float32)
|
| 72 |
freqs = tf.einsum("i,j->ij", t, inv_freq)
|
| 73 |
emb = tf.concat([freqs, freqs], axis=-1)
|
| 74 |
-
# Use tf.constant for faster access
|
| 75 |
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
|
| 76 |
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
|
| 77 |
self.built_cache = True
|
| 78 |
|
| 79 |
-
@tf.function(jit_compile=True)
|
| 80 |
def rotate_half(self, x):
|
| 81 |
x1, x2 = tf.split(x, 2, axis=-1)
|
| 82 |
return tf.concat([-x2, x1], axis=-1)
|
| 83 |
|
| 84 |
-
@tf.function(jit_compile=True)
|
| 85 |
def call(self, q, k, offset=0):
|
| 86 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
|
|
|
| 87 |
seq_len = tf.shape(q)[2]
|
| 88 |
dtype = q.dtype
|
| 89 |
|
|
@@ -111,7 +102,6 @@ class RMSNorm(keras.layers.Layer):
|
|
| 111 |
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
|
| 112 |
super().build(input_shape)
|
| 113 |
|
| 114 |
-
@tf.function(jit_compile=True)
|
| 115 |
def call(self, x):
|
| 116 |
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
|
| 117 |
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
|
|
@@ -149,8 +139,15 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 149 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
| 150 |
super().build(input_shape)
|
| 151 |
|
| 152 |
-
# Removed @tf.function here to allow flexible past_kv handling
|
| 153 |
def call(self, x, training=None, past_kv=None, use_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
B = tf.shape(x)[0]
|
| 155 |
T = tf.shape(x)[1]
|
| 156 |
dtype = x.dtype
|
|
@@ -158,9 +155,9 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 158 |
res = x
|
| 159 |
y = self.pre_attn_norm(x)
|
| 160 |
|
| 161 |
-
# Project Q, K, V
|
| 162 |
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 163 |
-
q = tf.transpose(q, [0, 2, 1, 3])
|
| 164 |
|
| 165 |
k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 166 |
k = tf.transpose(k, [0, 2, 1, 3])
|
|
@@ -168,13 +165,13 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 168 |
v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 169 |
v = tf.transpose(v, [0, 2, 1, 3])
|
| 170 |
|
| 171 |
-
#
|
| 172 |
if past_kv is not None:
|
| 173 |
past_len = tf.shape(past_kv[0])[2]
|
| 174 |
else:
|
| 175 |
past_len = 0
|
| 176 |
|
| 177 |
-
# Apply RoPE
|
| 178 |
q, k = self.rope(q, k, offset=past_len)
|
| 179 |
|
| 180 |
# Concatenate with past KV
|
|
@@ -200,13 +197,13 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 200 |
attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
|
| 201 |
attn_out = tf.reshape(attn_out, [B, T, self.d_model])
|
| 202 |
|
| 203 |
-
x = res + self.out_proj(attn_out)
|
| 204 |
|
| 205 |
# FFN
|
| 206 |
res = x
|
| 207 |
y = self.pre_ffn_norm(x)
|
| 208 |
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
|
| 209 |
-
output = res + ffn
|
| 210 |
|
| 211 |
return output, new_kv
|
| 212 |
|
|
@@ -252,8 +249,15 @@ class SAM1Model(keras.Model):
|
|
| 252 |
self.norm = RMSNorm(name="final_norm")
|
| 253 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 254 |
|
| 255 |
-
# Don't use @tf.function on the main call to allow flexible caching
|
| 256 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
x = self.embed(input_ids)
|
| 258 |
|
| 259 |
new_past_kv = [] if use_cache else None
|
|
@@ -318,11 +322,12 @@ if use_checkpoint:
|
|
| 318 |
'n_heads': config['num_attention_heads'],
|
| 319 |
'ff_mult': config['intermediate_size'] / config['hidden_size'],
|
| 320 |
'max_len': config['max_position_embeddings'],
|
| 321 |
-
'dropout': 0.
|
| 322 |
'rope_theta': config['rope_theta']
|
| 323 |
}
|
| 324 |
model = SAM1Model(config=model_config)
|
| 325 |
|
|
|
|
| 326 |
dummy_input = tf.zeros((1, 16), dtype=tf.int32)
|
| 327 |
_ = model(dummy_input, training=False, use_cache=False)
|
| 328 |
print(f"β
Model architecture built: {model.count_params():,} parameters")
|
|
@@ -351,37 +356,11 @@ else:
|
|
| 351 |
if model:
|
| 352 |
print(f"β
Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
|
| 353 |
|
| 354 |
-
#
|
| 355 |
-
|
| 356 |
-
# ============================================================================
|
| 357 |
-
|
| 358 |
-
print("β‘ Compiling optimized inference functions...")
|
| 359 |
-
|
| 360 |
-
# Create traced functions for prefill and decode
|
| 361 |
-
@tf.function(
|
| 362 |
-
input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)],
|
| 363 |
-
jit_compile=True
|
| 364 |
-
)
|
| 365 |
-
def prefill_fn(input_ids):
|
| 366 |
-
"""Optimized prefill phase - processes entire prompt at once."""
|
| 367 |
-
return model(input_ids, training=False, use_cache=True)
|
| 368 |
-
|
| 369 |
-
@tf.function(
|
| 370 |
-
input_signature=[tf.TensorSpec(shape=[1, 1], dtype=tf.int32)],
|
| 371 |
-
jit_compile=True
|
| 372 |
-
)
|
| 373 |
-
def decode_fn_no_cache(input_ids):
|
| 374 |
-
"""Fast decode for single token without cache (fallback)."""
|
| 375 |
-
logits, _ = model(input_ids, training=False, use_cache=False)
|
| 376 |
-
return logits
|
| 377 |
-
|
| 378 |
-
# Warm up with compilation
|
| 379 |
-
print("π₯ Warming up and compiling model...")
|
| 380 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 381 |
-
_ =
|
| 382 |
-
|
| 383 |
-
_ = decode_fn_no_cache(warmup_single)
|
| 384 |
-
print("β
Model warmed up and compiled")
|
| 385 |
|
| 386 |
# ============================================================================
|
| 387 |
# Optimized Inference Logic with KV-Cache
|
|
@@ -389,50 +368,39 @@ print("β
Model warmed up and compiled")
|
|
| 389 |
|
| 390 |
stop_generation = False
|
| 391 |
|
| 392 |
-
# Pre-compile sampling helper
|
| 393 |
-
@tf.function(jit_compile=True)
|
| 394 |
-
def apply_temperature_tf(logits, temperature):
|
| 395 |
-
"""GPU-accelerated temperature scaling."""
|
| 396 |
-
return logits / temperature
|
| 397 |
|
| 398 |
def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 399 |
-
"""
|
| 400 |
-
# Temperature
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
if repetition_penalty != 1.0 and token_freq:
|
| 406 |
-
penalty = np.ones_like(logits)
|
| 407 |
for token_id, freq in token_freq.items():
|
| 408 |
-
if token_id < len(
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
top_k_logits = logits[top_k_indices]
|
| 416 |
else:
|
| 417 |
-
top_k_indices = np.arange(len(
|
| 418 |
-
top_k_logits =
|
| 419 |
|
| 420 |
-
# Softmax (stable)
|
| 421 |
-
|
| 422 |
-
top_k_logits = top_k_logits - max_logit
|
| 423 |
top_k_probs = np.exp(top_k_logits)
|
| 424 |
-
|
| 425 |
-
top_k_probs = top_k_probs / sum_probs
|
| 426 |
|
| 427 |
-
# Top-P
|
| 428 |
if top_p < 1.0:
|
| 429 |
sorted_idx = np.argsort(top_k_probs)[::-1]
|
| 430 |
cumsum = np.cumsum(top_k_probs[sorted_idx])
|
| 431 |
cutoff = np.searchsorted(cumsum, top_p) + 1
|
| 432 |
-
cutoff = min(cutoff, len(sorted_idx))
|
| 433 |
nucleus_idx = sorted_idx[:cutoff]
|
| 434 |
nucleus_probs = top_k_probs[nucleus_idx]
|
| 435 |
-
nucleus_probs
|
| 436 |
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
|
| 437 |
return int(top_k_indices[nucleus_idx[sampled]])
|
| 438 |
else:
|
|
@@ -448,23 +416,23 @@ def generate_stream(
|
|
| 448 |
top_p: float = 0.9,
|
| 449 |
repetition_penalty: float = 1.1
|
| 450 |
):
|
| 451 |
-
"""
|
| 452 |
global stop_generation
|
| 453 |
stop_generation = False
|
| 454 |
|
| 455 |
-
# Tokenize
|
| 456 |
prompt_ids = tokenizer.encode(prompt).ids
|
| 457 |
input_ids = [i for i in prompt_ids if i != eos_token_id]
|
| 458 |
|
| 459 |
if len(input_ids) == 0:
|
| 460 |
-
yield "Error: Empty prompt"
|
| 461 |
return
|
| 462 |
|
| 463 |
generated_text = ""
|
| 464 |
token_count = 0
|
| 465 |
token_freq = {}
|
| 466 |
|
| 467 |
-
#
|
| 468 |
im_end_id = tokenizer.token_to_id("<|im_end|>")
|
| 469 |
model_end_id = tokenizer.token_to_id("<im end for model tun>")
|
| 470 |
stop_ids = {eos_token_id, im_end_id, model_end_id}
|
|
@@ -474,81 +442,81 @@ def generate_stream(
|
|
| 474 |
|
| 475 |
start_time = time.time()
|
| 476 |
|
| 477 |
-
# === PREFILL
|
|
|
|
| 478 |
if len(input_ids) > max_context - max_tokens:
|
| 479 |
input_ids = input_ids[-(max_context - max_tokens):]
|
| 480 |
|
| 481 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 482 |
|
| 483 |
try:
|
| 484 |
-
logits, past_kv =
|
| 485 |
-
logits_np = logits[0, -1, :].numpy()
|
| 486 |
except Exception as e:
|
| 487 |
-
yield f"Error: {e}"
|
| 488 |
return
|
| 489 |
|
|
|
|
|
|
|
|
|
|
| 490 |
prefill_time = time.time() - start_time
|
| 491 |
-
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.
|
| 492 |
|
| 493 |
-
# ===
|
| 494 |
decode_start = time.time()
|
| 495 |
-
decode_times = []
|
| 496 |
|
| 497 |
for step in range(max_tokens):
|
| 498 |
if stop_generation:
|
| 499 |
-
yield generated_text + "\n\n*[
|
| 500 |
return
|
| 501 |
|
| 502 |
-
# Sample
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
| 506 |
if next_token_id in stop_ids:
|
| 507 |
break
|
| 508 |
|
|
|
|
| 509 |
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
|
| 510 |
|
| 511 |
-
# Decode
|
| 512 |
token_text = tokenizer.decode([next_token_id])
|
| 513 |
generated_text += token_text
|
| 514 |
token_count += 1
|
| 515 |
-
|
| 516 |
-
# Yield every token for streaming
|
| 517 |
-
if token_count % 1 == 0: # Stream every token
|
| 518 |
-
yield generated_text
|
| 519 |
|
| 520 |
-
# ===
|
| 521 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 522 |
|
| 523 |
try:
|
| 524 |
logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
|
| 525 |
-
logits_np = logits[0, -1, :].numpy()
|
| 526 |
except Exception as e:
|
| 527 |
-
yield generated_text + f"\n\n*[Error: {e}]*"
|
| 528 |
return
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
#
|
| 533 |
-
if past_kv and
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
|
|
|
|
|
|
| 538 |
|
| 539 |
-
# Stats
|
| 540 |
decode_time = time.time() - decode_start
|
| 541 |
total_time = time.time() - start_time
|
| 542 |
|
| 543 |
if token_count > 0:
|
| 544 |
-
avg_decode_time = np.mean(decode_times) if decode_times else 0
|
| 545 |
decode_tps = token_count / decode_time if decode_time > 0 else 0
|
|
|
|
| 546 |
|
| 547 |
stats = (
|
| 548 |
-
f"\n\n*[{token_count} tokens in {total_time:.1f}s
|
| 549 |
-
f"
|
| 550 |
-
f"Decode: {decode_tps:.1f} t/s | "
|
| 551 |
-
f"Avg: {1/avg_decode_time:.1f} t/s]*"
|
| 552 |
)
|
| 553 |
|
| 554 |
if not stop_generation:
|
|
@@ -558,7 +526,7 @@ def generate_stream(
|
|
| 558 |
|
| 559 |
|
| 560 |
# ============================================================================
|
| 561 |
-
# Chat Interface Logic
|
| 562 |
# ============================================================================
|
| 563 |
|
| 564 |
def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
|
|
@@ -567,6 +535,7 @@ def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) ->
|
|
| 567 |
for user_msg, assistant_msg in history:
|
| 568 |
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
| 569 |
if assistant_msg:
|
|
|
|
| 570 |
clean_msg = assistant_msg.split("\n\n*[")[0]
|
| 571 |
prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
|
| 572 |
|
|
@@ -600,6 +569,7 @@ def chat_stream(
|
|
| 600 |
):
|
| 601 |
partial_response = generated
|
| 602 |
|
|
|
|
| 603 |
stop_tags = ["<|im_end|>", "<im end for model tun>"]
|
| 604 |
earliest_stop = len(partial_response)
|
| 605 |
should_stop = False
|
|
@@ -613,12 +583,14 @@ def chat_stream(
|
|
| 613 |
|
| 614 |
display_response = partial_response
|
| 615 |
if should_stop:
|
|
|
|
| 616 |
stats_start = partial_response.find("\n\n*[")
|
| 617 |
if stats_start > earliest_stop:
|
| 618 |
display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
|
| 619 |
else:
|
| 620 |
display_response = partial_response[:earliest_stop]
|
| 621 |
|
|
|
|
| 622 |
if reasoning_enabled:
|
| 623 |
if '<think>' in display_response and '</think>' in display_response:
|
| 624 |
start_idx = display_response.find('<think>')
|
|
@@ -650,7 +622,7 @@ def stop_gen():
|
|
| 650 |
|
| 651 |
|
| 652 |
# ============================================================================
|
| 653 |
-
# Gradio UI
|
| 654 |
# ============================================================================
|
| 655 |
|
| 656 |
custom_css = """
|
|
@@ -725,15 +697,15 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 725 |
<div id="welcome-modal" class="modal-overlay" style="display:none;">
|
| 726 |
<div class="modal-content">
|
| 727 |
<h2>π§ Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
|
| 728 |
-
<p
|
| 729 |
<div class="comparison-box">
|
| 730 |
<div class="comparison-mode mode-reasoning">
|
| 731 |
<h3>π‘ Reasoning Mode (ON)</h3>
|
| 732 |
-
<p>
|
| 733 |
</div>
|
| 734 |
<div class="comparison-mode mode-direct">
|
| 735 |
<h3>βͺ Direct Mode (OFF)</h3>
|
| 736 |
-
<p>
|
| 737 |
</div>
|
| 738 |
</div>
|
| 739 |
<button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
|
|
@@ -749,13 +721,13 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 749 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
|
| 750 |
alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
|
| 751 |
<h1>π€ Sam-large-2 Chat π€</h1>
|
| 752 |
-
<p><strong>
|
| 753 |
-
<div class="twin-badge">Reasoning Model
|
| 754 |
<div class="celebration">π π« π― β‘ π₯</div>
|
| 755 |
</div>
|
| 756 |
""")
|
| 757 |
else:
|
| 758 |
-
gr.HTML("""<div class="header"><h1>π€ Sam-large-2 Chat</h1><p>
|
| 759 |
|
| 760 |
with gr.Row():
|
| 761 |
with gr.Column(scale=4):
|
|
@@ -793,12 +765,11 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 793 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
|
| 794 |
gr.Markdown("---")
|
| 795 |
gr.Markdown(f"""### π Sam-large-2 Model Info
|
| 796 |
-
**Type:**
|
| 797 |
**Vocab:** {config['vocab_size']:,}
|
| 798 |
**Layers:** {config['num_hidden_layers']}
|
| 799 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 800 |
-
**Optimization:** KV-Cache
|
| 801 |
-
**Expected Speed:** 10-30 tok/s (CPU)
|
| 802 |
""")
|
| 803 |
|
| 804 |
gr.Examples(
|
|
@@ -813,8 +784,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 813 |
|
| 814 |
gr.HTML("""
|
| 815 |
<footer>
|
| 816 |
-
<p><strong
|
| 817 |
-
<p style="font-size: 0.9rem; color: #999;">
|
| 818 |
</footer>
|
| 819 |
""")
|
| 820 |
|
|
@@ -874,7 +845,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 874 |
|
| 875 |
if __name__ == "__main__":
|
| 876 |
print("\n" + "=" * 60)
|
| 877 |
-
print("
|
| 878 |
print("=" * 60 + "\n")
|
| 879 |
demo.queue(max_size=20)
|
| 880 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|
|
|
|
| 7 |
|
| 8 |
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
|
| 9 |
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
|
| 10 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
|
| 11 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
|
| 12 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
import tensorflow as tf
|
|
|
|
| 20 |
import numpy as np
|
| 21 |
import time
|
| 22 |
|
| 23 |
+
# Configure TF threading
|
| 24 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
| 25 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled")
|
| 28 |
|
| 29 |
# ============================================================================
|
| 30 |
# π FESTIVE MODE TOGGLE π
|
|
|
|
| 57 |
|
| 58 |
def build(self, input_shape):
|
| 59 |
super().build(input_shape)
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def _build_cache(self):
|
| 62 |
if not self.built_cache:
|
|
|
|
| 64 |
t = tf.range(self.max_len, dtype=tf.float32)
|
| 65 |
freqs = tf.einsum("i,j->ij", t, inv_freq)
|
| 66 |
emb = tf.concat([freqs, freqs], axis=-1)
|
|
|
|
| 67 |
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
|
| 68 |
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
|
| 69 |
self.built_cache = True
|
| 70 |
|
|
|
|
| 71 |
def rotate_half(self, x):
|
| 72 |
x1, x2 = tf.split(x, 2, axis=-1)
|
| 73 |
return tf.concat([-x2, x1], axis=-1)
|
| 74 |
|
|
|
|
| 75 |
def call(self, q, k, offset=0):
|
| 76 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
| 77 |
+
self._build_cache()
|
| 78 |
seq_len = tf.shape(q)[2]
|
| 79 |
dtype = q.dtype
|
| 80 |
|
|
|
|
| 102 |
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
|
| 103 |
super().build(input_shape)
|
| 104 |
|
|
|
|
| 105 |
def call(self, x):
|
| 106 |
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
|
| 107 |
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
|
|
|
|
| 139 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
| 140 |
super().build(input_shape)
|
| 141 |
|
|
|
|
| 142 |
def call(self, x, training=None, past_kv=None, use_cache=False):
|
| 143 |
+
"""
|
| 144 |
+
Args:
|
| 145 |
+
x: input tensor [B, T, D] (T=1 during cached generation)
|
| 146 |
+
past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
|
| 147 |
+
use_cache: whether to return updated kv cache
|
| 148 |
+
Returns:
|
| 149 |
+
output, (new_k, new_v) if use_cache else output, None
|
| 150 |
+
"""
|
| 151 |
B = tf.shape(x)[0]
|
| 152 |
T = tf.shape(x)[1]
|
| 153 |
dtype = x.dtype
|
|
|
|
| 155 |
res = x
|
| 156 |
y = self.pre_attn_norm(x)
|
| 157 |
|
| 158 |
+
# Project Q, K, V for current input
|
| 159 |
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 160 |
+
q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
|
| 161 |
|
| 162 |
k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 163 |
k = tf.transpose(k, [0, 2, 1, 3])
|
|
|
|
| 165 |
v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 166 |
v = tf.transpose(v, [0, 2, 1, 3])
|
| 167 |
|
| 168 |
+
# Determine position offset for RoPE
|
| 169 |
if past_kv is not None:
|
| 170 |
past_len = tf.shape(past_kv[0])[2]
|
| 171 |
else:
|
| 172 |
past_len = 0
|
| 173 |
|
| 174 |
+
# Apply RoPE with position offset
|
| 175 |
q, k = self.rope(q, k, offset=past_len)
|
| 176 |
|
| 177 |
# Concatenate with past KV
|
|
|
|
| 197 |
attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
|
| 198 |
attn_out = tf.reshape(attn_out, [B, T, self.d_model])
|
| 199 |
|
| 200 |
+
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 201 |
|
| 202 |
# FFN
|
| 203 |
res = x
|
| 204 |
y = self.pre_ffn_norm(x)
|
| 205 |
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
|
| 206 |
+
output = res + self.dropout(ffn, training=training)
|
| 207 |
|
| 208 |
return output, new_kv
|
| 209 |
|
|
|
|
| 249 |
self.norm = RMSNorm(name="final_norm")
|
| 250 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 251 |
|
|
|
|
| 252 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
| 253 |
+
"""
|
| 254 |
+
Args:
|
| 255 |
+
input_ids: [B, T]
|
| 256 |
+
past_kv: list of (k, v) tuples, one per layer
|
| 257 |
+
use_cache: whether to return updated cache
|
| 258 |
+
Returns:
|
| 259 |
+
logits, new_past_kv (or None)
|
| 260 |
+
"""
|
| 261 |
x = self.embed(input_ids)
|
| 262 |
|
| 263 |
new_past_kv = [] if use_cache else None
|
|
|
|
| 322 |
'n_heads': config['num_attention_heads'],
|
| 323 |
'ff_mult': config['intermediate_size'] / config['hidden_size'],
|
| 324 |
'max_len': config['max_position_embeddings'],
|
| 325 |
+
'dropout': 0.1,
|
| 326 |
'rope_theta': config['rope_theta']
|
| 327 |
}
|
| 328 |
model = SAM1Model(config=model_config)
|
| 329 |
|
| 330 |
+
# Build model with dummy input
|
| 331 |
dummy_input = tf.zeros((1, 16), dtype=tf.int32)
|
| 332 |
_ = model(dummy_input, training=False, use_cache=False)
|
| 333 |
print(f"β
Model architecture built: {model.count_params():,} parameters")
|
|
|
|
| 356 |
if model:
|
| 357 |
print(f"β
Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
|
| 358 |
|
| 359 |
+
# Warm up the model
|
| 360 |
+
print("π₯ Warming up model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 362 |
+
_, _ = model(warmup_input, training=False, use_cache=True)
|
| 363 |
+
print("β
Model warmed up")
|
|
|
|
|
|
|
| 364 |
|
| 365 |
# ============================================================================
|
| 366 |
# Optimized Inference Logic with KV-Cache
|
|
|
|
| 368 |
|
| 369 |
stop_generation = False
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 373 |
+
"""Pure NumPy sampling for speed."""
|
| 374 |
+
# Temperature scaling
|
| 375 |
+
scaled_logits = logits / temperature
|
| 376 |
+
|
| 377 |
+
# Repetition penalty
|
| 378 |
+
if repetition_penalty != 1.0:
|
|
|
|
|
|
|
| 379 |
for token_id, freq in token_freq.items():
|
| 380 |
+
if token_id < len(scaled_logits):
|
| 381 |
+
scaled_logits[token_id] /= (repetition_penalty ** freq)
|
| 382 |
+
|
| 383 |
+
# Top-K filtering
|
| 384 |
+
if top_k > 0 and top_k < len(scaled_logits):
|
| 385 |
+
top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
|
| 386 |
+
top_k_logits = scaled_logits[top_k_indices]
|
|
|
|
| 387 |
else:
|
| 388 |
+
top_k_indices = np.arange(len(scaled_logits))
|
| 389 |
+
top_k_logits = scaled_logits
|
| 390 |
|
| 391 |
+
# Softmax (numerically stable)
|
| 392 |
+
top_k_logits = top_k_logits - np.max(top_k_logits)
|
|
|
|
| 393 |
top_k_probs = np.exp(top_k_logits)
|
| 394 |
+
top_k_probs /= top_k_probs.sum()
|
|
|
|
| 395 |
|
| 396 |
+
# Top-P (nucleus) filtering
|
| 397 |
if top_p < 1.0:
|
| 398 |
sorted_idx = np.argsort(top_k_probs)[::-1]
|
| 399 |
cumsum = np.cumsum(top_k_probs[sorted_idx])
|
| 400 |
cutoff = np.searchsorted(cumsum, top_p) + 1
|
|
|
|
| 401 |
nucleus_idx = sorted_idx[:cutoff]
|
| 402 |
nucleus_probs = top_k_probs[nucleus_idx]
|
| 403 |
+
nucleus_probs /= nucleus_probs.sum()
|
| 404 |
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
|
| 405 |
return int(top_k_indices[nucleus_idx[sampled]])
|
| 406 |
else:
|
|
|
|
| 416 |
top_p: float = 0.9,
|
| 417 |
repetition_penalty: float = 1.1
|
| 418 |
):
|
| 419 |
+
"""Generate text with KV-cache for fast CPU inference."""
|
| 420 |
global stop_generation
|
| 421 |
stop_generation = False
|
| 422 |
|
| 423 |
+
# Tokenize prompt
|
| 424 |
prompt_ids = tokenizer.encode(prompt).ids
|
| 425 |
input_ids = [i for i in prompt_ids if i != eos_token_id]
|
| 426 |
|
| 427 |
if len(input_ids) == 0:
|
| 428 |
+
yield "Error: Empty prompt after tokenization"
|
| 429 |
return
|
| 430 |
|
| 431 |
generated_text = ""
|
| 432 |
token_count = 0
|
| 433 |
token_freq = {}
|
| 434 |
|
| 435 |
+
# Get special token IDs
|
| 436 |
im_end_id = tokenizer.token_to_id("<|im_end|>")
|
| 437 |
model_end_id = tokenizer.token_to_id("<im end for model tun>")
|
| 438 |
stop_ids = {eos_token_id, im_end_id, model_end_id}
|
|
|
|
| 442 |
|
| 443 |
start_time = time.time()
|
| 444 |
|
| 445 |
+
# === PREFILL PHASE ===
|
| 446 |
+
# Truncate if prompt is too long
|
| 447 |
if len(input_ids) > max_context - max_tokens:
|
| 448 |
input_ids = input_ids[-(max_context - max_tokens):]
|
| 449 |
|
| 450 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 451 |
|
| 452 |
try:
|
| 453 |
+
logits, past_kv = model(input_tensor, training=False, use_cache=True)
|
|
|
|
| 454 |
except Exception as e:
|
| 455 |
+
yield f"Error during prefill: {e}"
|
| 456 |
return
|
| 457 |
|
| 458 |
+
# Get logits for last position
|
| 459 |
+
next_token_logits = logits[0, -1, :].numpy()
|
| 460 |
+
|
| 461 |
prefill_time = time.time() - start_time
|
| 462 |
+
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
|
| 463 |
|
| 464 |
+
# === GENERATION LOOP ===
|
| 465 |
decode_start = time.time()
|
|
|
|
| 466 |
|
| 467 |
for step in range(max_tokens):
|
| 468 |
if stop_generation:
|
| 469 |
+
yield generated_text + "\n\n*[Generation stopped]*"
|
| 470 |
return
|
| 471 |
|
| 472 |
+
# Sample next token
|
| 473 |
+
next_token_id = sample_token(
|
| 474 |
+
next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Stop conditions
|
| 478 |
if next_token_id in stop_ids:
|
| 479 |
break
|
| 480 |
|
| 481 |
+
# Update frequency tracking
|
| 482 |
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
|
| 483 |
|
| 484 |
+
# Decode and yield
|
| 485 |
token_text = tokenizer.decode([next_token_id])
|
| 486 |
generated_text += token_text
|
| 487 |
token_count += 1
|
| 488 |
+
yield generated_text
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
+
# === DECODE PHASE (single token, reuse cache) ===
|
| 491 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 492 |
|
| 493 |
try:
|
| 494 |
logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
|
|
|
|
| 495 |
except Exception as e:
|
| 496 |
+
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 497 |
return
|
| 498 |
+
|
| 499 |
+
next_token_logits = logits[0, -1, :].numpy()
|
| 500 |
+
|
| 501 |
+
# Truncate cache if too long
|
| 502 |
+
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 503 |
+
if current_len > max_context:
|
| 504 |
+
trim_amount = current_len - max_context + 100 # Keep some buffer
|
| 505 |
+
past_kv = [
|
| 506 |
+
(k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
|
| 507 |
+
for k, v in past_kv
|
| 508 |
+
]
|
| 509 |
|
|
|
|
| 510 |
decode_time = time.time() - decode_start
|
| 511 |
total_time = time.time() - start_time
|
| 512 |
|
| 513 |
if token_count > 0:
|
|
|
|
| 514 |
decode_tps = token_count / decode_time if decode_time > 0 else 0
|
| 515 |
+
total_tps = token_count / total_time if total_time > 0 else 0
|
| 516 |
|
| 517 |
stats = (
|
| 518 |
+
f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
|
| 519 |
+
f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
|
|
|
|
|
|
|
| 520 |
)
|
| 521 |
|
| 522 |
if not stop_generation:
|
|
|
|
| 526 |
|
| 527 |
|
| 528 |
# ============================================================================
|
| 529 |
+
# Chat Interface Logic
|
| 530 |
# ============================================================================
|
| 531 |
|
| 532 |
def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
|
|
|
|
| 535 |
for user_msg, assistant_msg in history:
|
| 536 |
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
| 537 |
if assistant_msg:
|
| 538 |
+
# Clean up any stats from previous messages
|
| 539 |
clean_msg = assistant_msg.split("\n\n*[")[0]
|
| 540 |
prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
|
| 541 |
|
|
|
|
| 569 |
):
|
| 570 |
partial_response = generated
|
| 571 |
|
| 572 |
+
# Robust end-of-turn detection
|
| 573 |
stop_tags = ["<|im_end|>", "<im end for model tun>"]
|
| 574 |
earliest_stop = len(partial_response)
|
| 575 |
should_stop = False
|
|
|
|
| 583 |
|
| 584 |
display_response = partial_response
|
| 585 |
if should_stop:
|
| 586 |
+
# Keep the stats portion if present
|
| 587 |
stats_start = partial_response.find("\n\n*[")
|
| 588 |
if stats_start > earliest_stop:
|
| 589 |
display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
|
| 590 |
else:
|
| 591 |
display_response = partial_response[:earliest_stop]
|
| 592 |
|
| 593 |
+
# Post-process reasoning tags for display
|
| 594 |
if reasoning_enabled:
|
| 595 |
if '<think>' in display_response and '</think>' in display_response:
|
| 596 |
start_idx = display_response.find('<think>')
|
|
|
|
| 622 |
|
| 623 |
|
| 624 |
# ============================================================================
|
| 625 |
+
# Gradio UI
|
| 626 |
# ============================================================================
|
| 627 |
|
| 628 |
custom_css = """
|
|
|
|
| 697 |
<div id="welcome-modal" class="modal-overlay" style="display:none;">
|
| 698 |
<div class="modal-content">
|
| 699 |
<h2>π§ Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
|
| 700 |
+
<p>Our latest model features <strong>Chain-of-Thought (CoT)</strong> functionality and <strong>KV-Cache optimization</strong> for fast CPU inference!</p>
|
| 701 |
<div class="comparison-box">
|
| 702 |
<div class="comparison-mode mode-reasoning">
|
| 703 |
<h3>π‘ Reasoning Mode (ON)</h3>
|
| 704 |
+
<p>The model performs a <strong>CoT step</strong> first. The internal thought process is contained within <code><think>...</think></code> tags.</p>
|
| 705 |
</div>
|
| 706 |
<div class="comparison-mode mode-direct">
|
| 707 |
<h3>βͺ Direct Mode (OFF)</h3>
|
| 708 |
+
<p>The model generates the final answer immediately, maximizing speed.</p>
|
| 709 |
</div>
|
| 710 |
</div>
|
| 711 |
<button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
|
|
|
|
| 721 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
|
| 722 |
alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
|
| 723 |
<h1>π€ Sam-large-2 Chat π€</h1>
|
| 724 |
+
<p><strong>LATEST RELEASE!</strong> Our <strong>BEST Reasoning Model</strong> - Now with KV-Cache! <span class="speed-indicator">β‘ 5-20x Faster</span></p>
|
| 725 |
+
<div class="twin-badge">Reasoning Model</div>
|
| 726 |
<div class="celebration">π π« π― β‘ π₯</div>
|
| 727 |
</div>
|
| 728 |
""")
|
| 729 |
else:
|
| 730 |
+
gr.HTML("""<div class="header"><h1>π€ Sam-large-2 Chat</h1><p>Advanced Reasoning Model with KV-Cache</p></div>""")
|
| 731 |
|
| 732 |
with gr.Row():
|
| 733 |
with gr.Column(scale=4):
|
|
|
|
| 765 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
|
| 766 |
gr.Markdown("---")
|
| 767 |
gr.Markdown(f"""### π Sam-large-2 Model Info
|
| 768 |
+
**Type:** Chain-of-Thought Reasoning Model
|
| 769 |
**Vocab:** {config['vocab_size']:,}
|
| 770 |
**Layers:** {config['num_hidden_layers']}
|
| 771 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 772 |
+
**Optimization:** KV-Cache enabled β‘
|
|
|
|
| 773 |
""")
|
| 774 |
|
| 775 |
gr.Examples(
|
|
|
|
| 784 |
|
| 785 |
gr.HTML("""
|
| 786 |
<footer>
|
| 787 |
+
<p><strong>π Sam-large-2 - LATEST RELEASE with KV-Cache! π</strong></p>
|
| 788 |
+
<p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β’ Built by Smily studios with TensorFlow & Gradio</p>
|
| 789 |
</footer>
|
| 790 |
""")
|
| 791 |
|
|
|
|
| 845 |
|
| 846 |
if __name__ == "__main__":
|
| 847 |
print("\n" + "=" * 60)
|
| 848 |
+
print("π Starting Sam-large-2 Chat with KV-Cache Optimization")
|
| 849 |
print("=" * 60 + "\n")
|
| 850 |
demo.queue(max_size=20)
|
| 851 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|