Keeby-smilyai commited on
Commit
90b1095
Β·
verified Β·
1 Parent(s): cbfe110

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -86
app.py CHANGED
@@ -10,7 +10,6 @@ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
10
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
13
- os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0' # We'll handle precision manually
14
 
15
  import gradio as gr
16
  import tensorflow as tf
@@ -27,9 +26,11 @@ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
27
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
28
 
29
  # Enable XLA JIT compilation for CPU
30
- tf.config.optimizer.set_jit(True)
31
-
32
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
 
 
33
 
34
  # ============================================================================
35
  # 🎊 FESTIVE MODE TOGGLE 🎊
@@ -46,53 +47,45 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
46
  CACHE_DIR = "./model_cache"
47
 
48
  # ============================================================================
49
- # Optimized Model Architecture with KV-Cache
50
  # ============================================================================
51
 
52
  @keras.saving.register_keras_serializable()
53
  class RotaryEmbedding(keras.layers.Layer):
54
- """Optimized RoPE with pre-computed cache."""
55
 
56
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
57
  super().__init__(**kwargs)
58
  self.dim = dim
59
  self.max_len = max_len
60
  self.theta = theta
 
61
  self.cos_cached = None
62
  self.sin_cached = None
63
 
64
  def build(self, input_shape):
65
- # Pre-compute RoPE cache during build
66
- inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
67
- t = np.arange(self.max_len, dtype=np.float32)
68
- freqs = np.outer(t, inv_freq)
69
- emb = np.concatenate([freqs, freqs], axis=-1)
70
-
71
- # Store as non-trainable weights for better graph optimization
72
- self.cos_cached = self.add_weight(
73
- name="cos_cache",
74
- shape=emb.shape,
75
- initializer=keras.initializers.Constant(np.cos(emb)),
76
- trainable=False
77
- )
78
- self.sin_cached = self.add_weight(
79
- name="sin_cache",
80
- shape=emb.shape,
81
- initializer=keras.initializers.Constant(np.sin(emb)),
82
- trainable=False
83
- )
84
  super().build(input_shape)
85
 
86
- @tf.function(reduce_retracing=True)
 
 
 
 
 
 
 
 
 
87
  def call(self, q, k, offset=0):
88
  """Apply rotary embeddings with position offset for KV-cache."""
 
89
  seq_len = tf.shape(q)[2]
90
  dtype = q.dtype
91
 
92
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
93
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
94
 
95
- # Fused rotate_half operation
96
  x1_q, x2_q = tf.split(q, 2, axis=-1)
97
  x1_k, x2_k = tf.split(k, 2, axis=-1)
98
 
@@ -109,8 +102,6 @@ class RotaryEmbedding(keras.layers.Layer):
109
 
110
  @keras.saving.register_keras_serializable()
111
  class RMSNorm(keras.layers.Layer):
112
- """Optimized RMSNorm."""
113
-
114
  def __init__(self, epsilon=1e-5, **kwargs):
115
  super().__init__(**kwargs)
116
  self.epsilon = epsilon
@@ -120,9 +111,7 @@ class RMSNorm(keras.layers.Layer):
120
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
121
  super().build(input_shape)
122
 
123
- @tf.function(reduce_retracing=True)
124
  def call(self, x):
125
- # Fused computation
126
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
127
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
128
 
@@ -134,7 +123,7 @@ class RMSNorm(keras.layers.Layer):
134
 
135
  @keras.saving.register_keras_serializable()
136
  class TransformerBlock(keras.layers.Layer):
137
- """Optimized transformer block with efficient attention."""
138
 
139
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
140
  super().__init__(**kwargs)
@@ -149,17 +138,21 @@ class TransformerBlock(keras.layers.Layer):
149
  self.scale = 1.0 / np.sqrt(self.head_dim)
150
 
151
  def build(self, input_shape):
 
152
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
153
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
154
 
155
- # Fused QKV projection for better memory access
156
- self.qkv_proj = keras.layers.Dense(self.d_model * 3, use_bias=False, name="qkv_proj")
 
 
157
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
158
 
159
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
160
 
