Bc-AI commited on
Commit
0855c89
Β·
verified Β·
1 Parent(s): ce4914d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -56
app.py CHANGED
@@ -8,30 +8,17 @@ Edit the CONFIG below, then deploy.
8
  # ============================================================================
9
 
10
  CONFIG = {
11
- # This node's identity
12
  "node_id": "head-main",
13
-
14
- # Which transformer blocks this node runs (0-indexed)
15
- # Sam-large-2 has 12 blocks (0-11)
16
  "layer_start": 0,
17
- "layer_end": 6, # exclusive, so this runs blocks 0,1,2,3,4,5
18
-
19
- # Worker Space URLs (in order of execution)
20
- # Leave empty [] for standalone mode (all layers on this node)
21
- "worker_urls": [
22
- # "https://YOUR-WORKER-SPACE.hf.space",
23
- ],
24
-
25
- # Shared secret for worker communication
26
  "secret_token": "sam2-distributed-secret-change-me",
27
-
28
- # Model settings
29
  "model_repo": "Smilyai-labs/Sam-large-2",
30
  "cache_dir": "./model_cache",
31
  }
32
 
33
  # ============================================================================
34
- # CPU Optimization - MUST be before TensorFlow import
35
  # ============================================================================
36
 
37
  import os
@@ -45,7 +32,6 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
45
 
46
  import json
47
  import time
48
- import threading
49
  import io
50
  import base64
51
  from typing import Dict, List, Optional, Tuple, Any
@@ -204,13 +190,10 @@ class ModelState:
204
  self.config = None
205
  self.tokenizer = None
206
  self.eos_token_id = 50256
207
-
208
- # Model components
209
  self.embedding = None
210
  self.blocks: List = []
211
  self.final_norm = None
212
  self.lm_head = None
213
-
214
  self.my_block_start = 0
215
  self.my_block_end = 0
216
 
@@ -245,7 +228,6 @@ def deserialize_kv_cache(data):
245
  # ============================================================================
246
 
247
  def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
248
- """Send hidden states to worker and get result."""
249
  try:
250
  response = requests.post(
251
  f"{url.rstrip('/')}/api/forward",
@@ -273,16 +255,13 @@ def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=Fals
273
  # ============================================================================
274
 
275
  def load_model():
276
- """Load model and extract components for this node."""
277
  print("πŸš€ Loading model...")
278
 
279
- # Load config
280
  config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
281
  with open(config_path, 'r') as f:
282
  model_config = json.load(f)
283
  STATE.config = model_config
284
 
285
- # Load tokenizer
286
  from transformers import AutoTokenizer
287
  from tokenizers import Tokenizer
288
 
@@ -294,10 +273,8 @@ def load_model():
294
  STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
295
  STATE.eos_token_id = model_config.get('eos_token_id', 50256)
296
 
297
- # Load weights
298
  weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
299
 
300
- # Build full model to load weights
301
  n_layers = model_config['num_hidden_layers']
302
  d_model = model_config['hidden_size']
303
  n_heads = model_config['num_attention_heads']
@@ -306,14 +283,12 @@ def load_model():
306
  rope_theta = model_config['rope_theta']
307
  vocab_size = model_config['vocab_size']
308
 
309
- # Temporary full model
310
  embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
311
  blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
312
  for i in range(n_layers)]
313
  final_norm = RMSNorm(name="final_norm")
314
  lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
315
 
316
- # Build
317
  dummy = tf.zeros((1, 16), dtype=tf.int32)
318
  x = embedding(dummy)
319
  for block in blocks:
@@ -321,7 +296,6 @@ def load_model():
321
  x = final_norm(x)
322
  _ = lm_head(x)
323
 
324
- # Load weights into a temp model structure
325
  class TempModel(keras.Model):
326
  def __init__(self):
327
  super().__init__()
@@ -340,25 +314,19 @@ def load_model():
340
  temp_model.load_weights(weights_path)
341
  print("βœ… Weights loaded")
342
 
343
- # Extract components for this node
344
  STATE.my_block_start = CONFIG["layer_start"]
345
  STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
346
 
347
- # HEAD always has embedding
348
  STATE.embedding = embedding
349
-
350
- # Extract our blocks
351
  STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
