Bc-AI commited on
Commit
3760699
Β·
verified Β·
1 Parent(s): e670cfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +534 -498
app.py CHANGED
@@ -1,584 +1,620 @@
1
  """
2
- WAN-Distributed JAX Inference on Hugging Face Spaces
3
- Each Space runs this app and can be configured as head or worker.
4
  """
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import os
 
 
 
 
 
 
 
 
7
  import json
8
  import time
9
  import threading
10
- import queue
11
- from typing import Dict, List, Optional, Any
12
- from dataclasses import dataclass, field
13
- import hashlib
14
 
15
  import gradio as gr
16
  import numpy as np
17
  import requests
 
 
 
18
 
19
- # Use CPU JAX
20
- os.environ["JAX_PLATFORMS"] = "cpu"
21
- import jax
22
- import jax.numpy as jnp
23
 
 
24
 
25
  # ============================================================================
26
- # CONFIGURATION
27
  # ============================================================================
28
 
29
- @dataclass
30
- class NodeConfig:
31
- """Node configuration from environment."""
32
- role: str = os.environ.get("NODE_ROLE", "worker") # "head" or "worker"
33
- node_id: str = os.environ.get("NODE_ID", hashlib.md5(os.urandom(8)).hexdigest()[:8])
34
- head_url: str = os.environ.get("HEAD_URL", "") # URL of head Space (for workers)
35
- secret_token: str = os.environ.get("SECRET_TOKEN", "default-token")
36
- port: int = int(os.environ.get("PORT", "7860"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
- CONFIG = NodeConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
41
 
42
  # ============================================================================
43
- # SHARED STATE
44
  # ============================================================================
45
 
46
- class ClusterState:
47
- """Shared state for the cluster."""
48
-
49
- def __init__(self):
50
- self.workers: Dict[str, Dict] = {} # worker_id -> info
51
- self.shards: Dict[str, np.ndarray] = {} # shard_name -> data
52
- self.lock = threading.Lock()
53
- self.is_initialized = False
54
- self.pending_results: Dict[str, Any] = {}
55
- self.request_queue: queue.Queue = queue.Queue()
56
-
57
- def register_worker(self, worker_id: str, url: str, info: Dict) -> bool:
58
- with self.lock:
59
- self.workers[worker_id] = {
60
- "url": url,
61
- "info": info,
62
- "registered_at": time.time(),
63
- "last_seen": time.time(),
64
- "status": "active"
65
- }
66
- return True
67
-
68
- def get_workers(self) -> List[Dict]:
69
- with self.lock:
70
- return [
71
- {"worker_id": wid, **winfo}
72
- for wid, winfo in self.workers.items()
73
- if winfo.get("status") == "active"
74
- ]
75
-
76
- def store_shard(self, name: str, data: np.ndarray):
77
- with self.lock:
78
- self.shards[name] = data
79
-
80
- def get_shard(self, name: str) -> Optional[np.ndarray]:
81
- with self.lock:
82
- return self.shards.get(name)
83
-
84
- def heartbeat(self, worker_id: str):
85
- with self.lock:
86
- if worker_id in self.workers:
87
- self.workers[worker_id]["last_seen"] = time.time()
88
-
89
-
90
- STATE = ClusterState()
91
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ============================================================================
94
- # HTTP COMMUNICATION LAYER
95
  # ============================================================================
96
 
97
- def make_request(url: str, endpoint: str, data: Dict, timeout: int = 30) -> Optional[Dict]:
98
- """Make HTTP request to another Space."""
99
  try:
100
- full_url = f"{url.rstrip('/')}/api/{endpoint}"
101
- headers = {"Authorization": f"Bearer {CONFIG.secret_token}"}
102
-
103
  response = requests.post(
104
- full_url,
105
- json=data,
106
- headers=headers,
107
- timeout=timeout
 
 
 
 
108
  )
109
 
110
  if response.status_code == 200:
111
- return response.json()
 
 
 
112
  else:
113
- print(f"Request failed: {response.status_code} - {response.text}")
114
- return None
115
  except Exception as e:
116
- print(f"Request error: {e}")
117
- return None
118
-
119
 
120
  # ============================================================================
121
- # WORKER LOGIC
122
  # ============================================================================
123
 
124
- def worker_register_with_head():
125
- """Register this worker with the head node."""
126
- if not CONFIG.head_url:
127
- print("No HEAD_URL configured, cannot register")
128
- return False
129
-
130
- # Get this Space's URL from environment or construct it
131
- space_url = os.environ.get("SPACE_URL", f"http://localhost:{CONFIG.port}")
132
-
133
- result = make_request(
134
- CONFIG.head_url,
135
- "register_worker",
136
- {
137
- "worker_id": CONFIG.node_id,
138
- "worker_url": space_url,
139
- "info": {
140
- "jax_devices": len(jax.devices()),
141
- "platform": jax.default_backend(),
142
- }
143
- }
144
- )
145
-
146
- if result and result.get("success"):
147
- print(f"Registered with head at {CONFIG.head_url}")
148
- return True
149
- return False
150
-
151
-
152
- def worker_heartbeat_loop():
153
- """Send periodic heartbeats to head."""
154
- while True:
155
- time.sleep(30)
156
- if CONFIG.head_url:
157
- make_request(
158
- CONFIG.head_url,
159
- "heartbeat",
160
- {"worker_id": CONFIG.node_id}
161
- )
162
-
163
-
164
- def worker_forward_pass(input_data: np.ndarray) -> np.ndarray:
165
- """Run forward pass on local shards."""
166
- x = jnp.array(input_data)
167
-
168
- # Apply each stored shard (simple linear layers for demo)
169
- for name, weight in sorted(STATE.shards.items()):
170
- if weight.ndim == 2:
171
- # Matrix multiply for weight matrices
172
- if x.shape[-1] == weight.shape[0]:
173
- x = x @ weight
174
- elif weight.ndim == 1:
175
- # Add for biases
176
- if x.shape[-1] == weight.shape[0]:
177
- x = x + weight
178
-
179
- # Apply simple activation
180
- x = jax.nn.relu(x)
181
-
182
- return np.array(x)
183
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # ============================================================================
186
- # HEAD NODE LOGIC
187
  # ============================================================================
188
 
189
- def head_distribute_model(params: Dict[str, np.ndarray]) -> bool:
190
- """Distribute model parameters to workers."""
191
- workers = STATE.get_workers()
192
- if not workers:
193
- print("No workers available")
194
- return False
195
-
196
- # Simple round-robin distribution
197
- param_list = list(params.items())
198
- shards_per_worker = max(1, len(param_list) // len(workers))
199
-
200
- for i, worker in enumerate(workers):
201
- start_idx = i * shards_per_worker
202
- end_idx = start_idx + shards_per_worker if i < len(workers) - 1 else len(param_list)
203
-
204
- worker_shards = dict(param_list[start_idx:end_idx])
205
-
206
- for shard_name, shard_data in worker_shards.items():
207
- result = make_request(
208
- worker["url"],
209
- "store_shard",
210
- {
211
- "name": shard_name,
212
- "data": shard_data.tolist(),
213
- "shape": list(shard_data.shape),
214
- "dtype": str(shard_data.dtype)
215
- },
216
- timeout=60
217
- )
218
-
219
- if not result or not result.get("success"):
220
- print(f"Failed to send shard {shard_name} to worker {worker['worker_id']}")
221
- return False
222
-
223
- print(f"Distributed {len(params)} shards to {len(workers)} workers")
224
- return True
225
-
226
-
227
- def head_run_inference(input_data: np.ndarray) -> np.ndarray:
228
- """Run distributed inference across workers."""
229
- workers = STATE.get_workers()
230
-
231
- if not workers:
232
- # No workers, run locally
233
- return worker_forward_pass(input_data)
234
-
235
- # Pipeline through workers
236
- current_data = input_data
237
-
238
- for worker in workers:
239
- result = make_request(
240
- worker["url"],
241
- "forward",
242
- {
243
- "data": current_data.tolist(),
244
- "shape": list(current_data.shape),
245
- },
246
- timeout=60
247
- )
248
-
249
- if result and "output" in result:
250
- current_data = np.array(result["output"])
251
- else:
252
- print(f"Worker {worker['worker_id']} failed, using local fallback")
253
- current_data = worker_forward_pass(current_data)
254
 
255
- return current_data
256
-
257
 
258
  # ============================================================================
259
- # API ENDPOINTS (Gradio doesn't have native API, so we use a simple approach)
260
  # ============================================================================
261
 
262
- def api_handler(endpoint: str, data: Dict) -> Dict:
263
- """Handle API requests based on endpoint."""
264
-
265
- # Verify token
266
- # (In production, check Authorization header)
267
 
268
- if endpoint == "register_worker":
269
- success = STATE.register_worker(
270
- data["worker_id"],
271
- data["worker_url"],
272
- data.get("info", {})
273
- )
274
- return {"success": success, "message": "Worker registered" if success else "Failed"}
275
-
276
- elif endpoint == "heartbeat":
277
- STATE.heartbeat(data.get("worker_id", ""))
278
- return {"success": True}
279
-
280
- elif endpoint == "store_shard":
281
- shard_data = np.array(data["data"], dtype=data.get("dtype", "float32"))
282
- shard_data = shard_data.reshape(data["shape"])
283
- STATE.store_shard(data["name"], shard_data)
284
- return {"success": True, "shard": data["name"]}
285
-
286
- elif endpoint == "forward":
287
- input_data = np.array(data["data"]).reshape(data["shape"])
288
- output = worker_forward_pass(input_data)
289
- return {"output": output.tolist(), "shape": list(output.shape)}
290
-
291
- elif endpoint == "status":
292
- return {
293
- "node_id": CONFIG.node_id,
294
- "role": CONFIG.role,
295
- "workers": len(STATE.get_workers()),
296
- "shards": list(STATE.shards.keys()),
297
- "jax_devices": len(jax.devices()),
298
- }
299
-
300
- elif endpoint == "get_workers":
301
- return {"workers": STATE.get_workers()}
302
 
 
 
 
303
  else:
304
- return {"error": f"Unknown endpoint: {endpoint}"}
305
-
306
-
307
- # ============================================================================
308
- # GRADIO INTERFACE
309
- # ============================================================================
310
-
311
- def create_test_model(num_layers: int = 4, hidden_size: int = 128) -> Dict[str, np.ndarray]:
312
- """Create a simple test model."""
313
- params = {}
314
 
315
- for i in range(num_layers):
316
- params[f"layer_{i}_weight"] = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02
317
- params[f"layer_{i}_bias"] = np.zeros(hidden_size, dtype=np.float32)
 
 
 
 
 
 
318
 
319
- return params
320
 
321
 
322
- def gradio_run_inference(input_text: str) -> str:
323
- """Run inference from Gradio UI."""
324
- # Simple tokenization (ASCII values normalized)
325
- tokens = np.array([ord(c) / 128.0 for c in input_text[:128]], dtype=np.float32)
326
 
327
- # Pad to fixed size
328
- if len(tokens) < 128:
329
- tokens = np.pad(tokens, (0, 128 - len(tokens)))
 
330
 
331
- # Run inference
332
- start_time = time.time()
333
 
334
- if CONFIG.role == "head":
335
- output = head_run_inference(tokens)
336
- else:
337
- output = worker_forward_pass(tokens)
338
-
339
- latency = (time.time() - start_time) * 1000
340
-
341
- # Format output
342
- result = f"Output shape: {output.shape}\n"
343
- result += f"Output mean: {output.mean():.4f}\n"
344
- result += f"Output std: {output.std():.4f}\n"
345
- result += f"Latency: {latency:.1f}ms\n"
346
- result += f"Workers used: {len(STATE.get_workers())}"
347
-
348
- return result
349
-
350
-
351
- def gradio_get_status() -> str:
352
- """Get cluster status for Gradio UI."""
353
- status = {
354
- "Node ID": CONFIG.node_id,
355
- "Role": CONFIG.role,
356
- "JAX Devices": len(jax.devices()),
357
- "JAX Backend": jax.default_backend(),
358
- "Stored Shards": len(STATE.shards),
359
- "Shard Names": list(STATE.shards.keys())[:10], # First 10
360
- }
361
-
362
- if CONFIG.role == "head":
363
- workers = STATE.get_workers()
364
- status["Connected Workers"] = len(workers)
365
- status["Worker List"] = [
366
- f"{w['worker_id']} @ {w['url']}"
367
- for w in workers
368
- ]
369
- else:
370
- status["Head URL"] = CONFIG.head_url
371
- status["Registered"] = STATE.is_initialized
372
 
373
- return json.dumps(status, indent=2)
374
-
375
-
376
- def gradio_init_model(num_layers: int, hidden_size: int) -> str:
377
- """Initialize and distribute model."""
378
- params = create_test_model(int(num_layers), int(hidden_size))
379
 
380
- if CONFIG.role == "head":
381
- workers = STATE.get_workers()
382
- if workers:
383
- success = head_distribute_model(params)
384
- if success:
385
- return f"Distributed {len(params)} shards to {len(workers)} workers"
386
- else:
387
- return "Failed to distribute model"
388
- else:
389
- # Store locally
390
- for name, data in params.items():
391
- STATE.store_shard(name, data)
392
- return f"No workers - stored {len(params)} shards locally"
393
- else:
394
- # Worker stores locally
395
- for name, data in params.items():
396
- STATE.store_shard(name, data)
397
- return f"Stored {len(params)} shards locally"
398
-
399
-
400
- def gradio_register_worker(worker_url: str) -> str:
401
- """Manually register a worker (for head node)."""
402
- if CONFIG.role != "head":
403
- return "Only head node can register workers"
404
-
405
- # Ping the worker
406
- result = make_request(worker_url, "status", {})
407
 
408
- if result:
409
- worker_id = result.get("node_id", f"worker_{len(STATE.workers)}")
410
- STATE.register_worker(worker_id, worker_url, result)
411
- return f"Registered worker {worker_id}"
412
- else:
413
- return f"Failed to reach worker at {worker_url}"
414
-
415
-
416
- def gradio_api_call(endpoint: str, json_data: str) -> str:
417
- """Make API call (for testing)."""
418
  try:
419
- data = json.loads(json_data) if json_data else {}
420
- result = api_handler(endpoint, data)
421
- return json.dumps(result, indent=2)
422
  except Exception as e:
423
- return f"Error: {e}"
424
-
425
-
426
- # ============================================================================
427
- # MAIN APP
428
- # ============================================================================
429
-
430
- def create_app():
431
- """Create Gradio app based on node role."""
432
 
433
- # Start background tasks
434
- if CONFIG.role == "worker" and CONFIG.head_url:
435
- # Register with head
436
- threading.Thread(target=lambda: time.sleep(5) or worker_register_with_head(), daemon=True).start()
437
- # Heartbeat loop
438
- threading.Thread(target=worker_heartbeat_loop, daemon=True).start()
439
 
440
- # Create Gradio interface
441
- with gr.Blocks(title=f"WAN-JAX {CONFIG.role.upper()} - {CONFIG.node_id}") as app:
442
- gr.Markdown(f"""
443
- # 🌐 WAN-Distributed JAX Inference
 
 
 
 
444
 
445
- **Node ID:** `{CONFIG.node_id}` | **Role:** `{CONFIG.role.upper()}`
446
 
447
- {"This is the **HEAD** node - it coordinates workers and runs inference." if CONFIG.role == "head" else "This is a **WORKER** node - it stores model shards and computes."}
448
- """)
449
 
450
- with gr.Tab("Status"):
451
- status_output = gr.Textbox(label="Cluster Status", lines=15)
452
- refresh_btn = gr.Button("Refresh Status")
453
- refresh_btn.click(gradio_get_status, outputs=status_output)
454
-
455
- # Auto-refresh on load
456
- app.load(gradio_get_status, outputs=status_output)
457
 
458
- with gr.Tab("Inference"):
459
- with gr.Row():
460
- with gr.Column():
461
- input_text = gr.Textbox(
462
- label="Input Text",
463
- placeholder="Enter text to process...",
464
- lines=3
465
- )
466
- infer_btn = gr.Button("Run Inference", variant="primary")
467
-
468
- with gr.Column():
469
- output_text = gr.Textbox(label="Output", lines=8)
470
-
471
- infer_btn.click(gradio_run_inference, inputs=input_text, outputs=output_text)
472
 
473
- with gr.Tab("Model"):
474
- with gr.Row():
475
- num_layers = gr.Slider(1, 12, value=4, step=1, label="Number of Layers")
476
- hidden_size = gr.Slider(32, 512, value=128, step=32, label="Hidden Size")
477
-
478
- init_btn = gr.Button("Initialize Model")
479
- init_output = gr.Textbox(label="Result")
480
-
481
- init_btn.click(
482
- gradio_init_model,
483
- inputs=[num_layers, hidden_size],
484
- outputs=init_output
485
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
- if CONFIG.role == "head":
488
- with gr.Tab("Workers"):
489
- worker_url_input = gr.Textbox(
490
- label="Worker Space URL",
491
- placeholder="https://username-spacename.hf.space"
492
- )
493
- register_btn = gr.Button("Register Worker")
494
- register_output = gr.Textbox(label="Result")
495
-
496
- register_btn.click(
497
- gradio_register_worker,
498
- inputs=worker_url_input,
499
- outputs=register_output
500
- )
501
 
502
- with gr.Tab("API"):
503
- gr.Markdown("""
504
- ### Direct API Access
505
- Use this tab to test API endpoints directly.
506
-
507
- **Endpoints:**
508
- - `status` - Get node status
509
- - `register_worker` - Register a worker (head only)
510
- - `store_shard` - Store a model shard
511
- - `forward` - Run forward pass
512
- - `get_workers` - List workers (head only)
513
- """)
514
-
515
- endpoint_input = gr.Textbox(label="Endpoint", value="status")
516
- json_input = gr.Textbox(label="JSON Data", value="{}", lines=5)
517
- api_btn = gr.Button("Call API")
518
- api_output = gr.Textbox(label="Response", lines=10)
519
-
520
- api_btn.click(
521
- gradio_api_call,
522
- inputs=[endpoint_input, json_input],
523
- outputs=api_output
524
- )
525
-
526
- return app
527
 
 
 
 
528
 
529
  # ============================================================================
530
- # FASTAPI MOUNTING FOR TRUE API ACCESS
531
  # ============================================================================
532
 
533
- # Optional: Mount FastAPI for proper API endpoints
534
- try:
535
- from fastapi import FastAPI, Request, HTTPException
536
- from fastapi.responses import JSONResponse
537
 
538
- api_app = FastAPI()
539
-
540
- @api_app.post("/api/{endpoint}")
541
- async def api_endpoint(endpoint: str, request: Request):
542
- # Check authorization
543
- auth_header = request.headers.get("Authorization", "")
544
- if not auth_header.startswith("Bearer "):
545
- # Allow without auth for testing, but log it
546
- pass
547
 
548
- try:
549
- data = await request.json()
550
- except:
551
- data = {}
552
 
553
- result = api_handler(endpoint, data)
554
- return JSONResponse(result)
555
-
556
- @api_app.get("/api/status")
557
- async def get_status():
558
- return JSONResponse(api_handler("status", {}))
559
-
560
- # Mount Gradio app
561
- app = create_app()
562
- api_app = gr.mount_gradio_app(api_app, app, path="/")
563
-
564
- print("Running with FastAPI + Gradio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
- except ImportError:
567
- # FastAPI not available, use pure Gradio
568
- app = create_app()
569
- print("Running with pure Gradio")
570
-
571
 
572
  # ============================================================================
573
- # LAUNCH
574
  # ============================================================================
575
 
576
- if __name__ == "__main__":
577
- print(f"Starting WAN-JAX Node")
578
- print(f" Node ID: {CONFIG.node_id}")
579
- print(f" Role: {CONFIG.role}")
580
- print(f" Head URL: {CONFIG.head_url}")
581
- print(f" JAX devices: {jax.devices()}")
582
-
583
- app = create_app()
584
- app.launch(server_name="0.0.0.0", server_port=CONFIG.port)
 
 
1
  """
2
+ Sam-large-2 Distributed Inference - HEAD NODE
3
+ Edit the CONFIG below, then deploy.
4
  """
5
 
6
+ # ============================================================================
7
+ # βš™οΈ CONFIGURATION - EDIT THIS
8
+ # ============================================================================
9
+
10
+ CONFIG = {
11
+ # This node's identity
12
+ "node_id": "head-main",
13
+
14
+ # Which transformer blocks this node runs (0-indexed)
15
+ # Sam-large-2 has 12 blocks (0-11)
16
+ "layer_start": 0,
17
+ "layer_end": 6, # exclusive, so this runs blocks 0,1,2,3,4,5
18
+
19
+ # Worker Space URLs (in order of execution)
20
+ # Leave empty [] for standalone mode (all layers on this node)
21
+ "worker_urls": [
22
+ # "https://YOUR-WORKER-SPACE.hf.space",
23
+ ],
24
+
25
+ # Shared secret for worker communication
26
+ "secret_token": "sam2-distributed-secret-change-me",
27
+
28
+ # Model settings
29
+ "model_repo": "Smilyai-labs/Sam-large-2",
30
+ "cache_dir": "./model_cache",
31
+ }
32
+
33
+ # ============================================================================
34
+ # CPU Optimization - MUST be before TensorFlow import
35
+ # ============================================================================
36
+
37
  import os
38
+ NUM_CORES = os.cpu_count() or 4
39
+
40
+ os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
41
+ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
42
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
43
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
44
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
45
+
46
  import json
47
  import time
48
  import threading
49
+ import io
50
+ import base64
51
+ from typing import Dict, List, Optional, Tuple, Any
 
52
 
53
  import gradio as gr
54
  import numpy as np
55
  import requests
56
+ import tensorflow as tf
57
+ import keras
58
+ from huggingface_hub import hf_hub_download
59
 
60
+ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
61
+ tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
 
 
62
 
63
+ print(f"βœ… CPU optimized: {NUM_CORES} threads")
64
 
65
  # ============================================================================
66
+ # Model Architecture
67
  # ============================================================================
68
 
69
+ @keras.saving.register_keras_serializable()
70
+ class RotaryEmbedding(keras.layers.Layer):
71
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
72
+ super().__init__(**kwargs)
73
+ self.dim = dim
74
+ self.max_len = max_len
75
+ self.theta = theta
76
+ self.built_cache = False
77
+ self.cos_cached = None
78
+ self.sin_cached = None
79
+
80
+ def build(self, input_shape):
81
+ super().build(input_shape)
82
+
83
+ def _build_cache(self):
84
+ if not self.built_cache:
85
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
86
+ t = tf.range(self.max_len, dtype=tf.float32)
87
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
88
+ emb = tf.concat([freqs, freqs], axis=-1)
89
+ self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
90
+ self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
91
+ self.built_cache = True
92
+
93
+ def rotate_half(self, x):
94
+ x1, x2 = tf.split(x, 2, axis=-1)
95
+ return tf.concat([-x2, x1], axis=-1)
96
+
97
+ def call(self, q, k, offset=0):
98
+ self._build_cache()
99
+ seq_len = tf.shape(q)[2]
100
+ dtype = q.dtype
101
+ cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
102
+ sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
103
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
104
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
105
+ return q_embed, k_embed
106
+
107
+ def get_config(self):
108
+ return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta}
109
+
110
+
111
+ @keras.saving.register_keras_serializable()
112
+ class RMSNorm(keras.layers.Layer):
113
+ def __init__(self, epsilon=1e-5, **kwargs):
114
+ super().__init__(**kwargs)
115
+ self.epsilon = epsilon
116
+
117
+ def build(self, input_shape):
118
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
119
+ super().build(input_shape)
120
+
121
+ def call(self, x):
122
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
123
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
124
+
125
+ def get_config(self):
126
+ return {**super().get_config(), "epsilon": self.epsilon}
127
+
128
+
129
+ @keras.saving.register_keras_serializable()
130
+ class TransformerBlock(keras.layers.Layer):
131
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
132
+ super().__init__(**kwargs)
133
+ self.d_model = d_model
134
+ self.n_heads = n_heads
135
+ self.ff_dim = ff_dim
136
+ self.dropout_rate = dropout
137
+ self.max_len = max_len
138
+ self.rope_theta = rope_theta
139
+ self.head_dim = d_model // n_heads
140
+ self.layer_idx = layer_idx
141
+
142
+ def build(self, input_shape):
143
+ self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
144
+ self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
145
+ self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
146
+ self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
147
+ self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
148
+ self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
149
+ self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
150
+ self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
151
+ self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
152
+ self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
153
+ self.dropout = keras.layers.Dropout(self.dropout_rate)
154
+ super().build(input_shape)
155
+
156
+ def call(self, x, training=None, past_kv=None, use_cache=False):
157
+ B, T = tf.shape(x)[0], tf.shape(x)[1]
158
+ dtype = x.dtype
159
+
160
+ res = x
161
+ y = self.pre_attn_norm(x)
162
+
163
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
164
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
165
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
166
+
167
+ past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
168
+ q, k = self.rope(q, k, offset=past_len)
169
+
170
+ if past_kv is not None:
171
+ k = tf.concat([past_kv[0], k], axis=2)
172
+ v = tf.concat([past_kv[1], v], axis=2)
173
+
174
+ new_kv = (k, v) if use_cache else None
175
+
176
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
177
+ full_len = tf.shape(k)[2]
178
+ q_pos = tf.range(past_len, past_len + T)
179
+ k_pos = tf.range(full_len)
180
+ mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9)
181
+ scores = scores + tf.cast(mask[None, None, :, :], dtype)
182
+
183
+ attn = tf.nn.softmax(scores, axis=-1)
184
+ attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model])
185
+ x = res + self.dropout(self.out_proj(attn_out), training=training)
186
+
187
+ res = x
188
+ y = self.pre_ffn_norm(x)
189
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
190
+ return res + self.dropout(ffn, training=training), new_kv
191
+
192
+ def get_config(self):
193
+ return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads,
194
+ "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len,
195
+ "rope_theta": self.rope_theta, "layer_idx": self.layer_idx}
196
 
