Keeby-smilyai commited on
Commit
436b502
·
verified ·
1 Parent(s): 579190c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -76
app.py CHANGED
@@ -1,3 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import keras
@@ -52,13 +77,20 @@ class RotaryEmbedding(keras.layers.Layer):
52
  x1, x2 = tf.split(x, 2, axis=-1)
53
  return tf.concat([-x2, x1], axis=-1)
54
 
55
- def call(self, q, k):
 
56
  self._build_cache()
57
  seq_len = tf.shape(q)[2]
58
  dtype = q.dtype
59
- cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
60
- sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
61
- return (q * cos) + (self.rotate_half(q) * sin), (k * cos) + (self.rotate_half(k) * sin)
 
 
 
 
 
 
62
 
63
  def get_config(self):
64
  config = super().get_config()
@@ -110,29 +142,82 @@ class TransformerBlock(keras.layers.Layer):
110
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
111
  self.dropout = keras.layers.Dropout(dropout)
112
 
113
- def call(self, x, training=None):
114
- B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
 
 
 
 
 
 
 
 
 
115
  dtype = x.dtype
 
116
  res = x
117
  y = self.pre_attn_norm(x)
118
- q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
119
- k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
120
- v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
121
- q, k = self.rope(q, k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
123
- mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
124
- scores += mask
125
- attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
126
- attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
127
- x = res + self.dropout(self.out_proj(attn), training=training)
 
 
 
 
 
 
 
 
 
 
 
 
128
  res = x
129
  y = self.pre_ffn_norm(x)
130
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
131
- return res + self.dropout(ffn, training=training)
 
 
132
 
133
  def get_config(self):
134
  config = super().get_config()
135
- config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx})
 
 
 
 
136
  return config
137
 
138
 
@@ -149,25 +234,44 @@ class SAM1Model(keras.Model):
149
 
150
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
151
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
152
- block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']}
153
- self.blocks = []
154
- for i in range(self.cfg['n_layers']):
155
- block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
156
- self.blocks.append(block)
 
 
 
 
157
  self.norm = RMSNorm(name="final_norm")
158
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
159
 
160
- def call(self, input_ids, training=None):
 
 
 
 
 
 
 
 
161
  x = self.embed(input_ids)
162
- for block in self.blocks:
163
- x = block(x, training=training)
164
- return self.lm_head(self.norm(x))
 
 
 
 
 
 
 
 
165
 
166
  def get_config(self):
167
  base_config = super().get_config()
168
  base_config['config'] = self.cfg
169
  return base_config
170
-
171
  # --- Model and Tokenizer Loading ---
172
 
173
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
@@ -255,96 +359,101 @@ def generate_stream(
255
  top_p: float = 0.9,
256
  repetition_penalty: float = 1.1
257
  ):
258
- """Generate text with streaming output using REAL model inference"""
259
  global stop_generation
260
  stop_generation = False
261
 
262
- # Tokenize prompt
263
  prompt_ids = tokenizer.encode(prompt).ids
264
  input_ids = [i for i in prompt_ids if i != eos_token_id]
265
 
266
- input_tensor = tf.constant([input_ids], dtype=tf.int32)
267
  generated_text = ""
268
  token_count = 0
269
  token_freq = {}
270
 
271
  start_time = time.time()
272
 
273
- # --- REAL INFERENCE LOOP ---
 
 
 
 
 
 
 
 
274
  for step in range(max_tokens):
275
  if stop_generation:
276
  yield generated_text + "\n\n*[Generation stopped]*"
277
  break
278
 
279
- # 1. Forward Pass (Real Model)
280
- logits = fast_forward(input_tensor)
281
- next_token_logits = logits[0, -1, :].numpy()
282
-
283
- # 2. Temperature
284
- next_token_logits = next_token_logits / temperature
285
 
286
- # 3. Repetition Penalty
287
  if repetition_penalty != 1.0:
288
  for token_id, freq in token_freq.items():
289
- if token_id < len(next_token_logits):
290
- next_token_logits[token_id] /= (repetition_penalty ** freq)
291
 
