Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,6 @@ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
|
|
| 10 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
|
| 11 |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
|
| 12 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
|
| 13 |
-
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0' # We'll handle precision manually
|
| 14 |
|
| 15 |
import gradio as gr
|
| 16 |
import tensorflow as tf
|
|
@@ -27,9 +26,11 @@ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
|
| 27 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 28 |
|
| 29 |
# Enable XLA JIT compilation for CPU
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# ============================================================================
|
| 35 |
# π FESTIVE MODE TOGGLE π
|
|
@@ -46,53 +47,45 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
|
|
| 46 |
CACHE_DIR = "./model_cache"
|
| 47 |
|
| 48 |
# ============================================================================
|
| 49 |
-
#
|
| 50 |
# ============================================================================
|
| 51 |
|
| 52 |
@keras.saving.register_keras_serializable()
|
| 53 |
class RotaryEmbedding(keras.layers.Layer):
|
| 54 |
-
"""
|
| 55 |
|
| 56 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 57 |
super().__init__(**kwargs)
|
| 58 |
self.dim = dim
|
| 59 |
self.max_len = max_len
|
| 60 |
self.theta = theta
|
|
|
|
| 61 |
self.cos_cached = None
|
| 62 |
self.sin_cached = None
|
| 63 |
|
| 64 |
def build(self, input_shape):
|
| 65 |
-
# Pre-compute RoPE cache during build
|
| 66 |
-
inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
|
| 67 |
-
t = np.arange(self.max_len, dtype=np.float32)
|
| 68 |
-
freqs = np.outer(t, inv_freq)
|
| 69 |
-
emb = np.concatenate([freqs, freqs], axis=-1)
|
| 70 |
-
|
| 71 |
-
# Store as non-trainable weights for better graph optimization
|
| 72 |
-
self.cos_cached = self.add_weight(
|
| 73 |
-
name="cos_cache",
|
| 74 |
-
shape=emb.shape,
|
| 75 |
-
initializer=keras.initializers.Constant(np.cos(emb)),
|
| 76 |
-
trainable=False
|
| 77 |
-
)
|
| 78 |
-
self.sin_cached = self.add_weight(
|
| 79 |
-
name="sin_cache",
|
| 80 |
-
shape=emb.shape,
|
| 81 |
-
initializer=keras.initializers.Constant(np.sin(emb)),
|
| 82 |
-
trainable=False
|
| 83 |
-
)
|
| 84 |
super().build(input_shape)
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def call(self, q, k, offset=0):
|
| 88 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
|
|
|
| 89 |
seq_len = tf.shape(q)[2]
|
| 90 |
dtype = q.dtype
|
| 91 |
|
| 92 |
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 93 |
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 94 |
|
| 95 |
-
# Fused rotate_half
|
| 96 |
x1_q, x2_q = tf.split(q, 2, axis=-1)
|
| 97 |
x1_k, x2_k = tf.split(k, 2, axis=-1)
|
| 98 |
|
|
@@ -109,8 +102,6 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
| 109 |
|
| 110 |
@keras.saving.register_keras_serializable()
|
| 111 |
class RMSNorm(keras.layers.Layer):
|
| 112 |
-
"""Optimized RMSNorm."""
|
| 113 |
-
|
| 114 |
def __init__(self, epsilon=1e-5, **kwargs):
|
| 115 |
super().__init__(**kwargs)
|
| 116 |
self.epsilon = epsilon
|
|
@@ -120,9 +111,7 @@ class RMSNorm(keras.layers.Layer):
|
|
| 120 |
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
|
| 121 |
super().build(input_shape)
|
| 122 |
|
| 123 |
-
@tf.function(reduce_retracing=True)
|
| 124 |
def call(self, x):
|
| 125 |
-
# Fused computation
|
| 126 |
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
|
| 127 |
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
|
| 128 |
|
|
@@ -134,7 +123,7 @@ class RMSNorm(keras.layers.Layer):
|
|
| 134 |
|
| 135 |
@keras.saving.register_keras_serializable()
|
| 136 |
class TransformerBlock(keras.layers.Layer):
|
| 137 |
-
"""
|
| 138 |
|
| 139 |
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
|
| 140 |
super().__init__(**kwargs)
|
|
@@ -149,17 +138,21 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 149 |
self.scale = 1.0 / np.sqrt(self.head_dim)
|
| 150 |
|
| 151 |
def build(self, input_shape):
|
|
|
|
| 152 |
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
|
| 153 |
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
|
| 154 |
|
| 155 |
-
#
|
| 156 |
-
self.
|
|
|
|
|
|
|
| 157 |
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
|
| 158 |
|
| 159 |
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
|
| 160 |
|
| 161 |
-
#
|
| 162 |
-
self.
|
|
|
|
| 163 |
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
|
| 164 |
|
| 165 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
|
@@ -172,11 +165,15 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 172 |
res = x
|
| 173 |
y = self.pre_attn_norm(x)
|
| 174 |
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Determine position offset for RoPE
|
| 182 |
past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
|
|
@@ -198,7 +195,7 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 198 |
# Optimized causal mask
|
| 199 |
q_positions = tf.range(past_len, past_len + T)
|
| 200 |
k_positions = tf.range(full_len)
|
| 201 |
-
mask = tf.cast(q_positions[:, None] < k_positions[None, :],
|
| 202 |
scores = scores + mask[None, None, :, :]
|
| 203 |
|
| 204 |
attn = tf.nn.softmax(scores, axis=-1)
|
|
@@ -208,12 +205,10 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 208 |
|
| 209 |
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 210 |
|
| 211 |
-
#
|
| 212 |
res = x
|
| 213 |
y = self.pre_ffn_norm(x)
|
| 214 |
-
|
| 215 |
-
gate, up = tf.split(gate_up, 2, axis=-1)
|
| 216 |
-
ffn = self.down_proj(tf.nn.silu(gate) * up)
|
| 217 |
output = res + self.dropout(ffn, training=training)
|
| 218 |
|
| 219 |
return output, new_kv
|
|
@@ -234,8 +229,6 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 234 |
|
| 235 |
@keras.saving.register_keras_serializable()
|
| 236 |
class SAM1Model(keras.Model):
|
| 237 |
-
"""Optimized SAM model with compiled inference."""
|
| 238 |
-
|
| 239 |
def __init__(self, **kwargs):
|
| 240 |
super().__init__()
|
| 241 |
if 'config' in kwargs and isinstance(kwargs['config'], dict):
|
|
@@ -261,9 +254,6 @@ class SAM1Model(keras.Model):
|
|
| 261 |
]
|
| 262 |
self.norm = RMSNorm(name="final_norm")
|
| 263 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 264 |
-
|
| 265 |
-
self._compiled_prefill = None
|
| 266 |
-
self._compiled_decode = None
|
| 267 |
|
| 268 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
| 269 |
x = self.embed(input_ids)
|
|
@@ -279,19 +269,6 @@ class SAM1Model(keras.Model):
|
|
| 279 |
logits = self.lm_head(self.norm(x))
|
| 280 |
return logits, new_past_kv
|
| 281 |
|
| 282 |
-
@tf.function(reduce_retracing=True)
|
| 283 |
-
def prefill(self, input_ids):
|
| 284 |
-
"""Compiled prefill for initial prompt processing."""
|
| 285 |
-
return self.call(input_ids, training=False, past_kv=None, use_cache=True)
|
| 286 |
-
|
| 287 |
-
@tf.function(reduce_retracing=True, input_signature=[
|
| 288 |
-
tf.TensorSpec(shape=[1, 1], dtype=tf.int32),
|
| 289 |
-
tf.TensorSpec(shape=[None], dtype=tf.variant) # For the list of KV tuples
|
| 290 |
-
])
|
| 291 |
-
def decode_step(self, input_ids, past_kv):
|
| 292 |
-
"""Compiled single-token decode step."""
|
| 293 |
-
return self.call(input_ids, training=False, past_kv=past_kv, use_cache=True)
|
| 294 |
-
|
| 295 |
def get_config(self):
|
| 296 |
base_config = super().get_config()
|
| 297 |
base_config['config'] = self.cfg
|
|
@@ -299,15 +276,9 @@ class SAM1Model(keras.Model):
|
|
| 299 |
|
| 300 |
|
| 301 |
# ============================================================================
|
| 302 |
-
# Optimized Sampling
|
| 303 |
# ============================================================================
|
| 304 |
|
| 305 |
-
@lru_cache(maxsize=128)
|
| 306 |
-
def get_top_k_mask(vocab_size, top_k):
|
| 307 |
-
"""Cache top-k masks for common vocab sizes."""
|
| 308 |
-
return top_k
|
| 309 |
-
|
| 310 |
-
|
| 311 |
class FastSampler:
|
| 312 |
"""Vectorized sampler for faster token selection."""
|
| 313 |
|
|
@@ -317,6 +288,9 @@ class FastSampler:
|
|
| 317 |
|
| 318 |
def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 319 |
"""Optimized sampling with vectorized operations."""
|
|
|
|
|
|
|
|
|
|
| 320 |
# Temperature scaling
|
| 321 |
if temperature != 1.0:
|
| 322 |
logits = logits / temperature
|
|
@@ -328,7 +302,8 @@ class FastSampler:
|
|
| 328 |
valid_mask = freq_tokens < len(logits)
|
| 329 |
freq_tokens = freq_tokens[valid_mask]
|
| 330 |
freq_values = freq_values[valid_mask]
|
| 331 |
-
|
|
|
|
| 332 |
|
| 333 |
# Top-K filtering with partial sort
|
| 334 |
if 0 < top_k < len(logits):
|
|
@@ -455,6 +430,30 @@ for _ in range(3):
|
|
| 455 |
|
| 456 |
print("β
Model warmed up and traces compiled")
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
# ============================================================================
|
| 459 |
# Optimized Inference Logic with KV-Cache
|
| 460 |
# ============================================================================
|
|
@@ -494,7 +493,7 @@ def generate_stream(
|
|
| 494 |
|
| 495 |
max_context = config['max_position_embeddings']
|
| 496 |
|
| 497 |
-
start_time = time.perf_counter()
|
| 498 |
|
| 499 |
# === PREFILL PHASE ===
|
| 500 |
if len(input_ids) > max_context - max_tokens:
|
|
@@ -503,23 +502,21 @@ def generate_stream(
|
|
| 503 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 504 |
|
| 505 |
try:
|
| 506 |
-
logits, past_kv =
|
| 507 |
except Exception as e:
|
| 508 |
yield f"Error during prefill: {e}"
|
| 509 |
return
|
| 510 |
|
| 511 |
-
# Get logits for last position
|
| 512 |
next_token_logits = logits[0, -1, :].numpy()
|
| 513 |
|
| 514 |
prefill_time = time.perf_counter() - start_time
|
| 515 |
-
|
|
|
|
| 516 |
|
| 517 |
# === GENERATION LOOP ===
|
| 518 |
decode_start = time.perf_counter()
|
| 519 |
|
| 520 |
-
# Pre-compute constants
|
| 521 |
-
yield_interval = 1 # Yield every token for streaming
|
| 522 |
-
|
| 523 |
for step in range(max_tokens):
|
| 524 |
if stop_generation:
|
| 525 |
yield generated_text + "\n\n*[Generation stopped]*"
|
|
@@ -541,23 +538,21 @@ def generate_stream(
|
|
| 541 |
token_text = tokenizer.decode([next_token_id])
|
| 542 |
generated_text += token_text
|
| 543 |
token_count += 1
|
| 544 |
-
|
| 545 |
-
if step % yield_interval == 0:
|
| 546 |
-
yield generated_text
|
| 547 |
|
| 548 |
# === DECODE PHASE (single token, reuse cache) ===
|
| 549 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 550 |
|
| 551 |
try:
|
| 552 |
-
logits, past_kv =
|
| 553 |
except Exception as e:
|
| 554 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 555 |
return
|
| 556 |
|
| 557 |
next_token_logits = logits[0, -1, :].numpy()
|
| 558 |
|
| 559 |
-
# Truncate cache if too long (less
|
| 560 |
-
if step % 100 ==
|
| 561 |
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 562 |
if current_len > max_context:
|
| 563 |
trim_amount = current_len - max_context + 100
|
|
@@ -827,7 +822,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 827 |
**Vocab:** {config['vocab_size']:,}
|
| 828 |
**Layers:** {config['num_hidden_layers']}
|
| 829 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 830 |
-
**Optimization:** KV-Cache + XLA
|
| 831 |
""")
|
| 832 |
|
| 833 |
gr.Examples(
|
|
|
|
| 10 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
|
| 11 |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
|
| 12 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
|
|
|
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
import tensorflow as tf
|
|
|
|
| 26 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 27 |
|
| 28 |
# Enable XLA JIT compilation for CPU
|
| 29 |
+
try:
|
| 30 |
+
tf.config.optimizer.set_jit(True)
|
| 31 |
+
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
|
| 32 |
+
except:
|
| 33 |
+
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled")
|
| 34 |
|
| 35 |
# ============================================================================
|
| 36 |
# π FESTIVE MODE TOGGLE π
|
|
|
|
| 47 |
CACHE_DIR = "./model_cache"
|
| 48 |
|
| 49 |
# ============================================================================
|
| 50 |
+
# Model Architecture - MUST MATCH CHECKPOINT STRUCTURE
|
| 51 |
# ============================================================================
|
| 52 |
|
| 53 |
@keras.saving.register_keras_serializable()
|
| 54 |
class RotaryEmbedding(keras.layers.Layer):
|
| 55 |
+
"""RoPE with pre-computed cache (no trainable weights - compatible with checkpoint)."""
|
| 56 |
|
| 57 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 58 |
super().__init__(**kwargs)
|
| 59 |
self.dim = dim
|
| 60 |
self.max_len = max_len
|
| 61 |
self.theta = theta
|
| 62 |
+
self.built_cache = False
|
| 63 |
self.cos_cached = None
|
| 64 |
self.sin_cached = None
|
| 65 |
|
| 66 |
def build(self, input_shape):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
super().build(input_shape)
|
| 68 |
|
| 69 |
+
def _build_cache(self):
|
| 70 |
+
if not self.built_cache:
|
| 71 |
+
inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
|
| 72 |
+
t = np.arange(self.max_len, dtype=np.float32)
|
| 73 |
+
freqs = np.outer(t, inv_freq)
|
| 74 |
+
emb = np.concatenate([freqs, freqs], axis=-1)
|
| 75 |
+
self.cos_cached = tf.constant(np.cos(emb), dtype=tf.float32)
|
| 76 |
+
self.sin_cached = tf.constant(np.sin(emb), dtype=tf.float32)
|
| 77 |
+
self.built_cache = True
|
| 78 |
+
|
| 79 |
def call(self, q, k, offset=0):
|
| 80 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
| 81 |
+
self._build_cache()
|
| 82 |
seq_len = tf.shape(q)[2]
|
| 83 |
dtype = q.dtype
|
| 84 |
|
| 85 |
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 86 |
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 87 |
|
| 88 |
+
# Fused rotate_half
|
| 89 |
x1_q, x2_q = tf.split(q, 2, axis=-1)
|
| 90 |
x1_k, x2_k = tf.split(k, 2, axis=-1)
|
| 91 |
|
|
|
|
| 102 |
|
| 103 |
@keras.saving.register_keras_serializable()
|
| 104 |
class RMSNorm(keras.layers.Layer):
|
|
|
|
|
|
|
| 105 |
def __init__(self, epsilon=1e-5, **kwargs):
|
| 106 |
super().__init__(**kwargs)
|
| 107 |
self.epsilon = epsilon
|
|
|
|
| 111 |
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
|
| 112 |
super().build(input_shape)
|
| 113 |
|
|
|
|
| 114 |
def call(self, x):
|
|
|
|
| 115 |
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
|
| 116 |
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
|
| 117 |
|
|
|
|
| 123 |
|
| 124 |
@keras.saving.register_keras_serializable()
|
| 125 |
class TransformerBlock(keras.layers.Layer):
|
| 126 |
+
"""Transformer block - MATCHES ORIGINAL CHECKPOINT STRUCTURE."""
|
| 127 |
|
| 128 |
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
|
| 129 |
super().__init__(**kwargs)
|
|
|
|
| 138 |
self.scale = 1.0 / np.sqrt(self.head_dim)
|
| 139 |
|
| 140 |
def build(self, input_shape):
|
| 141 |
+
# MUST use same layer names as checkpoint
|
| 142 |
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
|
| 143 |
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
|
| 144 |
|
| 145 |
+
# Separate Q, K, V projections (matches checkpoint)
|
| 146 |
+
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
|
| 147 |
+
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
|
| 148 |
+
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
|
| 149 |
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
|
| 150 |
|
| 151 |
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
|
| 152 |
|
| 153 |
+
# Separate gate, up, down projections (matches checkpoint)
|
| 154 |
+
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
|
| 155 |
+
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
|
| 156 |
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
|
| 157 |
|
| 158 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
|
|
|
| 165 |
res = x
|
| 166 |
y = self.pre_attn_norm(x)
|
| 167 |
|
| 168 |
+
# Separate Q, K, V projections
|
| 169 |
+
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 170 |
+
q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
|
| 171 |
+
|
| 172 |
+
k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 173 |
+
k = tf.transpose(k, [0, 2, 1, 3])
|
| 174 |
+
|
| 175 |
+
v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 176 |
+
v = tf.transpose(v, [0, 2, 1, 3])
|
| 177 |
|
| 178 |
# Determine position offset for RoPE
|
| 179 |
past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
|
|
|
|
| 195 |
# Optimized causal mask
|
| 196 |
q_positions = tf.range(past_len, past_len + T)
|
| 197 |
k_positions = tf.range(full_len)
|
| 198 |
+
mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
|
| 199 |
scores = scores + mask[None, None, :, :]
|
| 200 |
|
| 201 |
attn = tf.nn.softmax(scores, axis=-1)
|
|
|
|
| 205 |
|
| 206 |
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 207 |
|
| 208 |
+
# FFN with SwiGLU
|
| 209 |
res = x
|
| 210 |
y = self.pre_ffn_norm(x)
|
| 211 |
+
ffn = self.down_proj(tf.nn.silu(self.gate_proj(y)) * self.up_proj(y))
|
|
|
|
|
|
|
| 212 |
output = res + self.dropout(ffn, training=training)
|
| 213 |
|
| 214 |
return output, new_kv
|
|
|
|
| 229 |
|
| 230 |
@keras.saving.register_keras_serializable()
|
| 231 |
class SAM1Model(keras.Model):
|
|
|
|
|
|
|
| 232 |
def __init__(self, **kwargs):
|
| 233 |
super().__init__()
|
| 234 |
if 'config' in kwargs and isinstance(kwargs['config'], dict):
|
|
|
|
| 254 |
]
|
| 255 |
self.norm = RMSNorm(name="final_norm")
|
| 256 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
| 259 |
x = self.embed(input_ids)
|
|
|
|
| 269 |
logits = self.lm_head(self.norm(x))
|
| 270 |
return logits, new_past_kv
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
def get_config(self):
|
| 273 |
base_config = super().get_config()
|
| 274 |
base_config['config'] = self.cfg
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
# ============================================================================
|
| 279 |
+
# Optimized Sampling
|
| 280 |
# ============================================================================
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
class FastSampler:
|
| 283 |
"""Vectorized sampler for faster token selection."""
|
| 284 |
|
|
|
|
| 288 |
|
| 289 |
def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 290 |
"""Optimized sampling with vectorized operations."""
|
| 291 |
+
# Make a copy to avoid modifying original
|
| 292 |
+
logits = logits.copy()
|
| 293 |
+
|
| 294 |
# Temperature scaling
|
| 295 |
if temperature != 1.0:
|
| 296 |
logits = logits / temperature
|
|
|
|
| 302 |
valid_mask = freq_tokens < len(logits)
|
| 303 |
freq_tokens = freq_tokens[valid_mask]
|
| 304 |
freq_values = freq_values[valid_mask]
|
| 305 |
+
if len(freq_tokens) > 0:
|
| 306 |
+
logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
|
| 307 |
|
| 308 |
# Top-K filtering with partial sort
|
| 309 |
if 0 < top_k < len(logits):
|
|
|
|
| 430 |
|
| 431 |
print("β
Model warmed up and traces compiled")
|
| 432 |
|
| 433 |
+
# ============================================================================
|
| 434 |
+
# Compiled Inference Functions
|
| 435 |
+
# ============================================================================
|
| 436 |
+
|
| 437 |
+
# Create tf.function wrapped inference for speed
|
| 438 |
+
@tf.function(reduce_retracing=True)
|
| 439 |
+
def model_prefill(input_ids):
|
| 440 |
+
"""Compiled prefill function."""
|
| 441 |
+
return model(input_ids, training=False, use_cache=True)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@tf.function(reduce_retracing=True)
|
| 445 |
+
def model_decode(input_ids, past_kv):
|
| 446 |
+
"""Compiled single-token decode function."""
|
| 447 |
+
return model(input_ids, training=False, past_kv=past_kv, use_cache=True)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# Additional warmup for compiled functions
|
| 451 |
+
print("π₯ Compiling tf.function traces...")
|
| 452 |
+
_ = model_prefill(warmup_input)
|
| 453 |
+
_ = model_decode(single_token, past_kv)
|
| 454 |
+
print("β
Compiled functions ready")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
# ============================================================================
|
| 458 |
# Optimized Inference Logic with KV-Cache
|
| 459 |
# ============================================================================
|
|
|
|
| 493 |
|
| 494 |
max_context = config['max_position_embeddings']
|
| 495 |
|
| 496 |
+
start_time = time.perf_counter()
|
| 497 |
|
| 498 |
# === PREFILL PHASE ===
|
| 499 |
if len(input_ids) > max_context - max_tokens:
|
|
|
|
| 502 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 503 |
|
| 504 |
try:
|
| 505 |
+
logits, past_kv = model_prefill(input_tensor)
|
| 506 |
except Exception as e:
|
| 507 |
yield f"Error during prefill: {e}"
|
| 508 |
return
|
| 509 |
|
| 510 |
+
# Get logits for last position
|
| 511 |
next_token_logits = logits[0, -1, :].numpy()
|
| 512 |
|
| 513 |
prefill_time = time.perf_counter() - start_time
|
| 514 |
+
prefill_tps = len(input_ids) / prefill_time if prefill_time > 0 else 0
|
| 515 |
+
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({prefill_tps:.1f} tok/s)")
|
| 516 |
|
| 517 |
# === GENERATION LOOP ===
|
| 518 |
decode_start = time.perf_counter()
|
| 519 |
|
|
|
|
|
|
|
|
|
|
| 520 |
for step in range(max_tokens):
|
| 521 |
if stop_generation:
|
| 522 |
yield generated_text + "\n\n*[Generation stopped]*"
|
|
|
|
| 538 |
token_text = tokenizer.decode([next_token_id])
|
| 539 |
generated_text += token_text
|
| 540 |
token_count += 1
|
| 541 |
+
yield generated_text
|
|
|
|
|
|
|
| 542 |
|
| 543 |
# === DECODE PHASE (single token, reuse cache) ===
|
| 544 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 545 |
|
| 546 |
try:
|
| 547 |
+
logits, past_kv = model_decode(next_input, past_kv)
|
| 548 |
except Exception as e:
|
| 549 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 550 |
return
|
| 551 |
|
| 552 |
next_token_logits = logits[0, -1, :].numpy()
|
| 553 |
|
| 554 |
+
# Truncate cache if too long (check less frequently)
|
| 555 |
+
if step % 100 == 99:
|
| 556 |
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 557 |
if current_len > max_context:
|
| 558 |
trim_amount = current_len - max_context + 100
|
|
|
|
| 822 |
**Vocab:** {config['vocab_size']:,}
|
| 823 |
**Layers:** {config['num_hidden_layers']}
|
| 824 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 825 |
+
**Optimization:** KV-Cache + XLA β‘
|
| 826 |
""")
|
| 827 |
|
| 828 |
gr.Examples(
|