197
 
198
+ # ============================================================================
199
+ # State
200
+ # ============================================================================
201
+
202
+ class ModelState:
203
+ def __init__(self):
204
+ self.config = None
205
+ self.tokenizer = None
206
+ self.eos_token_id = 50256
207
+
208
+ # Model components
209
+ self.embedding = None
210
+ self.blocks: List = []
211
+ self.final_norm = None
212
+ self.lm_head = None
213
+
214
+ self.my_block_start = 0
215
+ self.my_block_end = 0
216
 
217
+ STATE = ModelState()
218
+ stop_generation = False
219
 
220
  # ============================================================================
221
+ # Serialization
222
  # ============================================================================
223
 
224
+ def serialize_tensor(tensor: tf.Tensor) -> str:
225
+ buffer = io.BytesIO()
226
+ np.save(buffer, tensor.numpy(), allow_pickle=False)
227
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
228
+
229
+ def deserialize_tensor(data: str) -> tf.Tensor:
230
+ buffer = io.BytesIO(base64.b64decode(data))
231
+ return tf.constant(np.load(buffer, allow_pickle=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ def serialize_kv_cache(past_kv):
234
+ if past_kv is None:
235
+ return None
236
+ return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv]
237
+
238
+ def deserialize_kv_cache(data):
239
+ if data is None:
240
+ return None
241
+ return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data]
242
 
