head-node / app.py
Bc-AI's picture
Update app.py
3760699 verified
raw
history blame
23.3 kB
"""
Sam-large-2 Distributed Inference - HEAD NODE
Edit the CONFIG below, then deploy.
"""
# ============================================================================
# βš™οΈ CONFIGURATION - EDIT THIS
# ============================================================================
CONFIG = {
# This node's identity
"node_id": "head-main",
# Which transformer blocks this node runs (0-indexed)
# Sam-large-2 has 12 blocks (0-11)
"layer_start": 0,
"layer_end": 6, # exclusive, so this runs blocks 0,1,2,3,4,5
# Worker Space URLs (in order of execution)
# Leave empty [] for standalone mode (all layers on this node)
"worker_urls": [
# "https://YOUR-WORKER-SPACE.hf.space",
],
# Shared secret for worker communication
"secret_token": "sam2-distributed-secret-change-me",
# Model settings
"model_repo": "Smilyai-labs/Sam-large-2",
"cache_dir": "./model_cache",
}
# ============================================================================
# CPU Optimization - MUST be before TensorFlow import
# ============================================================================
import os
NUM_CORES = os.cpu_count() or 4
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import json
import time
import threading
import io
import base64
from typing import Dict, List, Optional, Tuple, Any
import gradio as gr
import numpy as np
import requests
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
print(f"βœ… CPU optimized: {NUM_CORES} threads")
# ============================================================================
# Model Architecture
# ============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
self.cos_cached = None
self.sin_cached = None
def build(self, input_shape):
super().build(input_shape)
def _build_cache(self):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
self.built_cache = True
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k, offset=0):
self._build_cache()
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def get_config(self):
return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta}
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
return {**super().get_config(), "epsilon": self.epsilon}
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
def build(self, input_shape):
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(self.dropout_rate)
super().build(input_shape)
def call(self, x, training=None, past_kv=None, use_cache=False):
B, T = tf.shape(x)[0], tf.shape(x)[1]
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
q, k = self.rope(q, k, offset=past_len)
if past_kv is not None:
k = tf.concat([past_kv[0], k], axis=2)
v = tf.concat([past_kv[1], v], axis=2)
new_kv = (k, v) if use_cache else None
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
full_len = tf.shape(k)[2]
q_pos = tf.range(past_len, past_len + T)
k_pos = tf.range(full_len)
mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9)
scores = scores + tf.cast(mask[None, None, :, :], dtype)
attn = tf.nn.softmax(scores, axis=-1)
attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model])
x = res + self.dropout(self.out_proj(attn_out), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training), new_kv
def get_config(self):
return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads,
"ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len,
"rope_theta": self.rope_theta, "layer_idx": self.layer_idx}
# ============================================================================
# State
# ============================================================================
class ModelState:
def __init__(self):
self.config = None
self.tokenizer = None
self.eos_token_id = 50256
# Model components
self.embedding = None
self.blocks: List = []
self.final_norm = None
self.lm_head = None
self.my_block_start = 0
self.my_block_end = 0
STATE = ModelState()
stop_generation = False
# ============================================================================
# Serialization
# ============================================================================
def serialize_tensor(tensor: tf.Tensor) -> str:
buffer = io.BytesIO()
np.save(buffer, tensor.numpy(), allow_pickle=False)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def deserialize_tensor(data: str) -> tf.Tensor:
buffer = io.BytesIO(base64.b64decode(data))
return tf.constant(np.load(buffer, allow_pickle=False))
def serialize_kv_cache(past_kv):
if past_kv is None:
return None
return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv]
def deserialize_kv_cache(data):
if data is None:
return None
return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data]
# ============================================================================
# HTTP Communication
# ============================================================================
def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
"""Send hidden states to worker and get result."""
try:
response = requests.post(
f"{url.rstrip('/')}/api/forward",
json={
"hidden_states": serialize_tensor(hidden_states),
"past_kv": serialize_kv_cache(past_kv),
"use_cache": use_cache,
},
headers={"Authorization": f"Bearer {CONFIG['secret_token']}"},
timeout=120
)
if response.status_code == 200:
result = response.json()
output = deserialize_tensor(result["hidden_states"])
new_kv = deserialize_kv_cache(result.get("past_kv"))
return output, new_kv
else:
raise RuntimeError(f"Worker returned {response.status_code}")
except Exception as e:
raise RuntimeError(f"Worker call failed: {e}")
# ============================================================================
# Model Loading
# ============================================================================
def load_model():
"""Load model and extract components for this node."""
print("πŸš€ Loading model...")
# Load config
config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
with open(config_path, 'r') as f:
model_config = json.load(f)
STATE.config = model_config
# Load tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
hf_tokenizer.add_special_tokens({"additional_special_tokens":
["<|im_start|>", "<|im_end|>", "<think>", "</think>", "<CONTINUE>", "<im end for model tun>"]})
os.makedirs("./temp_tokenizer", exist_ok=True)
hf_tokenizer.save_pretrained("./temp_tokenizer")
STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
STATE.eos_token_id = model_config.get('eos_token_id', 50256)
# Load weights
weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
# Build full model to load weights
n_layers = model_config['num_hidden_layers']
d_model = model_config['hidden_size']
n_heads = model_config['num_attention_heads']
ff_dim = model_config['intermediate_size']
max_len = model_config['max_position_embeddings']
rope_theta = model_config['rope_theta']
vocab_size = model_config['vocab_size']
# Temporary full model
embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
for i in range(n_layers)]
final_norm = RMSNorm(name="final_norm")
lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
# Build
dummy = tf.zeros((1, 16), dtype=tf.int32)
x = embedding(dummy)
for block in blocks:
x, _ = block(x)
x = final_norm(x)
_ = lm_head(x)
# Load weights into a temp model structure
class TempModel(keras.Model):
def __init__(self):
super().__init__()
self.embed = embedding
self.blocks = blocks
self.norm = final_norm
self.lm_head = lm_head
def call(self, x):
x = self.embed(x)
for b in self.blocks:
x, _ = b(x)
return self.lm_head(self.norm(x))
temp_model = TempModel()
temp_model(dummy)
temp_model.load_weights(weights_path)
print("βœ… Weights loaded")
# Extract components for this node
STATE.my_block_start = CONFIG["layer_start"]
STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
# HEAD always has embedding
STATE.embedding = embedding
# Extract our blocks
STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
print(f"βœ… Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
# HEAD has final norm and lm_head only if no workers OR we handle last block
has_workers = len(CONFIG["worker_urls"]) > 0
if not has_workers:
STATE.final_norm = final_norm
STATE.lm_head = lm_head
print("βœ… Loaded final norm and LM head (standalone mode)")
# Warmup
print("πŸ”₯ Warming up...")
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
x = STATE.embedding(dummy)
for block in STATE.blocks:
x, _ = block(x, use_cache=False)
if STATE.lm_head:
_ = STATE.lm_head(STATE.final_norm(x))
print("βœ… Model ready!")
return True
# ============================================================================
# Distributed Forward
# ============================================================================
def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
"""
Full forward pass through HEAD + all workers.
Returns logits and updated KV caches.
"""
# Embedding
x = STATE.embedding(input_ids)
# Local blocks
new_local_kv = [] if use_cache else None
for i, block in enumerate(STATE.blocks):
block_past = past_kv_local[i] if past_kv_local else None
x, kv = block(x, past_kv=block_past, use_cache=use_cache)
if use_cache:
new_local_kv.append(kv)
# Workers
new_worker_kv = {} if use_cache else None
for worker_url in CONFIG["worker_urls"]:
worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
x, worker_kv = call_worker(worker_url, x, worker_past, use_cache)
if use_cache:
new_worker_kv[worker_url] = worker_kv
# Final (only if standalone or last worker returned to us)
# In distributed mode, the last worker applies final_norm + lm_head
if STATE.lm_head:
logits = STATE.lm_head(STATE.final_norm(x))
else:
# x should already be logits from last worker
logits = x
return logits, new_local_kv, new_worker_kv
# ============================================================================
# Generation
# ============================================================================
def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty):
logits = np.array(logits) / temperature
for tid, freq in token_freq.items():
if tid < len(logits):
logits[tid] /= (rep_penalty ** freq)
if 0 < top_k < len(logits):
top_k_idx = np.argpartition(logits, -top_k)[-top_k:]
top_k_logits = logits[top_k_idx]
else:
top_k_idx = np.arange(len(logits))
top_k_logits = logits
top_k_logits = top_k_logits - np.max(top_k_logits)
probs = np.exp(top_k_logits)
probs /= probs.sum()
if top_p < 1.0:
sorted_idx = np.argsort(probs)[::-1]
cumsum = np.cumsum(probs[sorted_idx])
cutoff = np.searchsorted(cumsum, top_p) + 1
nucleus_idx = sorted_idx[:cutoff]
nucleus_probs = probs[nucleus_idx]
nucleus_probs /= nucleus_probs.sum()
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
return int(top_k_idx[nucleus_idx[sampled]])
return int(top_k_idx[np.random.choice(len(probs), p=probs)])
def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1):
global stop_generation
stop_generation = False
input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id]
if not input_ids:
yield "Error: Empty prompt"
return
generated = ""
token_freq = {}
stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"),
STATE.tokenizer.token_to_id("<im end for model tun>")}
stop_ids.discard(None)
max_ctx = STATE.config['max_position_embeddings']
if len(input_ids) > max_ctx - max_tokens:
input_ids = input_ids[-(max_ctx - max_tokens):]
start = time.time()
# Prefill
input_tensor = tf.constant([input_ids], dtype=tf.int32)
try:
logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
except Exception as e:
yield f"Error: {e}"
return
next_logits = logits[0, -1, :].numpy()
prefill_time = time.time() - start
print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
# Generate
decode_start = time.time()
tokens_generated = 0
for _ in range(max_tokens):
if stop_generation:
yield generated + "\n\n*[Stopped]*"
return
next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty)
if next_id in stop_ids:
break
token_freq[next_id] = token_freq.get(next_id, 0) + 1
generated += STATE.tokenizer.decode([next_id])
tokens_generated += 1
yield generated
# Next step
next_input = tf.constant([[next_id]], dtype=tf.int32)
try:
logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
except Exception as e:
yield generated + f"\n\n*[Error: {e}]*"
return
next_logits = logits[0, -1, :].numpy()
# Stats
if tokens_generated > 0:
total = time.time() - start
tps = tokens_generated / (time.time() - decode_start)
workers = len(CONFIG["worker_urls"])
mode = f", {workers} workers" if workers else " standalone"
generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*"
yield generated
def format_prompt(message: str, history: list, reasoning: bool) -> str:
prompt = ""
for user, assistant in history:
prompt += f"<|im_start|>user\n{user}<|im_end|>\n"
if assistant:
prompt += f"<|im_start|>assistant\n{assistant.split('*[')[0].strip()}<|im_end|>\n"
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
if reasoning:
prompt += "<think>"
return prompt
def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning):
if not message.strip():
yield history
return
prompt = format_prompt(message, history, reasoning)
for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
display = text
for tag in ["<|im_end|>", "<im end for model tun>"]:
if tag in display:
idx = display.find(tag)
stats = display.find("\n\n*[")
display = display[:idx] + (display[stats:] if stats > idx else "")
if reasoning and '<think>' in display and '</think>' in display:
s, e = display.find('<think>'), display.find('</think>')
if s < e:
thought = display[s+7:e].strip()
display = display[:s] + f'<details><summary>🧠 Reasoning</summary><p>{thought}</p></details>' + display[e+8:]
yield history + [[message, display.strip()]]
def stop():
global stop_generation
stop_generation = True
# ============================================================================
# Gradio UI
# ============================================================================
def create_ui():
workers = CONFIG["worker_urls"]
mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone"
with gr.Blocks(title="Sam-large-2 HEAD") as app:
gr.Markdown(f"""
# πŸ‘‘ Sam-large-2 - HEAD NODE
**Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']}
""")
if workers:
gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
reasoning = gr.State(False)
chatbot = gr.Chatbot(height=500)
with gr.Row():
reason_btn = gr.Button("πŸ’‘", size="sm", scale=0)
msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8)
send = gr.Button("Send", variant="primary", scale=1)
stop_btn = gr.Button("⏹️", scale=0)
with gr.Accordion("βš™οΈ Settings", open=False):
max_tok = gr.Slider(50, 1024, 512, label="Max Tokens")
temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
topk = gr.Slider(1, 100, 40, label="Top-K")
topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P")
rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty")
def toggle(r):
return not r, gr.update(variant="primary" if not r else "secondary")
reason_btn.click(toggle, [reasoning], [reasoning, reason_btn])
inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning]
submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
stop_btn.click(stop, cancels=[submit, click])
gr.Button("πŸ—‘οΈ Clear").click(lambda: ([], ""), outputs=[chatbot, msg])
return app
# ============================================================================
# Main
# ============================================================================
print("=" * 60)
print("πŸš€ Sam-large-2 HEAD Node Starting")
print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}")
print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}")
print("=" * 60)
load_model()
app = create_ui()
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860)