352
  print(f"βœ… Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
353
 
354
- # HEAD has final norm and lm_head only if no workers OR we handle last block
355
  has_workers = len(CONFIG["worker_urls"]) > 0
356
  if not has_workers:
357
  STATE.final_norm = final_norm
358
  STATE.lm_head = lm_head
359
  print("βœ… Loaded final norm and LM head (standalone mode)")
360
 
361
- # Warmup
362
  print("πŸ”₯ Warming up...")
363
  dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
364
  x = STATE.embedding(dummy)
@@ -375,14 +343,8 @@ def load_model():
375
  # ============================================================================
376
 
377
  def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
378
- """
379
- Full forward pass through HEAD + all workers.
380
- Returns logits and updated KV caches.
381
- """
382
- # Embedding
383
  x = STATE.embedding(input_ids)
384
 
385
- # Local blocks
386
  new_local_kv = [] if use_cache else None
387
  for i, block in enumerate(STATE.blocks):
388
  block_past = past_kv_local[i] if past_kv_local else None
@@ -390,7 +352,6 @@ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None,
390
  if use_cache:
391
  new_local_kv.append(kv)
392
 
393
- # Workers
394
  new_worker_kv = {} if use_cache else None
395
  for worker_url in CONFIG["worker_urls"]:
396
  worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
@@ -398,12 +359,9 @@ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None,
398
  if use_cache:
399
  new_worker_kv[worker_url] = worker_kv
400
 
401
- # Final (only if standalone or last worker returned to us)
402
- # In distributed mode, the last worker applies final_norm + lm_head
403
  if STATE.lm_head:
404
  logits = STATE.lm_head(STATE.final_norm(x))
405
  else:
406
- # x should already be logits from last worker
407
  logits = x
408
 
409
  return logits, new_local_kv, new_worker_kv
@@ -465,7 +423,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
465
 
466
  start = time.time()
467
 
468
- # Prefill
469
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
470
  try:
471
  logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
@@ -477,7 +434,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
477
  prefill_time = time.time() - start
478
  print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
479
 
480
- # Generate
481
  decode_start = time.time()
482
  tokens_generated = 0
483
 
@@ -496,7 +452,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
496
  tokens_generated += 1
497
  yield generated
498
 
499
- # Next step
500
  next_input = tf.constant([[next_id]], dtype=tf.int32)
501
  try:
502
  logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
@@ -506,7 +461,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
506
 
507
  next_logits = logits[0, -1, :].numpy()
508
 
509
- # Stats
510
  if tokens_generated > 0:
511
  total = time.time() - start
512
  tps = tokens_generated / (time.time() - decode_start)
@@ -519,10 +473,12 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
519
 
520
  def format_prompt(message: str, history: list, reasoning: bool) -> str:
521
  prompt = ""
522
- for user, assistant in history:
523
- prompt += f"<|im_start|>user\n{user}<|im_end|>\n"
524
- if assistant:
525
- prompt += f"<|im_start|>assistant\n{assistant.split('*[')[0].strip()}<|im_end|>\n"
 
 
526
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
527
  if reasoning:
528
  prompt += "<think>"
@@ -536,21 +492,27 @@ def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reas
536
 
537
  prompt = format_prompt(message, history, reasoning)
538
 
 
 
 
539
  for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
540
  display = text
 
 
541
  for tag in ["<|im_end|>", "<im end for model tun>"]:
542
  if tag in display:
543
  idx = display.find(tag)
544
  stats = display.find("\n\n*[")
545
  display = display[:idx] + (display[stats:] if stats > idx else "")
546
 
 
547
  if reasoning and '<think>' in display and '</think>' in display:
548
  s, e = display.find('<think>'), display.find('</think>')
549
  if s < e:
550
  thought = display[s+7:e].strip()
551
  display = display[:s] + f'<details><summary>🧠 Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
552
 
553
- yield history + [[message, display.strip()]]
554
 
555
 
556
  def stop():
@@ -575,7 +537,11 @@ def create_ui():
575
  gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
576
 
577
  reasoning = gr.State(False)
578
- chatbot = gr.Chatbot(height=500)
 
 
 
 
579
 
580
  with gr.Row():
581
  reason_btn = gr.Button("πŸ’‘", size="sm", scale=0)
@@ -600,7 +566,7 @@ def create_ui():
600
  click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
601
  stop_btn.click(stop, cancels=[submit, click])
602
 
603
- gr.Button("πŸ—‘οΈ Clear").click(lambda: ([], ""), outputs=[chatbot, msg])
604
 
605
  return app
606
 
 
8
  # ============================================================================
9
 
10
  CONFIG = {
 
11
  "node_id": "head-main",
 
 
 
12
  "layer_start": 0,
13
+ "layer_end": 6,
14
+ "worker_urls": [],
 
 
 
 
 
 
 
15
  "secret_token": "sam2-distributed-secret-change-me",
 
 
16
  "model_repo": "Smilyai-labs/Sam-large-2",
17
  "cache_dir": "./model_cache",
18
  }
19
 
20
  # ============================================================================
21
+ # CPU Optimization
22
  # ============================================================================
23
 
24
  import os
 
32
 
33
  import json
34
  import time
 
35
  import io
36
  import base64
37
  from typing import Dict, List, Optional, Tuple, Any
 
190
  self.config = None
191
  self.tokenizer = None
192
  self.eos_token_id = 50256
 
 
193
  self.embedding = None
194
  self.blocks: List = []
195
  self.final_norm = None
196
  self.lm_head = None
 
197
  self.my_block_start = 0
198
  self.my_block_end = 0
199
 
 
228
  # ============================================================================
229
 
230
  def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
 
231
  try:
232
  response = requests.post(
233
  f"{url.rstrip('/')}/api/forward",
 
255
  # ============================================================================
256
 
257
  def load_model():
 
258
  print("πŸš€ Loading model...")
259
 
 
260
  config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
261
  with open(config_path, 'r') as f:
262
  model_config = json.load(f)
263
  STATE.config = model_config
264
 
 
265
  from transformers import AutoTokenizer
266
  from tokenizers import Tokenizer
267
 
 
273
  STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
274
  STATE.eos_token_id = model_config.get('eos_token_id', 50256)
275
 
 
276
  weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
277
 
 
278
  n_layers = model_config['num_hidden_layers']
279
  d_model = model_config['hidden_size']
280
  n_heads = model_config['num_attention_heads']
 
283
  rope_theta = model_config['rope_theta']
284
  vocab_size = model_config['vocab_size']
285
 
 
286
  embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
287
  blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
288
  for i in range(n_layers)]
289
  final_norm = RMSNorm(name="final_norm")
290
  lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
291
 
 
292
  dummy = tf.zeros((1, 16), dtype=tf.int32)
293
  x = embedding(dummy)
294
  for block in blocks:
 
296
  x = final_norm(x)
297
  _ = lm_head(x)
298
 
 
299
  class TempModel(keras.Model):
300
  def __init__(self):
301
  super().__init__()
 
314
  temp_model.load_weights(weights_path)
315
  print("βœ… Weights loaded")
316
 
 
317
  STATE.my_block_start = CONFIG["layer_start"]
318
  STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
319
 
 
320
  STATE.embedding = embedding
 
 
321
  STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
322
  print(f"βœ… Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
323
 
 
324
  has_workers = len(CONFIG["worker_urls"]) > 0
325
  if not has_workers:
326
  STATE.final_norm = final_norm
327
  STATE.lm_head = lm_head
328
  print("βœ… Loaded final norm and LM head (standalone mode)")
329
 
 
330
  print("πŸ”₯ Warming up...")
331
  dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
332
  x = STATE.embedding(dummy)
 
343
  # ============================================================================
344
 
345
  def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
 
 
 
 
 
346
  x = STATE.embedding(input_ids)
347
 
 
348
  new_local_kv = [] if use_cache else None
349
  for i, block in enumerate(STATE.blocks):
350
  block_past = past_kv_local[i] if past_kv_local else None
 
352
  if use_cache:
353
  new_local_kv.append(kv)
354
 
 
355
  new_worker_kv = {} if use_cache else None
356
  for worker_url in CONFIG["worker_urls"]:
357
  worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
 
359
  if use_cache:
360
  new_worker_kv[worker_url] = worker_kv
361
 
 
 
362
  if STATE.lm_head:
363
  logits = STATE.lm_head(STATE.final_norm(x))
364
  else:
 
365
  logits = x
366
 
367
  return logits, new_local_kv, new_worker_kv
 
423
 
424
  start = time.time()
425
 
 
426
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
427
  try:
428
  logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
 
434
  prefill_time = time.time() - start
435
  print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
436
 
 
437
  decode_start = time.time()
438
  tokens_generated = 0
439
 
 
452
  tokens_generated += 1
453
  yield generated
454
 
 
455
  next_input = tf.constant([[next_id]], dtype=tf.int32)
456
  try:
457
  logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
 
461
 
462
  next_logits = logits[0, -1, :].numpy()
463
 
 
464
  if tokens_generated > 0:
465
  total = time.time() - start
466
  tps = tokens_generated / (time.time() - decode_start)
 
473
 
474
  def format_prompt(message: str, history: list, reasoning: bool) -> str:
475
  prompt = ""
476
+ for msg in history:
477
+ if msg["role"] == "user":
478
+ prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
479
+ elif msg["role"] == "assistant":
480
+ content = msg['content'].split('*[')[0].strip()
481
+ prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
482
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
483
  if reasoning:
484
  prompt += "<think>"
 
492
 
493
  prompt = format_prompt(message, history, reasoning)
494
 
495
+ # Add user message to history
496
+ history = history + [{"role": "user", "content": message}]
497
+
498
  for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
499
  display = text
500
+
501
+ # Clean stop tags
502
  for tag in ["<|im_end|>", "<im end for model tun>"]:
503
  if tag in display:
504
  idx = display.find(tag)
505
  stats = display.find("\n\n*[")
506
  display = display[:idx] + (display[stats:] if stats > idx else "")
507
 
508
+ # Format reasoning
509
  if reasoning and '<think>' in display and '</think>' in display:
510
  s, e = display.find('<think>'), display.find('</think>')
511
  if s < e:
512
  thought = display[s+7:e].strip()
513
  display = display[:s] + f'<details><summary>🧠 Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
514
 
515
+ yield history + [{"role": "assistant", "content": display.strip()}]
516
 
517
 
518
  def stop():
 
537
  gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
538
 
539
  reasoning = gr.State(False)
540
+
541
+ chatbot = gr.Chatbot(
542
+ height=500,
543
+ type="messages" # Use new messages format
544
+ )
545
 
546
  with gr.Row():
547
  reason_btn = gr.Button("πŸ’‘", size="sm", scale=0)
 
566
  click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
567
  stop_btn.click(stop, cancels=[submit, click])
568
 
569
+ gr.Button("πŸ—‘οΈ Clear").click(lambda: [], outputs=[chatbot])
570
 
571
  return app
572