243
  # ============================================================================
244
+ # HTTP Communication
245
  # ============================================================================
246
 
247
+ def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
248
+ """Send hidden states to worker and get result."""
249
  try:
 
 
 
250
  response = requests.post(
251
+ f"{url.rstrip('/')}/api/forward",
252
+ json={
253
+ "hidden_states": serialize_tensor(hidden_states),
254
+ "past_kv": serialize_kv_cache(past_kv),
255
+ "use_cache": use_cache,
256
+ },
257
+ headers={"Authorization": f"Bearer {CONFIG['secret_token']}"},
258
+ timeout=120
259
  )
260
 
261
  if response.status_code == 200:
262
+ result = response.json()
263
+ output = deserialize_tensor(result["hidden_states"])
264
+ new_kv = deserialize_kv_cache(result.get("past_kv"))
265
+ return output, new_kv
266
  else:
267
+ raise RuntimeError(f"Worker returned {response.status_code}")
 
268
  except Exception as e:
269
+ raise RuntimeError(f"Worker call failed: {e}")
 
 
270
 
271
  # ============================================================================
272
+ # Model Loading
273
  # ============================================================================
274
 
275
+ def load_model():
276
+ """Load model and extract components for this node."""
277
+ print("πŸš€ Loading model...")
278
+
279
+ # Load config
280
+ config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
281
+ with open(config_path, 'r') as f:
282
+ model_config = json.load(f)
283
+ STATE.config = model_config
284
+
285
+ # Load tokenizer
286
+ from transformers import AutoTokenizer
287
+ from tokenizers import Tokenizer
288
+
289
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
290
+ hf_tokenizer.add_special_tokens({"additional_special_tokens":
291
+ ["<|im_start|>", "<|im_end|>", "<think>", "</think>", "<CONTINUE>", "<im end for model tun>"]})
292
+ os.makedirs("./temp_tokenizer", exist_ok=True)
293
+ hf_tokenizer.save_pretrained("./temp_tokenizer")
294
+ STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
295
+ STATE.eos_token_id = model_config.get('eos_token_id', 50256)
296
+
297
+ # Load weights
298
+ weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
299
+
300
+ # Build full model to load weights
301
+ n_layers = model_config['num_hidden_layers']
302
+ d_model = model_config['hidden_size']
303
+ n_heads = model_config['num_attention_heads']
304
+ ff_dim = model_config['intermediate_size']
305
+ max_len = model_config['max_position_embeddings']
306
+ rope_theta = model_config['rope_theta']
307
+ vocab_size = model_config['vocab_size']
308
+
309
+ # Temporary full model
310
+ embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
311
+ blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
312
+ for i in range(n_layers)]
313
+ final_norm = RMSNorm(name="final_norm")
314
+ lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
315
+
316
+ # Build
317
+ dummy = tf.zeros((1, 16), dtype=tf.int32)
318
+ x = embedding(dummy)
319
+ for block in blocks:
320
+ x, _ = block(x)
321
+ x = final_norm(x)
322
+ _ = lm_head(x)
323
+
324
+ # Load weights into a temp model structure
325
+ class TempModel(keras.Model):
326
+ def __init__(self):
327
+ super().__init__()
328
+ self.embed = embedding
329
+ self.blocks = blocks
330
+ self.norm = final_norm
331
+ self.lm_head = lm_head
332
+ def call(self, x):
333
+ x = self.embed(x)
334
+ for b in self.blocks:
335
+ x, _ = b(x)
336
+ return self.lm_head(self.norm(x))
337
+
338
+ temp_model = TempModel()
339
+ temp_model(dummy)
340
+ temp_model.load_weights(weights_path)
341
+ print("βœ… Weights loaded")
342
+
343
+ # Extract components for this node
344
+ STATE.my_block_start = CONFIG["layer_start"]
345
+ STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
346
+
347
+ # HEAD always has embedding
348
+ STATE.embedding = embedding
349
+
350
+ # Extract our blocks
351
+ STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
352
+ print(f"βœ… Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
353
+
354
+ # HEAD has final norm and lm_head only if no workers OR we handle last block
355
+ has_workers = len(CONFIG["worker_urls"]) > 0
356
+ if not has_workers:
357
+ STATE.final_norm = final_norm
358
+ STATE.lm_head = lm_head
359
+ print("βœ… Loaded final norm and LM head (standalone mode)")
360
+
361
+ # Warmup
362
+ print("πŸ”₯ Warming up...")
363
+ dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
364
+ x = STATE.embedding(dummy)
365
+ for block in STATE.blocks:
366
+ x, _ = block(x, use_cache=False)
367
+ if STATE.lm_head:
368
+ _ = STATE.lm_head(STATE.final_norm(x))
369
+
370
+ print("βœ… Model ready!")
371
+ return True
372
 
373
  # ============================================================================
374
+ # Distributed Forward
375
  # ============================================================================
376
 
377
+ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
378
+ """
379
+ Full forward pass through HEAD + all workers.
380
+ Returns logits and updated KV caches.
381
+ """
382
+ # Embedding
383
+ x = STATE.embedding(input_ids)
384
+
385
+ # Local blocks
386
+ new_local_kv = [] if use_cache else None
387
+ for i, block in enumerate(STATE.blocks):
388
+ block_past = past_kv_local[i] if past_kv_local else None
389
+ x, kv = block(x, past_kv=block_past, use_cache=use_cache)
390
+ if use_cache:
391
+ new_local_kv.append(kv)
392
+
393
+ # Workers
394
+ new_worker_kv = {} if use_cache else None
395
+ for worker_url in CONFIG["worker_urls"]:
396
+ worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
397
+ x, worker_kv = call_worker(worker_url, x, worker_past, use_cache)
398
+ if use_cache:
399
+ new_worker_kv[worker_url] = worker_kv
400
+
401
+ # Final (only if standalone or last worker returned to us)
402
+ # In distributed mode, the last worker applies final_norm + lm_head
403
+ if STATE.lm_head:
404
+ logits = STATE.lm_head(STATE.final_norm(x))
405
+ else:
406
+ # x should already be logits from last worker
407
+ logits = x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ return logits, new_local_kv, new_worker_kv
 
410
 
411
  # ============================================================================
412
+ # Generation
413
  # ============================================================================
414
 
415
+ def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty):
416
+ logits = np.array(logits) / temperature
 
 
 
