"""
Sam-large-2 Distributed Inference - HEAD NODE
Edit the CONFIG below, then deploy.
"""
# ============================================================================
# âī¸ CONFIGURATION - EDIT THIS
# ============================================================================
CONFIG = {
"node_id": "head-main",
"layer_start": 0,
"layer_end": 6,
"worker_urls": [],
"secret_token": "sam2-distributed-secret-change-me",
"model_repo": "Smilyai-labs/Sam-large-2",
"cache_dir": "./model_cache",
}
# ============================================================================
# CPU Optimization
# ============================================================================
import os
NUM_CORES = os.cpu_count() or 4
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import json
import time
import io
import base64
from typing import Dict, List, Optional, Tuple, Any
import gradio as gr
import numpy as np
import requests
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
print(f"â
CPU optimized: {NUM_CORES} threads")
# ============================================================================
# Model Architecture
# ============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
self.cos_cached = None
self.sin_cached = None
def build(self, input_shape):
super().build(input_shape)
def _build_cache(self):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
self.built_cache = True
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k, offset=0):
self._build_cache()
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def get_config(self):
return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta}
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
return {**super().get_config(), "epsilon": self.epsilon}
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
def build(self, input_shape):
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(self.dropout_rate)
super().build(input_shape)
def call(self, x, training=None, past_kv=None, use_cache=False):
B, T = tf.shape(x)[0], tf.shape(x)[1]
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
q, k = self.rope(q, k, offset=past_len)
if past_kv is not None:
k = tf.concat([past_kv[0], k], axis=2)
v = tf.concat([past_kv[1], v], axis=2)
new_kv = (k, v) if use_cache else None
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
full_len = tf.shape(k)[2]
q_pos = tf.range(past_len, past_len + T)
k_pos = tf.range(full_len)
mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9)
scores = scores + tf.cast(mask[None, None, :, :], dtype)
attn = tf.nn.softmax(scores, axis=-1)
attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model])
x = res + self.dropout(self.out_proj(attn_out), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training), new_kv
def get_config(self):
return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads,
"ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len,
"rope_theta": self.rope_theta, "layer_idx": self.layer_idx}
# ============================================================================
# State
# ============================================================================
class ModelState:
def __init__(self):
self.config = None
self.tokenizer = None
self.eos_token_id = 50256
self.embedding = None
self.blocks: List = []
self.final_norm = None
self.lm_head = None
self.my_block_start = 0
self.my_block_end = 0
STATE = ModelState()
stop_generation = False
# ============================================================================
# Serialization
# ============================================================================
def serialize_tensor(tensor: tf.Tensor) -> str:
buffer = io.BytesIO()
np.save(buffer, tensor.numpy(), allow_pickle=False)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def deserialize_tensor(data: str) -> tf.Tensor:
buffer = io.BytesIO(base64.b64decode(data))
return tf.constant(np.load(buffer, allow_pickle=False))
def serialize_kv_cache(past_kv):
if past_kv is None:
return None
return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv]
def deserialize_kv_cache(data):
if data is None:
return None
return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data]
# ============================================================================
# HTTP Communication
# ============================================================================
def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]:
try:
response = requests.post(
f"{url.rstrip('/')}/api/forward",
json={
"hidden_states": serialize_tensor(hidden_states),
"past_kv": serialize_kv_cache(past_kv),
"use_cache": use_cache,
},
headers={"Authorization": f"Bearer {CONFIG['secret_token']}"},
timeout=120
)
if response.status_code == 200:
result = response.json()
output = deserialize_tensor(result["hidden_states"])
new_kv = deserialize_kv_cache(result.get("past_kv"))
return output, new_kv
else:
raise RuntimeError(f"Worker returned {response.status_code}")
except Exception as e:
raise RuntimeError(f"Worker call failed: {e}")
# ============================================================================
# Model Loading
# ============================================================================
def load_model():
print("đ Loading model...")
config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"])
with open(config_path, 'r') as f:
model_config = json.load(f)
STATE.config = model_config
from transformers import AutoTokenizer
from tokenizers import Tokenizer
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
hf_tokenizer.add_special_tokens({"additional_special_tokens":
["<|im_start|>", "<|im_end|>", "", "", "", ""]})
os.makedirs("./temp_tokenizer", exist_ok=True)
hf_tokenizer.save_pretrained("./temp_tokenizer")
STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
STATE.eos_token_id = model_config.get('eos_token_id', 50256)
weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"])
n_layers = model_config['num_hidden_layers']
d_model = model_config['hidden_size']
n_heads = model_config['num_attention_heads']
ff_dim = model_config['intermediate_size']
max_len = model_config['max_position_embeddings']
rope_theta = model_config['rope_theta']
vocab_size = model_config['vocab_size']
embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens")
blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}")
for i in range(n_layers)]
final_norm = RMSNorm(name="final_norm")
lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head")
dummy = tf.zeros((1, 16), dtype=tf.int32)
x = embedding(dummy)
for block in blocks:
x, _ = block(x)
x = final_norm(x)
_ = lm_head(x)
class TempModel(keras.Model):
def __init__(self):
super().__init__()
self.embed = embedding
self.blocks = blocks
self.norm = final_norm
self.lm_head = lm_head
def call(self, x):
x = self.embed(x)
for b in self.blocks:
x, _ = b(x)
return self.lm_head(self.norm(x))
temp_model = TempModel()
temp_model(dummy)
temp_model.load_weights(weights_path)
print("â
Weights loaded")
STATE.my_block_start = CONFIG["layer_start"]
STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers
STATE.embedding = embedding
STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end]
print(f"â
Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}")
has_workers = len(CONFIG["worker_urls"]) > 0
if not has_workers:
STATE.final_norm = final_norm
STATE.lm_head = lm_head
print("â
Loaded final norm and LM head (standalone mode)")
print("đĨ Warming up...")
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
x = STATE.embedding(dummy)
for block in STATE.blocks:
x, _ = block(x, use_cache=False)
if STATE.lm_head:
_ = STATE.lm_head(STATE.final_norm(x))
print("â
Model ready!")
return True
# ============================================================================
# Distributed Forward
# ============================================================================
def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False):
x = STATE.embedding(input_ids)
new_local_kv = [] if use_cache else None
for i, block in enumerate(STATE.blocks):
block_past = past_kv_local[i] if past_kv_local else None
x, kv = block(x, past_kv=block_past, use_cache=use_cache)
if use_cache:
new_local_kv.append(kv)
new_worker_kv = {} if use_cache else None
for worker_url in CONFIG["worker_urls"]:
worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None
x, worker_kv = call_worker(worker_url, x, worker_past, use_cache)
if use_cache:
new_worker_kv[worker_url] = worker_kv
if STATE.lm_head:
logits = STATE.lm_head(STATE.final_norm(x))
else:
logits = x
return logits, new_local_kv, new_worker_kv
# ============================================================================
# Generation
# ============================================================================
def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty):
logits = np.array(logits) / temperature
for tid, freq in token_freq.items():
if tid < len(logits):
logits[tid] /= (rep_penalty ** freq)
if 0 < top_k < len(logits):
top_k_idx = np.argpartition(logits, -top_k)[-top_k:]
top_k_logits = logits[top_k_idx]
else:
top_k_idx = np.arange(len(logits))
top_k_logits = logits
top_k_logits = top_k_logits - np.max(top_k_logits)
probs = np.exp(top_k_logits)
probs /= probs.sum()
if top_p < 1.0:
sorted_idx = np.argsort(probs)[::-1]
cumsum = np.cumsum(probs[sorted_idx])
cutoff = np.searchsorted(cumsum, top_p) + 1
nucleus_idx = sorted_idx[:cutoff]
nucleus_probs = probs[nucleus_idx]
nucleus_probs /= nucleus_probs.sum()
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
return int(top_k_idx[nucleus_idx[sampled]])
return int(top_k_idx[np.random.choice(len(probs), p=probs)])
def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1):
global stop_generation
stop_generation = False
input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id]
if not input_ids:
yield "Error: Empty prompt"
return
generated = ""
token_freq = {}
stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"),
STATE.tokenizer.token_to_id("")}
stop_ids.discard(None)
max_ctx = STATE.config['max_position_embeddings']
if len(input_ids) > max_ctx - max_tokens:
input_ids = input_ids[-(max_ctx - max_tokens):]
start = time.time()
input_tensor = tf.constant([input_ids], dtype=tf.int32)
try:
logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True)
except Exception as e:
yield f"Error: {e}"
return
next_logits = logits[0, -1, :].numpy()
prefill_time = time.time() - start
print(f"⥠Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
decode_start = time.time()
tokens_generated = 0
for _ in range(max_tokens):
if stop_generation:
yield generated + "\n\n*[Stopped]*"
return
next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty)
if next_id in stop_ids:
break
token_freq[next_id] = token_freq.get(next_id, 0) + 1
generated += STATE.tokenizer.decode([next_id])
tokens_generated += 1
yield generated
next_input = tf.constant([[next_id]], dtype=tf.int32)
try:
logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True)
except Exception as e:
yield generated + f"\n\n*[Error: {e}]*"
return
next_logits = logits[0, -1, :].numpy()
if tokens_generated > 0:
total = time.time() - start
tps = tokens_generated / (time.time() - decode_start)
workers = len(CONFIG["worker_urls"])
mode = f", {workers} workers" if workers else " standalone"
generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*"
yield generated
def format_prompt(message: str, history: list, reasoning: bool) -> str:
prompt = ""
for msg in history:
if msg["role"] == "user":
prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
elif msg["role"] == "assistant":
content = msg['content'].split('*[')[0].strip()
prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
if reasoning:
prompt += ""
return prompt
def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning):
if not message.strip():
yield history
return
prompt = format_prompt(message, history, reasoning)
# Add user message to history
history = history + [{"role": "user", "content": message}]
for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen):
display = text
# Clean stop tags
for tag in ["<|im_end|>", ""]:
if tag in display:
idx = display.find(tag)
stats = display.find("\n\n*[")
display = display[:idx] + (display[stats:] if stats > idx else "")
# Format reasoning
if reasoning and '' in display and '' in display:
s, e = display.find(''), display.find('')
if s < e:
thought = display[s+7:e].strip()
display = display[:s] + f'đ§ Reasoning
{thought}
' + display[e+8:]
yield history + [{"role": "assistant", "content": display.strip()}]
def stop():
global stop_generation
stop_generation = True
# ============================================================================
# Gradio UI
# ============================================================================
def create_ui():
workers = CONFIG["worker_urls"]
mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone"
with gr.Blocks(title="Sam-large-2 HEAD") as app:
gr.Markdown(f"""
# đ Sam-large-2 - HEAD NODE
**Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']}
""")
if workers:
gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers))
reasoning = gr.State(False)
chatbot = gr.Chatbot(
height=500,
type="messages" # Use new messages format
)
with gr.Row():
reason_btn = gr.Button("đĄ", size="sm", scale=0)
msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8)
send = gr.Button("Send", variant="primary", scale=1)
stop_btn = gr.Button("âšī¸", scale=0)
with gr.Accordion("âī¸ Settings", open=False):
max_tok = gr.Slider(50, 1024, 512, label="Max Tokens")
temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
topk = gr.Slider(1, 100, 40, label="Top-K")
topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P")
rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty")
def toggle(r):
return not r, gr.update(variant="primary" if not r else "secondary")
reason_btn.click(toggle, [reasoning], [reasoning, reason_btn])
inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning]
submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg)
stop_btn.click(stop, cancels=[submit, click])
gr.Button("đī¸ Clear").click(lambda: [], outputs=[chatbot])
return app
# ============================================================================
# Main
# ============================================================================
print("=" * 60)
print("đ Sam-large-2 HEAD Node Starting")
print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}")
print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}")
print("=" * 60)
load_model()
app = create_ui()
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860)