Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import tensorflow as tf
|
| 3 |
import keras
|
|
@@ -52,13 +77,20 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
| 52 |
x1, x2 = tf.split(x, 2, axis=-1)
|
| 53 |
return tf.concat([-x2, x1], axis=-1)
|
| 54 |
|
| 55 |
-
def call(self, q, k):
|
|
|
|
| 56 |
self._build_cache()
|
| 57 |
seq_len = tf.shape(q)[2]
|
| 58 |
dtype = q.dtype
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def get_config(self):
|
| 64 |
config = super().get_config()
|
|
@@ -110,29 +142,82 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 110 |
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
|
| 111 |
self.dropout = keras.layers.Dropout(dropout)
|
| 112 |
|
| 113 |
-
def call(self, x, training=None):
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
dtype = x.dtype
|
|
|
|
| 116 |
res = x
|
| 117 |
y = self.pre_attn_norm(x)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
q
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
res = x
|
| 129 |
y = self.pre_ffn_norm(x)
|
| 130 |
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
def get_config(self):
|
| 134 |
config = super().get_config()
|
| 135 |
-
config.update({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return config
|
| 137 |
|
| 138 |
|
|
@@ -149,25 +234,44 @@ class SAM1Model(keras.Model):
|
|
| 149 |
|
| 150 |
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
|
| 151 |
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
|
| 152 |
-
block_args = {
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
self.norm = RMSNorm(name="final_norm")
|
| 158 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 159 |
|
| 160 |
-
def call(self, input_ids, training=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
x = self.embed(input_ids)
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
def get_config(self):
|
| 167 |
base_config = super().get_config()
|
| 168 |
base_config['config'] = self.cfg
|
| 169 |
return base_config
|
| 170 |
-
|
| 171 |
# --- Model and Tokenizer Loading ---
|
| 172 |
|
| 173 |
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
|
|
@@ -255,96 +359,101 @@ def generate_stream(
|
|
| 255 |
top_p: float = 0.9,
|
| 256 |
repetition_penalty: float = 1.1
|
| 257 |
):
|
| 258 |
-
"""Generate text with
|
| 259 |
global stop_generation
|
| 260 |
stop_generation = False
|
| 261 |
|
| 262 |
-
# Tokenize prompt
|
| 263 |
prompt_ids = tokenizer.encode(prompt).ids
|
| 264 |
input_ids = [i for i in prompt_ids if i != eos_token_id]
|
| 265 |
|
| 266 |
-
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 267 |
generated_text = ""
|
| 268 |
token_count = 0
|
| 269 |
token_freq = {}
|
| 270 |
|
| 271 |
start_time = time.time()
|
| 272 |
|
| 273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
for step in range(max_tokens):
|
| 275 |
if stop_generation:
|
| 276 |
yield generated_text + "\n\n*[Generation stopped]*"
|
| 277 |
break
|
| 278 |
|
| 279 |
-
#
|
| 280 |
-
|
| 281 |
-
next_token_logits = logits[0, -1, :].numpy()
|
| 282 |
-
|
| 283 |
-
# 2. Temperature
|
| 284 |
-
next_token_logits = next_token_logits / temperature
|
| 285 |
|
| 286 |
-
#
|
| 287 |
if repetition_penalty != 1.0:
|
| 288 |
for token_id, freq in token_freq.items():
|
| 289 |
-
if token_id < len(
|
| 290 |
-
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
# Top-K
|
| 294 |
if top_k > 0:
|
| 295 |
-
top_k_indices = np.argpartition(
|
| 296 |
-
top_k_logits =
|
| 297 |
-
top_k_probs =
|
|
|
|
| 298 |
|
| 299 |
-
# Top-P (
|
| 300 |
if top_p < 1.0:
|
| 301 |
-
|
| 302 |
-
cumsum = np.cumsum(top_k_probs[
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
|
| 310 |
-
next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
|
| 311 |
else:
|
| 312 |
-
|
| 313 |
-
next_token_id = int(top_k_indices[
|
| 314 |
else:
|
| 315 |
-
probs =
|
|
|
|
| 316 |
next_token_id = np.random.choice(len(probs), p=probs)
|
| 317 |
-
|
| 318 |
-
#
|
| 319 |
-
if next_token_id == eos_token_id
|
| 320 |
-
next_token_id == tokenizer.token_to_id("<|im_end|>") or \
|
| 321 |
-
next_token_id == tokenizer.token_to_id("<im end for model tun>"):
|
| 322 |
break
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
|
| 326 |
|
|
|
|
| 327 |
token_text = tokenizer.decode([next_token_id])
|
| 328 |
generated_text += token_text
|
| 329 |
token_count += 1
|
| 330 |
-
|
| 331 |
yield generated_text
|
| 332 |
|
| 333 |
-
#
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
|
| 336 |
-
# Truncate if
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
| 340 |
elapsed = time.time() - start_time
|
| 341 |
-
|
| 342 |
|
| 343 |
if token_count > 0 and not stop_generation:
|
| 344 |
-
generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({
|
| 345 |
|
| 346 |
yield generated_text
|
| 347 |
-
|
| 348 |
# ============================================================================
|
| 349 |
# Chat Interface Logic
|
| 350 |
# ============================================================================
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# === CPU Threading Optimization ===
|
| 4 |
+
# Set these BEFORE importing TensorFlow
|
| 5 |
+
NUM_CORES = os.cpu_count() or 4
|
| 6 |
+
|
| 7 |
+
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
|
| 8 |
+
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
|
| 9 |
+
|
| 10 |
+
# Disable GPU (ensures CPU-only, avoids GPU detection overhead)
|
| 11 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
| 12 |
+
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
|
| 15 |
+
# Configure threading after import
|
| 16 |
+
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
| 17 |
+
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 18 |
+
|
| 19 |
+
# Enable oneDNN optimizations (significant on Intel CPUs)
|
| 20 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
|
| 21 |
+
|
| 22 |
+
# Optional: XLA JIT compilation (can help, test it)
|
| 23 |
+
# tf.config.optimizer.set_jit(True)
|
| 24 |
+
|
| 25 |
+
print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled")
|
| 26 |
import gradio as gr
|
| 27 |
import tensorflow as tf
|
| 28 |
import keras
|
|
|
|
| 77 |
x1, x2 = tf.split(x, 2, axis=-1)
|
| 78 |
return tf.concat([-x2, x1], axis=-1)
|
| 79 |
|
| 80 |
+
def call(self, q, k, offset=0):
|
| 81 |
+
"""Apply rotary embeddings with position offset for KV-cache."""
|
| 82 |
self._build_cache()
|
| 83 |
seq_len = tf.shape(q)[2]
|
| 84 |
dtype = q.dtype
|
| 85 |
+
|
| 86 |
+
# For q: positions are [offset, offset+seq_len)
|
| 87 |
+
# For k: same positions (k is only the new tokens, past_k already has RoPE applied)
|
| 88 |
+
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 89 |
+
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 90 |
+
|
| 91 |
+
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
| 92 |
+
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
| 93 |
+
return q_embed, k_embed
|
| 94 |
|
| 95 |
def get_config(self):
|
| 96 |
config = super().get_config()
|
|
|
|
| 142 |
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
|
| 143 |
self.dropout = keras.layers.Dropout(dropout)
|
| 144 |
|
| 145 |
+
def call(self, x, training=None, past_kv=None, use_cache=False):
|
| 146 |
+
"""
|
| 147 |
+
Args:
|
| 148 |
+
x: input tensor [B, T, D] (T=1 during cached generation)
|
| 149 |
+
past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
|
| 150 |
+
use_cache: whether to return updated kv cache
|
| 151 |
+
Returns:
|
| 152 |
+
output, (new_k, new_v) if use_cache else output, None
|
| 153 |
+
"""
|
| 154 |
+
B = tf.shape(x)[0]
|
| 155 |
+
T = tf.shape(x)[1]
|
| 156 |
dtype = x.dtype
|
| 157 |
+
|
| 158 |
res = x
|
| 159 |
y = self.pre_attn_norm(x)
|
| 160 |
+
|
| 161 |
+
# Project Q, K, V for current input
|
| 162 |
+
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 163 |
+
q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
|
| 164 |
+
|
| 165 |
+
k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 166 |
+
k = tf.transpose(k, [0, 2, 1, 3])
|
| 167 |
+
|
| 168 |
+
v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 169 |
+
v = tf.transpose(v, [0, 2, 1, 3])
|
| 170 |
+
|
| 171 |
+
# Determine position offset for RoPE
|
| 172 |
+
if past_kv is not None:
|
| 173 |
+
past_len = tf.shape(past_kv[0])[2]
|
| 174 |
+
else:
|
| 175 |
+
past_len = 0
|
| 176 |
+
|
| 177 |
+
# Apply RoPE with position offset
|
| 178 |
+
q, k = self.rope(q, k, offset=past_len)
|
| 179 |
+
|
| 180 |
+
# Concatenate with past KV
|
| 181 |
+
if past_kv is not None:
|
| 182 |
+
k = tf.concat([past_kv[0], k], axis=2)
|
| 183 |
+
v = tf.concat([past_kv[1], v], axis=2)
|
| 184 |
+
|
| 185 |
+
new_kv = (k, v) if use_cache else None
|
| 186 |
+
|
| 187 |
+
# Attention
|
| 188 |
+
full_len = tf.shape(k)[2]
|
| 189 |
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
|
| 190 |
+
|
| 191 |
+
# Causal mask: q attends to all of k (including past)
|
| 192 |
+
# Shape: [T, full_len] where each query position can attend to positions <= its absolute position
|
| 193 |
+
q_positions = tf.range(past_len, past_len + T)
|
| 194 |
+
k_positions = tf.range(full_len)
|
| 195 |
+
mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
|
| 196 |
+
mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
|
| 197 |
+
scores = scores + mask[None, None, :, :]
|
| 198 |
+
|
| 199 |
+
attn = tf.nn.softmax(scores, axis=-1)
|
| 200 |
+
attn_out = tf.matmul(attn, v)
|
| 201 |
+
attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
|
| 202 |
+
attn_out = tf.reshape(attn_out, [B, T, self.d_model])
|
| 203 |
+
|
| 204 |
+
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 205 |
+
|
| 206 |
+
# FFN
|
| 207 |
res = x
|
| 208 |
y = self.pre_ffn_norm(x)
|
| 209 |
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
|
| 210 |
+
output = res + self.dropout(ffn, training=training)
|
| 211 |
+
|
| 212 |
+
return output, new_kv
|
| 213 |
|
| 214 |
def get_config(self):
|
| 215 |
config = super().get_config()
|
| 216 |
+
config.update({
|
| 217 |
+
"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
|
| 218 |
+
"dropout": self.dropout_rate, "max_len": self.max_len,
|
| 219 |
+
"rope_theta": self.rope_theta, "layer_idx": self.layer_idx
|
| 220 |
+
})
|
| 221 |
return config
|
| 222 |
|
| 223 |
|
|
|
|
| 234 |
|
| 235 |
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
|
| 236 |
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
|
| 237 |
+
block_args = {
|
| 238 |
+
'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
|
| 239 |
+
'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
|
| 240 |
+
'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
|
| 241 |
+
}
|
| 242 |
+
self.blocks = [
|
| 243 |
+
TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
|
| 244 |
+
for i in range(self.cfg['n_layers'])
|
| 245 |
+
]
|
| 246 |
self.norm = RMSNorm(name="final_norm")
|
| 247 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 248 |
|
| 249 |
+
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
| 250 |
+
"""
|
| 251 |
+
Args:
|
| 252 |
+
input_ids: [B, T]
|
| 253 |
+
past_kv: list of (k, v) tuples, one per layer
|
| 254 |
+
use_cache: whether to return updated cache
|
| 255 |
+
Returns:
|
| 256 |
+
logits, new_past_kv (or None)
|
| 257 |
+
"""
|
| 258 |
x = self.embed(input_ids)
|
| 259 |
+
|
| 260 |
+
new_past_kv = [] if use_cache else None
|
| 261 |
+
|
| 262 |
+
for i, block in enumerate(self.blocks):
|
| 263 |
+
layer_past = past_kv[i] if past_kv is not None else None
|
| 264 |
+
x, layer_kv = block(x, training=training, past_kv=layer_past, use_cache=use_cache)
|
| 265 |
+
if use_cache:
|
| 266 |
+
new_past_kv.append(layer_kv)
|
| 267 |
+
|
| 268 |
+
logits = self.lm_head(self.norm(x))
|
| 269 |
+
return logits, new_past_kv
|
| 270 |
|
| 271 |
def get_config(self):
|
| 272 |
base_config = super().get_config()
|
| 273 |
base_config['config'] = self.cfg
|
| 274 |
return base_config
|
|
|
|
| 275 |
# --- Model and Tokenizer Loading ---
|
| 276 |
|
| 277 |
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
|
|
|
|
| 359 |
top_p: float = 0.9,
|
| 360 |
repetition_penalty: float = 1.1
|
| 361 |
):
|
| 362 |
+
"""Generate text with KV-cache for fast CPU inference."""
|
| 363 |
global stop_generation
|
| 364 |
stop_generation = False
|
| 365 |
|
|
|
|
| 366 |
prompt_ids = tokenizer.encode(prompt).ids
|
| 367 |
input_ids = [i for i in prompt_ids if i != eos_token_id]
|
| 368 |
|
|
|
|
| 369 |
generated_text = ""
|
| 370 |
token_count = 0
|
| 371 |
token_freq = {}
|
| 372 |
|
| 373 |
start_time = time.time()
|
| 374 |
|
| 375 |
+
# === PREFILL PHASE ===
|
| 376 |
+
# Process entire prompt, build initial KV cache
|
| 377 |
+
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 378 |
+
logits, past_kv = model(input_tensor, training=False, use_cache=True)
|
| 379 |
+
|
| 380 |
+
# Get logits for last position
|
| 381 |
+
next_token_logits = logits[0, -1, :].numpy()
|
| 382 |
+
|
| 383 |
+
# === GENERATION LOOP ===
|
| 384 |
for step in range(max_tokens):
|
| 385 |
if stop_generation:
|
| 386 |
yield generated_text + "\n\n*[Generation stopped]*"
|
| 387 |
break
|
| 388 |
|
| 389 |
+
# Temperature
|
| 390 |
+
scaled_logits = next_token_logits / temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
+
# Repetition penalty
|
| 393 |
if repetition_penalty != 1.0:
|
| 394 |
for token_id, freq in token_freq.items():
|
| 395 |
+
if token_id < len(scaled_logits):
|
| 396 |
+
scaled_logits[token_id] /= (repetition_penalty ** freq)
|
| 397 |
|
| 398 |
+
# Top-K sampling
|
|
|
|
| 399 |
if top_k > 0:
|
| 400 |
+
top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
|
| 401 |
+
top_k_logits = scaled_logits[top_k_indices]
|
| 402 |
+
top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
|
| 403 |
+
top_k_probs /= top_k_probs.sum()
|
| 404 |
|
| 405 |
+
# Top-P (nucleus) sampling
|
| 406 |
if top_p < 1.0:
|
| 407 |
+
sorted_idx = np.argsort(top_k_probs)[::-1]
|
| 408 |
+
cumsum = np.cumsum(top_k_probs[sorted_idx])
|
| 409 |
+
cutoff = np.searchsorted(cumsum, top_p) + 1
|
| 410 |
+
nucleus_idx = sorted_idx[:cutoff]
|
| 411 |
+
nucleus_probs = top_k_probs[nucleus_idx]
|
| 412 |
+
nucleus_probs /= nucleus_probs.sum()
|
| 413 |
+
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
|
| 414 |
+
next_token_id = int(top_k_indices[nucleus_idx[sampled]])
|
|
|
|
|
|
|
| 415 |
else:
|
| 416 |
+
sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
|
| 417 |
+
next_token_id = int(top_k_indices[sampled])
|
| 418 |
else:
|
| 419 |
+
probs = np.exp(scaled_logits - np.max(scaled_logits))
|
| 420 |
+
probs /= probs.sum()
|
| 421 |
next_token_id = np.random.choice(len(probs), p=probs)
|
| 422 |
+
|
| 423 |
+
# Stop conditions
|
| 424 |
+
if next_token_id == eos_token_id:
|
|
|
|
|
|
|
| 425 |
break
|
| 426 |
+
im_end_id = tokenizer.token_to_id("<|im_end|>")
|
| 427 |
+
model_end_id = tokenizer.token_to_id("<im end for model tun>")
|
| 428 |
+
if next_token_id in (im_end_id, model_end_id):
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
# Update frequency tracking
|
| 432 |
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
|
| 433 |
|
| 434 |
+
# Decode and yield
|
| 435 |
token_text = tokenizer.decode([next_token_id])
|
| 436 |
generated_text += token_text
|
| 437 |
token_count += 1
|
|
|
|
| 438 |
yield generated_text
|
| 439 |
|
| 440 |
+
# === DECODE PHASE (single token, reuse cache) ===
|
| 441 |
+
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 442 |
+
logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
|
| 443 |
+
next_token_logits = logits[0, -1, :].numpy()
|
| 444 |
|
| 445 |
+
# Truncate cache if too long
|
| 446 |
+
max_len = config['max_position_embeddings']
|
| 447 |
+
if past_kv[0][0].shape[2] > max_len:
|
| 448 |
+
past_kv = [(k[:, :, -max_len:, :], v[:, :, -max_len:, :]) for k, v in past_kv]
|
| 449 |
+
|
| 450 |
elapsed = time.time() - start_time
|
| 451 |
+
tps = token_count / elapsed if elapsed > 0 else 0
|
| 452 |
|
| 453 |
if token_count > 0 and not stop_generation:
|
| 454 |
+
generated_text += f"\n\n*[Generated {token_count} tokens in {elapsed:.1f}s ({tps:.1f} tok/s)]*"
|
| 455 |
|
| 456 |
yield generated_text
|
|
|
|
| 457 |
# ============================================================================
|
| 458 |
# Chat Interface Logic
|
| 459 |
# ============================================================================
|