417
 
418
+ for tid, freq in token_freq.items():
419
+ if tid < len(logits):
420
+ logits[tid] /= (rep_penalty ** freq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
+ if 0 < top_k < len(logits):
423
+ top_k_idx = np.argpartition(logits, -top_k)[-top_k:]
424
+ top_k_logits = logits[top_k_idx]
425
  else:
426
+ top_k_idx = np.arange(len(logits))
427
+ top_k_logits = logits
428
+
429
+ top_k_logits = top_k_logits - np.max(top_k_logits)
430
+ probs = np.exp(top_k_logits)
431
+ probs /= probs.sum()
 
 
 
 
432
 
433
+ if top_p < 1.0:
434
+ sorted_idx = np.argsort(probs)[::-1]
435
+ cumsum = np.cumsum(probs[sorted_idx])
436
+ cutoff = np.searchsorted(cumsum, top_p) + 1
437
+ nucleus_idx = sorted_idx[:cutoff]
438
+ nucleus_probs = probs[nucleus_idx]
439
+ nucleus_probs /= nucleus_probs.sum()
440
+ sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
441
+ return int(top_k_idx[nucleus_idx[sampled]])
442
 
443
+ return int(top_k_idx[np.random.choice(len(probs), p=probs)])
444
 
445
 
446
+ def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1):
447
+ global stop_generation
448
+ stop_generation = False
 