292
- # 4. Sampling (Top-K / Top-P)
293
- # Top-K
294
  if top_k > 0:
295
- top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
296
- top_k_logits = next_token_logits[top_k_indices]
297
- top_k_probs = tf.nn.softmax(top_k_logits).numpy()
 
298
 
299
- # Top-P (Nucleus)
300
  if top_p < 1.0:
301
- sorted_indices = np.argsort(top_k_probs)[::-1]
302
- cumsum = np.cumsum(top_k_probs[sorted_indices])
303
- cutoff_idx = np.searchsorted(cumsum, top_p)
304
- nucleus_indices = sorted_indices[:cutoff_idx + 1]
305
-
306
- nucleus_logits = top_k_logits[nucleus_indices]
307
- nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
308
-
309
- sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
310
- next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
311
  else:
312
- sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
313
- next_token_id = int(top_k_indices[sampled_idx])
314
  else:
315
- probs = tf.nn.softmax(next_token_logits).numpy()
 
316
  next_token_id = np.random.choice(len(probs), p=probs)
317
-
318
- # 5. Stop Conditions
319
- if next_token_id == eos_token_id or \
320
- next_token_id == tokenizer.token_to_id("<|im_end|>") or \
321
- next_token_id == tokenizer.token_to_id("<im end for model tun>"):
322
  break
323
-
324
- # 6. Update Input & History
 
 
 
 
325
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
326
 
 
327
  token_text = tokenizer.decode([next_token_id])
328
  generated_text += token_text
329
  token_count += 1
330
-
331
  yield generated_text
332
 
333
- # Prepare next input
334
- input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
 
 
335
 
336
- # Truncate if exceeding context
337
- if input_tensor.shape[1] > config['max_position_embeddings']:
338
- input_tensor = input_tensor[:, -config['max_position_embeddings']:]
339
-
 
340
  elapsed = time.time() - start_time
341
- tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
342
 
343
  if token_count > 0 and not stop_generation:
344
- generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tokens_per_sec:.1f} tok/s)]*"
345
 
346
  yield generated_text
347
-
348
  # ============================================================================
349
  # Chat Interface Logic
350
  # ============================================================================
 
1
+ import os
2
+
3
+ # === CPU Threading Optimization ===
4
+ # Set these BEFORE importing TensorFlow
5
+ NUM_CORES = os.cpu_count() or 4
6
+
7
+ os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
8
+ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
9
+
10
+ # Disable GPU (ensures CPU-only, avoids GPU detection overhead)
11
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
12
+
13
+ import tensorflow as tf
14
+
15
+ # Configure threading after import
16
+ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
17
+ tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
18
+
19
+ # Enable oneDNN optimizations (significant on Intel CPUs)
20
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
21
+
22
+ # Optional: XLA JIT compilation (can help, test it)
23
+ # tf.config.optimizer.set_jit(True)
24
+
25
+ print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled")
26
  import gradio as gr
27
  import tensorflow as tf
28
  import keras
 
77
  x1, x2 = tf.split(x, 2, axis=-1)
78
  return tf.concat([-x2, x1], axis=-1)
79
 
80
+ def call(self, q, k, offset=0):
81
+ """Apply rotary embeddings with position offset for KV-cache."""
82
  self._build_cache()
83
  seq_len = tf.shape(q)[2]
84
  dtype = q.dtype
85
+
86
+ # For q: positions are [offset, offset+seq_len)
87
+ # For k: same positions (k is only the new tokens, past_k already has RoPE applied)
88
+ cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
89
+ sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
90
+
91
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
92
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
93
+ return q_embed, k_embed
94
 
95
  def get_config(self):
96
  config = super().get_config()
 
142
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
143
  self.dropout = keras.layers.Dropout(dropout)
144
 
