Update app.py
Browse files
app.py
CHANGED
|
@@ -8,30 +8,17 @@ Edit the CONFIG below, then deploy.
|
|
| 8 |
# ============================================================================
|
| 9 |
|
| 10 |
CONFIG = {
|
| 11 |
-
# This node's identity
|
| 12 |
"node_id": "head-main",
|
| 13 |
-
|
| 14 |
-
# Which transformer blocks this node runs (0-indexed)
|
| 15 |
-
# Sam-large-2 has 12 blocks (0-11)
|
| 16 |
"layer_start": 0,
|
| 17 |
-
"layer_end": 6,
|
| 18 |
-
|
| 19 |
-
# Worker Space URLs (in order of execution)
|
| 20 |
-
# Leave empty [] for standalone mode (all layers on this node)
|
| 21 |
-
"worker_urls": [
|
| 22 |
-
# "https://YOUR-WORKER-SPACE.hf.space",
|
| 23 |
-
],
|
| 24 |
-
|
| 25 |
-
# Shared secret for worker communication
|
| 26 |
"secret_token": "sam2-distributed-secret-change-me",
|
| 27 |
-
|
| 28 |
-
# Model settings
|
| 29 |
"model_repo": "Smilyai-labs/Sam-large-2",
|
| 30 |
"cache_dir": "./model_cache",
|
| 31 |
}
|
| 32 |
|
| 33 |
# ============================================================================
|
| 34 |
-
# CPU Optimization
|
| 35 |
# ============================================================================
|
| 36 |
|
| 37 |
import os
|
|
@@ -45,7 +32,6 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
| 45 |
|
| 46 |
import json
|
| 47 |
import time
|
| 48 |
-
import threading
|
| 49 |
import io
|
| 50 |
import base64
|
| 51 |
from typing import Dict, List, Optional, Tuple, Any
|
|
@@ -204,13 +190,10 @@ class ModelState:
|
|
| 204 |
self.config = None
|
| 205 |
self.tokenizer = None
|
| 206 |
self.eos_token_id = 50256
|
| 207 |
-
|
| 208 |
-
# Model components
|
| 209 |
self.embedding = None
|
| 210 |
self.blocks: List = []
|
| 211 |
self.final_norm = None
|
| 212 |
self.lm_head = None
|
| 213 |
-
|
| 214 |
self.my_block_start = 0
|
| 215 |
self.my_block_end = 0
|
| 216 |
|
|
@@ -245,7 +228,6 @@ def deserialize_kv_cache(data):
|
|
| 245 |
# ============================================================================
|
| 246 |
|
| 247 |
def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
|
| 248 |
-
"""Send hidden states to worker and get result."""
|
| 249 |
try:
|
| 250 |
response = requests.post(
|
| 251 |
f"{url.rstrip('/')}/api/forward",
|
|
@@ -273,16 +255,13 @@ def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=Fals
|
|
| 273 |
# ============================================================================
|
| 274 |
|
| 275 |
def load_model():
|
| 276 |
-
"""Load model and extract components for this node."""
|
| 277 |
print("π Loading model...")
|
| 278 |
|
| 279 |
-
# Load config
|
| 280 |
config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
|
| 281 |
with open(config_path, 'r') as f:
|
| 282 |
model_config = json.load(f)
|
| 283 |
STATE.config = model_config
|
| 284 |
|
| 285 |
-
# Load tokenizer
|
| 286 |
from transformers import AutoTokenizer
|
| 287 |
from tokenizers import Tokenizer
|
| 288 |
|
|
@@ -294,10 +273,8 @@ def load_model():
|
|
| 294 |
STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
|
| 295 |
STATE.eos_token_id = model_config.get('eos_token_id', 50256)
|
| 296 |
|
| 297 |
-
# Load weights
|
| 298 |
weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
|
| 299 |
|
| 300 |
-
# Build full model to load weights
|
| 301 |
n_layers = model_config['num_hidden_layers']
|
| 302 |
d_model = model_config['hidden_size']
|
| 303 |
n_heads = model_config['num_attention_heads']
|
|
@@ -306,14 +283,12 @@ def load_model():
|
|
| 306 |
rope_theta = model_config['rope_theta']
|
| 307 |
vocab_size = model_config['vocab_size']
|
| 308 |
|
| 309 |
-
# Temporary full model
|
| 310 |
embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
|
| 311 |
blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
|
| 312 |
for i in range(n_layers)]
|
| 313 |
final_norm = RMSNorm(name="final_norm")
|
| 314 |
lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
|
| 315 |
|
| 316 |
-
# Build
|
| 317 |
dummy = tf.zeros((1, 16), dtype=tf.int32)
|
| 318 |
x = embedding(dummy)
|
| 319 |
for block in blocks:
|
|
@@ -321,7 +296,6 @@ def load_model():
|
|
| 321 |
x = final_norm(x)
|
| 322 |
_ = lm_head(x)
|
| 323 |
|
| 324 |
-
# Load weights into a temp model structure
|
| 325 |
class TempModel(keras.Model):
|
| 326 |
def __init__(self):
|
| 327 |
super().__init__()
|
|
@@ -340,25 +314,19 @@ def load_model():
|
|
| 340 |
temp_model.load_weights(weights_path)
|
| 341 |
print("β
Weights loaded")
|
| 342 |
|
| 343 |
-
# Extract components for this node
|
| 344 |
STATE.my_block_start = CONFIG["layer_start"]
|
| 345 |
STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
|
| 346 |
|
| 347 |
-
# HEAD always has embedding
|
| 348 |
STATE.embedding = embedding
|
| 349 |
-
|
| 350 |
-
# Extract our blocks
|
| 351 |
STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
|
| 352 |
print(f"β
Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
|
| 353 |
|
| 354 |
-
# HEAD has final norm and lm_head only if no workers OR we handle last block
|
| 355 |
has_workers = len(CONFIG["worker_urls"]) > 0
|
| 356 |
if not has_workers:
|
| 357 |
STATE.final_norm = final_norm
|
| 358 |
STATE.lm_head = lm_head
|
| 359 |
print("β
Loaded final norm and LM head (standalone mode)")
|
| 360 |
|
| 361 |
-
# Warmup
|
| 362 |
print("π₯ Warming up...")
|
| 363 |
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
|
| 364 |
x = STATE.embedding(dummy)
|
|
@@ -375,14 +343,8 @@ def load_model():
|
|
| 375 |
# ============================================================================
|
| 376 |
|
| 377 |
def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
|
| 378 |
-
"""
|
| 379 |
-
Full forward pass through HEAD + all workers.
|
| 380 |
-
Returns logits and updated KV caches.
|
| 381 |
-
"""
|
| 382 |
-
# Embedding
|
| 383 |
x = STATE.embedding(input_ids)
|
| 384 |
|
| 385 |
-
# Local blocks
|
| 386 |
new_local_kv = [] if use_cache else None
|
| 387 |
for i, block in enumerate(STATE.blocks):
|
| 388 |
block_past = past_kv_local[i] if past_kv_local else None
|
|
@@ -390,7 +352,6 @@ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None,
|
|
| 390 |
if use_cache:
|
| 391 |
new_local_kv.append(kv)
|
| 392 |
|
| 393 |
-
# Workers
|
| 394 |
new_worker_kv = {} if use_cache else None
|
| 395 |
for worker_url in CONFIG["worker_urls"]:
|
| 396 |
worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
|
|
@@ -398,12 +359,9 @@ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None,
|
|
| 398 |
if use_cache:
|
| 399 |
new_worker_kv[worker_url] = worker_kv
|
| 400 |
|
| 401 |
-
# Final (only if standalone or last worker returned to us)
|
| 402 |
-
# In distributed mode, the last worker applies final_norm + lm_head
|
| 403 |
if STATE.lm_head:
|
| 404 |
logits = STATE.lm_head(STATE.final_norm(x))
|
| 405 |
else:
|
| 406 |
-
# x should already be logits from last worker
|
| 407 |
logits = x
|
| 408 |
|
| 409 |
return logits, new_local_kv, new_worker_kv
|
|
@@ -465,7 +423,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
|
|
| 465 |
|
| 466 |
start = time.time()
|
| 467 |
|
| 468 |
-
# Prefill
|
| 469 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 470 |
try:
|
| 471 |
logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
|
|
@@ -477,7 +434,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
|
|
| 477 |
prefill_time = time.time() - start
|
| 478 |
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
|
| 479 |
|
| 480 |
-
# Generate
|
| 481 |
decode_start = time.time()
|
| 482 |
tokens_generated = 0
|
| 483 |
|
|
@@ -496,7 +452,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
|
|
| 496 |
tokens_generated += 1
|
| 497 |
yield generated
|
| 498 |
|
| 499 |
-
# Next step
|
| 500 |
next_input = tf.constant([[next_id]], dtype=tf.int32)
|
| 501 |
try:
|
| 502 |
logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
|
|
@@ -506,7 +461,6 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
|
|
| 506 |
|
| 507 |
next_logits = logits[0, -1, :].numpy()
|
| 508 |
|
| 509 |
-
# Stats
|
| 510 |
if tokens_generated > 0:
|
| 511 |
total = time.time() - start
|
| 512 |
tps = tokens_generated / (time.time() - decode_start)
|
|
@@ -519,10 +473,12 @@ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_
|
|
| 519 |
|
| 520 |
def format_prompt(message: str, history: list, reasoning: bool) -> str:
|
| 521 |
prompt = ""
|
| 522 |
-
for
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
|
|
|
|
|
|
| 526 |
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
| 527 |
if reasoning:
|
| 528 |
prompt += "<think>"
|
|
@@ -536,21 +492,27 @@ def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reas
|
|
| 536 |
|
| 537 |
prompt = format_prompt(message, history, reasoning)
|
| 538 |
|
|
|
|
|
|
|
|
|
|
| 539 |
for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
|
| 540 |
display = text
|
|
|
|
|
|
|
| 541 |
for tag in ["<|im_end|>", "<im end for model tun>"]:
|
| 542 |
if tag in display:
|
| 543 |
idx = display.find(tag)
|
| 544 |
stats = display.find("\n\n*[")
|
| 545 |
display = display[:idx] + (display[stats:] if stats > idx else "")
|
| 546 |
|
|
|
|
| 547 |
if reasoning and '<think>' in display and '</think>' in display:
|
| 548 |
s, e = display.find('<think>'), display.find('</think>')
|
| 549 |
if s < e:
|
| 550 |
thought = display[s+7:e].strip()
|
| 551 |
display = display[:s] + f'<details><summary>π§ Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
|
| 552 |
|
| 553 |
-
yield history + [
|
| 554 |
|
| 555 |
|
| 556 |
def stop():
|
|
@@ -575,7 +537,11 @@ def create_ui():
|
|
| 575 |
gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
|
| 576 |
|
| 577 |
reasoning = gr.State(False)
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
with gr.Row():
|
| 581 |
reason_btn = gr.Button("π‘", size="sm", scale=0)
|
|
@@ -600,7 +566,7 @@ def create_ui():
|
|
| 600 |
click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
|
| 601 |
stop_btn.click(stop, cancels=[submit, click])
|
| 602 |
|
| 603 |
-
gr.Button("ποΈ Clear").click(lambda:
|
| 604 |
|
| 605 |
return app
|
| 606 |
|
|
|
|
| 8 |
# ============================================================================
|
| 9 |
|
| 10 |
CONFIG = {
|
|
|
|
| 11 |
"node_id": "head-main",
|
|
|
|
|
|
|
|
|
|
| 12 |
"layer_start": 0,
|
| 13 |
+
"layer_end": 6,
|
| 14 |
+
"worker_urls": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"secret_token": "sam2-distributed-secret-change-me",
|
|
|
|
|
|
|
| 16 |
"model_repo": "Smilyai-labs/Sam-large-2",
|
| 17 |
"cache_dir": "./model_cache",
|
| 18 |
}
|
| 19 |
|
| 20 |
# ============================================================================
|
| 21 |
+
# CPU Optimization
|
| 22 |
# ============================================================================
|
| 23 |
|
| 24 |
import os
|
|
|
|
| 32 |
|
| 33 |
import json
|
| 34 |
import time
|
|
|
|
| 35 |
import io
|
| 36 |
import base64
|
| 37 |
from typing import Dict, List, Optional, Tuple, Any
|
|
|
|
| 190 |
self.config = None
|
| 191 |
self.tokenizer = None
|
| 192 |
self.eos_token_id = 50256
|
|
|
|
|
|
|
| 193 |
self.embedding = None
|
| 194 |
self.blocks: List = []
|
| 195 |
self.final_norm = None
|
| 196 |
self.lm_head = None
|
|
|
|
| 197 |
self.my_block_start = 0
|
| 198 |
self.my_block_end = 0
|
| 199 |
|
|
|
|
| 228 |
# ============================================================================
|
| 229 |
|
| 230 |
def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
|
|
|
|
| 231 |
try:
|
| 232 |
response = requests.post(
|
| 233 |
f"{url.rstrip('/')}/api/forward",
|
|
|
|
| 255 |
# ============================================================================
|
| 256 |
|
| 257 |
def load_model():
|
|
|
|
| 258 |
print("π Loading model...")
|
| 259 |
|
|
|
|
| 260 |
config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
|
| 261 |
with open(config_path, 'r') as f:
|
| 262 |
model_config = json.load(f)
|
| 263 |
STATE.config = model_config
|
| 264 |
|
|
|
|
| 265 |
from transformers import AutoTokenizer
|
| 266 |
from tokenizers import Tokenizer
|
| 267 |
|
|
|
|
| 273 |
STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
|
| 274 |
STATE.eos_token_id = model_config.get('eos_token_id', 50256)
|
| 275 |
|
|
|
|
| 276 |
weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
|
| 277 |
|
|
|
|
| 278 |
n_layers = model_config['num_hidden_layers']
|
| 279 |
d_model = model_config['hidden_size']
|
| 280 |
n_heads = model_config['num_attention_heads']
|
|
|
|
| 283 |
rope_theta = model_config['rope_theta']
|
| 284 |
vocab_size = model_config['vocab_size']
|
| 285 |
|
|
|
|
| 286 |
embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
|
| 287 |
blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
|
| 288 |
for i in range(n_layers)]
|
| 289 |
final_norm = RMSNorm(name="final_norm")
|
| 290 |
lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
|
| 291 |
|
|
|
|
| 292 |
dummy = tf.zeros((1, 16), dtype=tf.int32)
|
| 293 |
x = embedding(dummy)
|
| 294 |
for block in blocks:
|
|
|
|
| 296 |
x = final_norm(x)
|
| 297 |
_ = lm_head(x)
|
| 298 |
|
|
|
|
| 299 |
class TempModel(keras.Model):
|
| 300 |
def __init__(self):
|
| 301 |
super().__init__()
|
|
|
|
| 314 |
temp_model.load_weights(weights_path)
|
| 315 |
print("β
Weights loaded")
|
| 316 |
|
|
|
|
| 317 |
STATE.my_block_start = CONFIG["layer_start"]
|
| 318 |
STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
|
| 319 |
|
|
|
|
| 320 |
STATE.embedding = embedding
|
|
|
|
|
|
|
| 321 |
STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
|
| 322 |
print(f"β
Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
|
| 323 |
|
|
|
|
| 324 |
has_workers = len(CONFIG["worker_urls"]) > 0
|
| 325 |
if not has_workers:
|
| 326 |
STATE.final_norm = final_norm
|
| 327 |
STATE.lm_head = lm_head
|
| 328 |
print("β
Loaded final norm and LM head (standalone mode)")
|
| 329 |
|
|
|
|
| 330 |
print("π₯ Warming up...")
|
| 331 |
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
|
| 332 |
x = STATE.embedding(dummy)
|
|
|
|
| 343 |
# ============================================================================
|
| 344 |
|
| 345 |
def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
x = STATE.embedding(input_ids)
|
| 347 |
|
|
|
|
| 348 |
new_local_kv = [] if use_cache else None
|
| 349 |
for i, block in enumerate(STATE.blocks):
|
| 350 |
block_past = past_kv_local[i] if past_kv_local else None
|
|
|
|
| 352 |
if use_cache:
|
| 353 |
new_local_kv.append(kv)
|
| 354 |
|
|
|
|
| 355 |
new_worker_kv = {} if use_cache else None
|
| 356 |
for worker_url in CONFIG["worker_urls"]:
|
| 357 |
worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
|
|
|
|
| 359 |
if use_cache:
|
| 360 |
new_worker_kv[worker_url] = worker_kv
|
| 361 |
|
|
|
|
|
|
|
| 362 |
if STATE.lm_head:
|
| 363 |
logits = STATE.lm_head(STATE.final_norm(x))
|
| 364 |
else:
|
|
|
|
| 365 |
logits = x
|
| 366 |
|
| 367 |
return logits, new_local_kv, new_worker_kv
|
|
|
|
| 423 |
|
| 424 |
start = time.time()
|
| 425 |
|
|
|
|
| 426 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 427 |
try:
|
| 428 |
logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
|
|
|
|
| 434 |
prefill_time = time.time() - start
|
| 435 |
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
|
| 436 |
|
|
|
|
| 437 |
decode_start = time.time()
|
| 438 |
tokens_generated = 0
|
| 439 |
|
|
|
|
| 452 |
tokens_generated += 1
|
| 453 |
yield generated
|
| 454 |
|
|
|
|
| 455 |
next_input = tf.constant([[next_id]], dtype=tf.int32)
|
| 456 |
try:
|
| 457 |
logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
|
|
|
|
| 461 |
|
| 462 |
next_logits = logits[0, -1, :].numpy()
|
| 463 |
|
|
|
|
| 464 |
if tokens_generated > 0:
|
| 465 |
total = time.time() - start
|
| 466 |
tps = tokens_generated / (time.time() - decode_start)
|
|
|
|
| 473 |
|
| 474 |
def format_prompt(message: str, history: list, reasoning: bool) -> str:
|
| 475 |
prompt = ""
|
| 476 |
+
for msg in history:
|
| 477 |
+
if msg["role"] == "user":
|
| 478 |
+
prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
|
| 479 |
+
elif msg["role"] == "assistant":
|
| 480 |
+
content = msg['content'].split('*[')[0].strip()
|
| 481 |
+
prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
|
| 482 |
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
| 483 |
if reasoning:
|
| 484 |
prompt += "<think>"
|
|
|
|
| 492 |
|
| 493 |
prompt = format_prompt(message, history, reasoning)
|
| 494 |
|
| 495 |
+
# Add user message to history
|
| 496 |
+
history = history + [{"role": "user", "content": message}]
|
| 497 |
+
|
| 498 |
for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
|
| 499 |
display = text
|
| 500 |
+
|
| 501 |
+
# Clean stop tags
|
| 502 |
for tag in ["<|im_end|>", "<im end for model tun>"]:
|
| 503 |
if tag in display:
|
| 504 |
idx = display.find(tag)
|
| 505 |
stats = display.find("\n\n*[")
|
| 506 |
display = display[:idx] + (display[stats:] if stats > idx else "")
|
| 507 |
|
| 508 |
+
# Format reasoning
|
| 509 |
if reasoning and '<think>' in display and '</think>' in display:
|
| 510 |
s, e = display.find('<think>'), display.find('</think>')
|
| 511 |
if s < e:
|
| 512 |
thought = display[s+7:e].strip()
|
| 513 |
display = display[:s] + f'<details><summary>π§ Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
|
| 514 |
|
| 515 |
+
yield history + [{"role": "assistant", "content": display.strip()}]
|
| 516 |
|
| 517 |
|
| 518 |
def stop():
|
|
|
|
| 537 |
gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
|
| 538 |
|
| 539 |
reasoning = gr.State(False)
|
| 540 |
+
|
| 541 |
+
chatbot = gr.Chatbot(
|
| 542 |
+
height=500,
|
| 543 |
+
type="messages" # Use new messages format
|
| 544 |
+
)
|
| 545 |
|
| 546 |
with gr.Row():
|
| 547 |
reason_btn = gr.Button("π‘", size="sm", scale=0)
|
|
|
|
| 566 |
click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
|
| 567 |
stop_btn.click(stop, cancels=[submit, click])
|
| 568 |
|
| 569 |
+
gr.Button("ποΈ Clear").click(lambda: [], outputs=[chatbot])
|
| 570 |
|
| 571 |
return app
|
| 572 |
|