449
 
450
+ input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id]
451
+ if not input_ids:
452
+ yield "Error: Empty prompt"
453
+ return
454
 
455
+ generated = ""
456
+ token_freq = {}
457
 
458
+ stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"),
459
+ STATE.tokenizer.token_to_id("<im end for model tun>")}
460
+ stop_ids.discard(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ max_ctx = STATE.config['max_position_embeddings']
463
+ if len(input_ids) > max_ctx - max_tokens:
464
+ input_ids = input_ids[-(max_ctx - max_tokens):]
 
 
 
465
 
466
+ start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
+ # Prefill
469
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
 
 
 
 
 
 
 
 
470
  try:
471
+ logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
 
 
472
  except Exception as e:
473
+ yield f"Error: {e}"
474
+ return
 
 
 
 
 
 
 
475
 
476
+ next_logits = logits[0, -1, :].numpy()
477
+ prefill_time = time.time() - start
478
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
 
 
 
479
 
480
+ # Generate
481
+ decode_start = time.time()
482
+ tokens_generated = 0
483
+
484
+ for _ in range(max_tokens):
485
+ if stop_generation:
486
+ yield generated + "\n\n*[Stopped]*"
487
+ return
488
 
489
+ next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty)
490
 
491
+ if next_id in stop_ids:
492
+ break
493
 