145
+ def call(self, x, training=None, past_kv=None, use_cache=False):
146
+ """
147
+ Args:
148
+ x: input tensor [B, T, D] (T=1 during cached generation)
149
+ past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
150
+ use_cache: whether to return updated kv cache
151
+ Returns:
152
+ output, (new_k, new_v) if use_cache else output, None
153
+ """
154
+ B = tf.shape(x)[0]
155
+ T = tf.shape(x)[1]
156
  dtype = x.dtype
157
+
158
  res = x
159
  y = self.pre_attn_norm(x)
160
+
161
+ # Project Q, K, V for current input
162
+ q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
163
+ q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
164
+
165
+ k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
166
+ k = tf.transpose(k, [0, 2, 1, 3])
167
+
168
+ v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
169
+ v = tf.transpose(v, [0, 2, 1, 3])
170
+
171
+ # Determine position offset for RoPE
172
+ if past_kv is not None:
173
+ past_len = tf.shape(past_kv[0])[2]
174
+ else:
175
+ past_len = 0
176
+
177
+ # Apply RoPE with position offset
178
+ q, k = self.rope(q, k, offset=past_len)
179
+
180
+ # Concatenate with past KV
181
+ if past_kv is not None:
182
+ k = tf.concat([past_kv[0], k], axis=2)
183
+ v = tf.concat([past_kv[1], v], axis=2)
184
+
185
+ new_kv = (k, v) if use_cache else None
186
+
187
+ # Attention
188
+ full_len = tf.shape(k)[2]
189
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
190
+
191
+ # Causal mask: q attends to all of k (including past)
192
+ # Shape: [T, full_len] where each query position can attend to positions <= its absolute position
193
+ q_positions = tf.range(past_len, past_len + T)
194
+ k_positions = tf.range(full_len)
195
+ mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
196
+ mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
197
+ scores = scores + mask[None, None, :, :]
198
+
199
+ attn = tf.nn.softmax(scores, axis=-1)
200
+ attn_out = tf.matmul(attn, v)
201
+ attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
202
+ attn_out = tf.reshape(attn_out, [B, T, self.d_model])
203
+
204
+ x = res + self.dropout(self.out_proj(attn_out), training=training)
205
+
206
+ # FFN
207
  res = x
208
  y = self.pre_ffn_norm(x)
209
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
210
+ output = res + self.dropout(ffn, training=training)
211
+
212
+ return output, new_kv
213
 
214
  def get_config(self):
215
  config = super().get_config()
216
+ config.update({
217
+ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
218
+ "dropout": self.dropout_rate, "max_len": self.max_len,
219
+ "rope_theta": self.rope_theta, "layer_idx": self.layer_idx
220
+ })
221
  return config
222
 
223
 
 
234
 
235
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
236
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
237
+ block_args = {
238
+ 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
239
+ 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
240
+ 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
241
+ }
242
+ self.blocks = [
243
+ TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
244
+ for i in range(self.cfg['n_layers'])
245
+ ]
246
  self.norm = RMSNorm(name="final_norm")
247
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
248
 
249
+ def call(self, input_ids, training=None, past_kv=None, use_cache=False):
250
+ """
251
+ Args:
252
+ input_ids: [B, T]
253
+ past_kv: list of (k, v) tuples, one per layer
254
+ use_cache: whether to return updated cache
255
+ Returns:
256
+ logits, new_past_kv (or None)
257
+ """
258
  x = self.embed(input_ids)
259
+
260
+ new_past_kv = [] if use_cache else None
261
+
262
+ for i, block in enumerate(self.blocks):
263
+ layer_past = past_kv[i] if past_kv is not None else None
264
+ x, layer_kv = block(x, training=training, past_kv=layer_past, use_cache=use_cache)
265
+ if use_cache:
266
+ new_past_kv.append(layer_kv)
267
+
268
+ logits = self.lm_head(self.norm(x))
269
+ return logits, new_past_kv
270
 
271
  def get_config(self):
272
  base_config = super().get_config()
273
  base_config['config'] = self.cfg
274
  return base_config
 
275
  # --- Model and Tokenizer Loading ---
276
 
277
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
359
  top_p: float = 0.9,