161
- # Fused gate/up projection
162
- self.gate_up_proj = keras.layers.Dense(self.ff_dim * 2, use_bias=False, name="gate_up_proj")
 
163
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
164
 
165
  self.dropout = keras.layers.Dropout(self.dropout_rate)
@@ -172,11 +165,15 @@ class TransformerBlock(keras.layers.Layer):
172
  res = x
173
  y = self.pre_attn_norm(x)
174
 
175
- # Fused QKV projection
176
- qkv = self.qkv_proj(y)
177
- qkv = tf.reshape(qkv, [B, T, 3, self.n_heads, self.head_dim])
178
- qkv = tf.transpose(qkv, [2, 0, 3, 1, 4]) # [3, B, n_heads, T, head_dim]
179
- q, k, v = qkv[0], qkv[1], qkv[2]
 
 
 
 
180
 
181
  # Determine position offset for RoPE
182
  past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
@@ -198,7 +195,7 @@ class TransformerBlock(keras.layers.Layer):
198
  # Optimized causal mask
199
  q_positions = tf.range(past_len, past_len + T)
200
  k_positions = tf.range(full_len)
201
- mask = tf.cast(q_positions[:, None] < k_positions[None, :], q.dtype) * -1e9
202
  scores = scores + mask[None, None, :, :]
203
 
204
  attn = tf.nn.softmax(scores, axis=-1)
@@ -208,12 +205,10 @@ class TransformerBlock(keras.layers.Layer):
208
 
209
  x = res + self.dropout(self.out_proj(attn_out), training=training)
210
 
211
- # Optimized FFN with fused gate/up
212
  res = x
213
  y = self.pre_ffn_norm(x)
214
- gate_up = self.gate_up_proj(y)
215
- gate, up = tf.split(gate_up, 2, axis=-1)
216
- ffn = self.down_proj(tf.nn.silu(gate) * up)
217
  output = res + self.dropout(ffn, training=training)
218
 
219
  return output, new_kv
@@ -234,8 +229,6 @@ class TransformerBlock(keras.layers.Layer):
234
 
235
  @keras.saving.register_keras_serializable()
236
  class SAM1Model(keras.Model):
237
- """Optimized SAM model with compiled inference."""
238
-
239
  def __init__(self, **kwargs):
240
  super().__init__()
241
  if 'config' in kwargs and isinstance(kwargs['config'], dict):
@@ -261,9 +254,6 @@ class SAM1Model(keras.Model):
261
  ]
262
  self.norm = RMSNorm(name="final_norm")
263
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
264
-
265
- self._compiled_prefill = None
266
- self._compiled_decode = None
267
 
268
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
269
  x = self.embed(input_ids)
@@ -279,19 +269,6 @@ class SAM1Model(keras.Model):
279
  logits = self.lm_head(self.norm(x))
280
  return logits, new_past_kv
281
 
282
- @tf.function(reduce_retracing=True)
283
- def prefill(self, input_ids):
284
- """Compiled prefill for initial prompt processing."""
285
- return self.call(input_ids, training=False, past_kv=None, use_cache=True)
286
-
287
- @tf.function(reduce_retracing=True, input_signature=[
288
- tf.TensorSpec(shape=[1, 1], dtype=tf.int32),
289
- tf.TensorSpec(shape=[None], dtype=tf.variant) # For the list of KV tuples
290
- ])
291
- def decode_step(self, input_ids, past_kv):
292
- """Compiled single-token decode step."""
293
- return self.call(input_ids, training=False, past_kv=past_kv, use_cache=True)
294
-
295
  def get_config(self):
296
  base_config = super().get_config()
297
  base_config['config'] = self.cfg
@@ -299,15 +276,9 @@ class SAM1Model(keras.Model):
299
 
300
 
301
  # ============================================================================
302
- # Optimized Sampling Functions
303
  # ============================================================================
304
 
305
- @lru_cache(maxsize=128)
306
- def get_top_k_mask(vocab_size, top_k):
307
- """Cache top-k masks for common vocab sizes."""
308
- return top_k
309
-
310
-
311
  class FastSampler:
312
  """Vectorized sampler for faster token selection."""
313
 
@@ -317,6 +288,9 @@ class FastSampler:
317
 
318
  def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
319
  """Optimized sampling with vectorized operations."""
 
 
 
320
  # Temperature scaling
321
  if temperature != 1.0:
322
  logits = logits / temperature
@@ -328,7 +302,8 @@ class FastSampler:
328
  valid_mask = freq_tokens < len(logits)
329
  freq_tokens = freq_tokens[valid_mask]
330
  freq_values = freq_values[valid_mask]
331
- logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
 
332
 
333
  # Top-K filtering with partial sort
334
  if 0 < top_k < len(logits):
@@ -455,6 +430,30 @@ for _ in range(3):
455
 
456
  print("βœ… Model warmed up and traces compiled")
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  # ============================================================================
459
  # Optimized Inference Logic with KV-Cache
460
  # ============================================================================
@@ -494,7 +493,7 @@ def generate_stream(
494
 
495
  max_context = config['max_position_embeddings']
496
 
497
- start_time = time.perf_counter() # More precise timing
498
 
499
  # === PREFILL PHASE ===
500
  if len(input_ids) > max_context - max_tokens:
@@ -503,23 +502,21 @@ def generate_stream(
503
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
504
 
505
  try:
506
- logits, past_kv = model(input_tensor, training=False, use_cache=True)
507
  except Exception as e:
508
  yield f"Error during prefill: {e}"
509
  return
510
 
511
- # Get logits for last position (avoid copy with indexing)
512
  next_token_logits = logits[0, -1, :].numpy()
513
 
514
  prefill_time = time.perf_counter() - start_time
515
- print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({len(input_ids)/prefill_time:.1f} tok/s)")
 
516
 
517
  # === GENERATION LOOP ===
518
  decode_start = time.perf_counter()
519
 
520
- # Pre-compute constants
521
- yield_interval = 1 # Yield every token for streaming
522
-
523
  for step in range(max_tokens):
524
  if stop_generation:
525
  yield generated_text + "\n\n*[Generation stopped]*"
@@ -541,23 +538,21 @@ def generate_stream(
541
  token_text = tokenizer.decode([next_token_id])
542
  generated_text += token_text
543
  token_count += 1
544
-
545
- if step % yield_interval == 0:
546
- yield generated_text
547
 
548
  # === DECODE PHASE (single token, reuse cache) ===
549
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
550
 
551
  try:
552
- logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
553
  except Exception as e:
554
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
555
  return
556
 
557
  next_token_logits = logits[0, -1, :].numpy()
558
 
559
- # Truncate cache if too long (less frequent check)
560
- if step % 100 == 0:
561
  current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
562
  if current_len > max_context:
563
  trim_amount = current_len - max_context + 100
@@ -827,7 +822,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
827
  **Vocab:** {config['vocab_size']:,}
828
  **Layers:** {config['num_hidden_layers']}
829
  **Context:** {config['max_position_embeddings']:,} tokens
830
- **Optimization:** KV-Cache + XLA JIT ⚑
831
  """)
832
 
833
  gr.Examples(
 
10
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
 
13
 
14
  import gradio as gr
15
  import tensorflow as tf
 
26
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
27
 
28
  # Enable XLA JIT compilation for CPU
29
+ try:
30
+ tf.config.optimizer.set_jit(True)
31
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
32
+ except:
33
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
34
 
35
  # ============================================================================
36
  # 🎊 FESTIVE MODE TOGGLE 🎊
 
47
  CACHE_DIR = "./model_cache"
48
 
49
  # ============================================================================
50
+ # Model Architecture - MUST MATCH CHECKPOINT STRUCTURE
51
  # ============================================================================
52
 
53
  @keras.saving.register_keras_serializable()
54
  class RotaryEmbedding(keras.layers.Layer):
55
+ """RoPE with pre-computed cache (no trainable weights - compatible with checkpoint)."""
56
 
57
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
58
  super().__init__(**kwargs)
59
  self.dim = dim
60
  self.max_len = max_len
61
  self.theta = theta
62
+ self.built_cache = False
63
  self.cos_cached = None
64
  self.sin_cached = None
65
 
66
  def build(self, input_shape):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  super().build(input_shape)
68
 
69
+ def _build_cache(self):
70
+ if not self.built_cache:
71
+ inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
72
+ t = np.arange(self.max_len, dtype=np.float32)
73
+ freqs = np.outer(t, inv_freq)
74
+ emb = np.concatenate([freqs, freqs], axis=-1)
75
+ self.cos_cached = tf.constant(np.cos(emb), dtype=tf.float32)
76
+ self.sin_cached = tf.constant(np.sin(emb), dtype=tf.float32)
77
+ self.built_cache = True
78
+
79
  def call(self, q, k, offset=0):
80
  """Apply rotary embeddings with position offset for KV-cache."""
81
+ self._build_cache()
82
  seq_len = tf.shape(q)[2]
83
  dtype = q.dtype
84
 
85
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
86
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
87
 
88
+ # Fused rotate_half
89
  x1_q, x2_q = tf.split(q, 2, axis=-1)
90
  x1_k, x2_k = tf.split(k, 2, axis=-1)
91
 
 
102
 
103
  @keras.saving.register_keras_serializable()
104
  class RMSNorm(keras.layers.Layer):
 
 
105
  def __init__(self, epsilon=1e-5, **kwargs):
106
  super().__init__(**kwargs)
107
  self.epsilon = epsilon
 
111
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
112
  super().build(input_shape)
113
 
 
114
  def call(self, x):
 
115
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
116
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
117
 
 
123
 
124
  @keras.saving.register_keras_serializable()
125
  class TransformerBlock(keras.layers.Layer):
126
+ """Transformer block - MATCHES ORIGINAL CHECKPOINT STRUCTURE."""
127
 
128
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
129
  super().__init__(**kwargs)
 
138
  self.scale = 1.0 / np.sqrt(self.head_dim)
139
 
140
  def build(self, input_shape):
141
+ # MUST use same layer names as checkpoint
142
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
143
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
144
 
145
+ # Separate Q, K, V projections (matches checkpoint)
146
+ self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
147
+ self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
148
+ self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
149
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
150
 
151
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
152
 
153
+ # Separate gate, up, down projections (matches checkpoint)
154
+ self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
155
+ self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
156
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
157
 
158
  self.dropout = keras.layers.Dropout(self.dropout_rate)
 
165
  res = x
166
  y = self.pre_attn_norm(x)
167
 
168
+ # Separate Q, K, V projections
169
+ q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
170
+ q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
171
+
172
+ k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
173
+ k = tf.transpose(k, [0, 2, 1, 3])
174
+
175
+ v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
176
+ v = tf.transpose(v, [0, 2, 1, 3])
177
 
178
  # Determine position offset for RoPE
179
  past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
 
195
  # Optimized causal mask
196
  q_positions = tf.range(past_len, past_len + T)
197
  k_positions = tf.range(full_len)
198
+ mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
199
  scores = scores + mask[None, None, :, :]
200
 
201
  attn = tf.nn.softmax(scores, axis=-1)
 
205
 
206
  x = res + self.dropout(self.out_proj(attn_out), training=training)
207
 
208
+ # FFN with SwiGLU
209
  res = x
210
  y = self.pre_ffn_norm(x)
211
+ ffn = self.down_proj(tf.nn.silu(self.gate_proj(y)) * self.up_proj(y))
 
 
212
  output = res + self.dropout(ffn, training=training)
213
 
214
  return output, new_kv
 
229
 
230
  @keras.saving.register_keras_serializable()
231
  class SAM1Model(keras.Model):
 
 
232
  def __init__(self, **kwargs):
233
  super().__init__()
234
  if 'config' in kwargs and isinstance(kwargs['config'], dict):
 
254
  ]
255
  self.norm = RMSNorm(name="final_norm")
256
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
 
 
 
257
 
258
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
259
  x = self.embed(input_ids)
 
269
  logits = self.lm_head(self.norm(x))
270
  return logits, new_past_kv
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def get_config(self):
273
  base_config = super().get_config()
274
  base_config['config'] = self.cfg
 
276
 
277
 
278
  # ============================================================================
279
+ # Optimized Sampling
280
  # ============================================================================
281
 
 
 
 
 
 
 
282
  class FastSampler:
283
  """Vectorized sampler for faster token selection."""
284
 
 
288
 
289
  def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
290
  """Optimized sampling with vectorized operations."""
291
+ # Make a copy to avoid modifying original
292
+ logits = logits.copy()
293
+
294
  # Temperature scaling
295
  if temperature != 1.0:
296
  logits = logits / temperature
 
302
  valid_mask = freq_tokens < len(logits)
303
  freq_tokens = freq_tokens[valid_mask]
304
  freq_values = freq_values[valid_mask]
305
+ if len(freq_tokens) > 0:
306
+ logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
307
 
308
  # Top-K filtering with partial sort
309
  if 0 < top_k < len(logits):
 
430
 
431
  print("βœ… Model warmed up and traces compiled")
432
 
433
+ # ============================================================================
434
+ # Compiled Inference Functions
435
+ # ============================================================================
436
+
437
+ # Create tf.function wrapped inference for speed
438
+ @tf.function(reduce_retracing=True)
439
+ def model_prefill(input_ids):
440
+ """Compiled prefill function."""
441
+ return model(input_ids, training=False, use_cache=True)
442
+
443
+
444
+ @tf.function(reduce_retracing=True)
445
+ def model_decode(input_ids, past_kv):
446
+ """Compiled single-token decode function."""
447
+ return model(input_ids, training=False, past_kv=past_kv, use_cache=True)
448
+
449
+
450
+ # Additional warmup for compiled functions
451
+ print("πŸ”₯ Compiling tf.function traces...")
452
+ _ = model_prefill(warmup_input)
453
+ _ = model_decode(single_token, past_kv)
454
+ print("βœ… Compiled functions ready")
455
+
456
+
457
  # ============================================================================
458
  # Optimized Inference Logic with KV-Cache
459
  # ============================================================================
 
493
 
494
  max_context = config['max_position_embeddings']
495
 
496
+ start_time = time.perf_counter()
497
 
498
  # === PREFILL PHASE ===
499
  if len(input_ids) > max_context - max_tokens:
 
502
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
503
 
504
  try:
505
+ logits, past_kv = model_prefill(input_tensor)
506
  except Exception as e:
507
  yield f"Error during prefill: {e}"
508
  return
509
 
510
+ # Get logits for last position
511
  next_token_logits = logits[0, -1, :].numpy()
512
 
513
  prefill_time = time.perf_counter() - start_time
514
+ prefill_tps = len(input_ids) / prefill_time if prefill_time > 0 else 0
515
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({prefill_tps:.1f} tok/s)")
516
 
517
  # === GENERATION LOOP ===
518
  decode_start = time.perf_counter()
519
 
 
 
 
520
  for step in range(max_tokens):
521
  if stop_generation:
522
  yield generated_text + "\n\n*[Generation stopped]*"
 
538
  token_text = tokenizer.decode([next_token_id])
539
  generated_text += token_text
540
  token_count += 1
541
+ yield generated_text
 
 
542
 
543
  # === DECODE PHASE (single token, reuse cache) ===
544
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
545
 
546
  try:
547
+ logits, past_kv = model_decode(next_input, past_kv)
548
  except Exception as e:
549
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
550
  return
551
 
552
  next_token_logits = logits[0, -1, :].numpy()
553
 
554
+ # Truncate cache if too long (check less frequently)
555
+ if step % 100 == 99:
556
  current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
557
  if current_len > max_context:
558
  trim_amount = current_len - max_context + 100
 
822
  **Vocab:** {config['vocab_size']:,}
823
  **Layers:** {config['num_hidden_layers']}
824
  **Context:** {config['max_position_embeddings']:,} tokens
825
+ **Optimization:** KV-Cache + XLA ⚑
826
  """)
827
 
828
  gr.Examples(