494
+ token_freq[next_id] = token_freq.get(next_id, 0) + 1
495
+ generated += STATE.tokenizer.decode([next_id])
496
+ tokens_generated += 1
497
+ yield generated
 
 
 
498
 
499
+ # Next step
500
+ next_input = tf.constant([[next_id]], dtype=tf.int32)
501
+ try:
502
+ logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
503
+ except Exception as e:
504
+ yield generated + f"\n\n*[Error: {e}]*"
505
+ return
 
 
 
 
 
 
 
506
 
507
+ next_logits = logits[0, -1, :].numpy()
508
+
509
+ # Stats
510
+ if tokens_generated > 0:
511
+ total = time.time() - start
512
+ tps = tokens_generated / (time.time() - decode_start)
513
+ workers = len(CONFIG["worker_urls"])
514
+ mode = f", {workers} workers" if workers else " standalone"
515
+ generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*"
516
+
517
+ yield generated
518
+
519
+
520
+ def format_prompt(message: str, history: list, reasoning: bool) -> str:
521
+ prompt = ""
522
+ for user, assistant in history:
523
+ prompt += f"<|im_start|>user\n{user}<|im_end|>\n"
524
+ if assistant:
525
+ prompt += f"<|im_start|>assistant\n{assistant.split('*[')[0].strip()}<|im_end|>\n"
526
+ prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
527
+ if reasoning:
528
+ prompt += "<think>"
529
+ return prompt
530
+
531
+
532
+ def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning):
533
+ if not message.strip():
534
+ yield history
535
+ return
536
+
537
+ prompt = format_prompt(message, history, reasoning)
538
+
539
+ for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
540
+ display = text
541
+ for tag in ["<|im_end|>", "<im end for model tun>"]:
542
+ if tag in display:
543
+ idx = display.find(tag)
544
+ stats = display.find("\n\n*[")
545
+ display = display[:idx] + (display[stats:] if stats > idx else "")
546
 