360
  repetition_penalty: float = 1.1
361
  ):
362
+ """Generate text with KV-cache for fast CPU inference."""
363
  global stop_generation
364
  stop_generation = False
365
 
 
366
  prompt_ids = tokenizer.encode(prompt).ids
367
  input_ids = [i for i in prompt_ids if i != eos_token_id]
368
 
 
369
  generated_text = ""
370
  token_count = 0
371
  token_freq = {}
372
 
373
  start_time = time.time()
374
 
375
+ # === PREFILL PHASE ===
376
+ # Process entire prompt, build initial KV cache
377
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
378
+ logits, past_kv = model(input_tensor, training=False, use_cache=True)
379
+
380
+ # Get logits for last position
381
+ next_token_logits = logits[0, -1, :].numpy()
382
+
383
+ # === GENERATION LOOP ===
384
  for step in range(max_tokens):
385
  if stop_generation:
386
  yield generated_text + "\n\n*[Generation stopped]*"
387
  break
388
 
389
+ # Temperature
390
+ scaled_logits = next_token_logits / temperature
 
 
 
 
391
 
392
+ # Repetition penalty
393
  if repetition_penalty != 1.0:
394
  for token_id, freq in token_freq.items():
395
+ if token_id < len(scaled_logits):
396
+ scaled_logits[token_id] /= (repetition_penalty ** freq)
397
 
398
+ # Top-K sampling
 
399
  if top_k > 0:
400
+ top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
401
+ top_k_logits = scaled_logits[top_k_indices]
402
+ top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
403
+ top_k_probs /= top_k_probs.sum()
404
 
405
+ # Top-P (nucleus) sampling
406
  if top_p < 1.0:
407
+ sorted_idx = np.argsort(top_k_probs)[::-1]
408
+ cumsum = np.cumsum(top_k_probs[sorted_idx])
409
+ cutoff = np.searchsorted(cumsum, top_p) + 1
410
+ nucleus_idx = sorted_idx[:cutoff]
411
+ nucleus_probs = top_k_probs[nucleus_idx]
412
+ nucleus_probs /= nucleus_probs.sum()
413
+ sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
414
+ next_token_id = int(top_k_indices[nucleus_idx[sampled]])
 
 
415
  else:
416
+ sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
417
+ next_token_id = int(top_k_indices[sampled])
418
  else:
419
+ probs = np.exp(scaled_logits - np.max(scaled_logits))
420
+ probs /= probs.sum()
421
  next_token_id = np.random.choice(len(probs), p=probs)
422
+
423
+ # Stop conditions
424
+ if next_token_id == eos_token_id:
 
 
425
  break
426
+ im_end_id = tokenizer.token_to_id("<|im_end|>")
427
+ model_end_id = tokenizer.token_to_id("<im end for model tun>")
428
+ if next_token_id in (im_end_id, model_end_id):
429
+ break
430
+
431
+ # Update frequency tracking
432
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
433
 
434
+ # Decode and yield
435
  token_text = tokenizer.decode([next_token_id])
436
  generated_text += token_text
437
  token_count += 1
 
438
  yield generated_text
439
 
440
+ # === DECODE PHASE (single token, reuse cache) ===
441
+ next_input = tf.constant([[next_token_id]], dtype=tf.int32)
442
+ logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
443
+ next_token_logits = logits[0, -1, :].numpy()
444
 
445
+ # Truncate cache if too long
446
+ max_len = config['max_position_embeddings']
447
+ if past_kv[0][0].shape[2] > max_len:
448
+ past_kv = [(k[:, :, -max_len:, :], v[:, :, -max_len:, :]) for k, v in past_kv]
449
+
450
  elapsed = time.time() - start_time
451
+ tps = token_count / elapsed if elapsed > 0 else 0
452
 
453
  if token_count > 0 and not stop_generation:
454
+ generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tps:.1f} tok/s)]*"
455
 
456
  yield generated_text
 
457
  # ============================================================================
458
  # Chat Interface Logic
459
  # ============================================================================