547
+ if reasoning and '<think>' in display and '</think>' in display:
548
+ s, e = display.find('<think>'), display.find('</think>')
549
+ if s < e:
550
+ thought = display[s+7:e].strip()
551
+ display = display[:s] + f'<details><summary>🧠 Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
 
 
 
 
 
 
 
 
 
552
 
553
+ yield history + [[message, display.strip()]]
554
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
+ def stop():
557
+ global stop_generation
558
+ stop_generation = True
559
 
560
  # ============================================================================
561
+ # Gradio UI
562
  # ============================================================================
563
 
564
+ def create_ui():
565
+ workers = CONFIG["worker_urls"]
566
+ mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone"
 
567
 
568
+ with gr.Blocks(title="Sam-large-2 HEAD") as app:
569
+ gr.Markdown(f"""
570
+ # πŸ‘‘ Sam-large-2 - HEAD NODE
571
+ **Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']}
572
+ """)
 
 
 
 
573
 
574
+ if workers:
575
+ gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
 
 
576
 
577
+ reasoning = gr.State(False)
578
+ chatbot = gr.Chatbot(height=500)
579
+
580
+ with gr.Row():
581
+ reason_btn = gr.Button("πŸ’‘", size="sm", scale=0)
582
+ msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8)
583
+ send = gr.Button("Send", variant="primary", scale=1)
584
+ stop_btn = gr.Button("⏹️", scale=0)
585
+
586
+ with gr.Accordion("βš™οΈ Settings", open=False):
587
+ max_tok = gr.Slider(50, 1024, 512, label="Max Tokens")
588
+ temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
589
+ topk = gr.Slider(1, 100, 40, label="Top-K")
590
+ topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P")
591
+ rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty")
592
+
593
+ def toggle(r):
594
+ return not r, gr.update(variant="primary" if not r else "secondary")
595
+
596
+ reason_btn.click(toggle, [reasoning], [reasoning, reason_btn])
597
+
598
+ inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning]
599
+ submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
600
+ click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
601
+ stop_btn.click(stop, cancels=[submit, click])
602
+
603
+ gr.Button("πŸ—‘οΈ Clear").click(lambda: ([], ""), outputs=[chatbot, msg])
604
 
605
+ return app
 
 
 
 
606
 
607
  # ============================================================================
608
+ # Main
609
  # ============================================================================
610
 
611
+ print("=" * 60)
612
+ print("πŸš€ Sam-large-2 HEAD Node Starting")
613
+ print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}")
614
+ print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}")
615
+ print("=" * 60)
616
+
617
+ load_model()
618
+ app = create_ui()
619
+ app.queue()
620
+ app.launch(server_name="0.0.0.0", server_port=7860)