""" Sam-large-2 Distributed Inference - HEAD NODE Edit the CONFIG below, then deploy. """ # ============================================================================ # âš™ī¸ CONFIGURATION - EDIT THIS # ============================================================================ CONFIG = { "node_id": "head-main", "layer_start": 0, "layer_end": 6, "worker_urls": [], "secret_token": "sam2-distributed-secret-change-me", "model_repo": "Smilyai-labs/Sam-large-2", "cache_dir": "./model_cache", } # ============================================================================ # CPU Optimization # ============================================================================ import os NUM_CORES = os.cpu_count() or 4 os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) os.environ['CUDA_VISIBLE_DEVICES'] = '-1' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import json import time import io import base64 from typing import Dict, List, Optional, Tuple, Any import gradio as gr import numpy as np import requests import tensorflow as tf import keras from huggingface_hub import hf_hub_download tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) print(f"✅ CPU optimized: {NUM_CORES} threads") # ============================================================================ # Model Architecture # ============================================================================ @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False self.cos_cached = None self.sin_cached = None def build(self, input_shape): super().build(input_shape) def _build_cache(self): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) self.built_cache = True def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k, offset=0): self._build_cache() seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed def get_config(self): return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta} @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") super().build(input_shape) def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): return {**super().get_config(), "epsilon": self.epsilon} @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx def build(self, input_shape): self.pre_attn_norm = RMSNorm(name="pre_attn_norm") self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(self.dropout_rate) super().build(input_shape) def call(self, x, training=None, past_kv=None, use_cache=False): B, T = tf.shape(x)[0], tf.shape(x)[1] dtype = x.dtype res = x y = self.pre_attn_norm(x) q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0 q, k = self.rope(q, k, offset=past_len) if past_kv is not None: k = tf.concat([past_kv[0], k], axis=2) v = tf.concat([past_kv[1], v], axis=2) new_kv = (k, v) if use_cache else None scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) full_len = tf.shape(k)[2] q_pos = tf.range(past_len, past_len + T) k_pos = tf.range(full_len) mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9) scores = scores + tf.cast(mask[None, None, :, :], dtype) attn = tf.nn.softmax(scores, axis=-1) attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model]) x = res + self.dropout(self.out_proj(attn_out), training=training) res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) return res + self.dropout(ffn, training=training), new_kv def get_config(self): return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx} # ============================================================================ # State # ============================================================================ class ModelState: def __init__(self): self.config = None self.tokenizer = None self.eos_token_id = 50256 self.embedding = None self.blocks: List = [] self.final_norm = None self.lm_head = None self.my_block_start = 0 self.my_block_end = 0 STATE = ModelState() stop_generation = False # ============================================================================ # Serialization # ============================================================================ def serialize_tensor(tensor: tf.Tensor) -> str: buffer = io.BytesIO() np.save(buffer, tensor.numpy(), allow_pickle=False) return base64.b64encode(buffer.getvalue()).decode('utf-8') def deserialize_tensor(data: str) -> tf.Tensor: buffer = io.BytesIO(base64.b64decode(data)) return tf.constant(np.load(buffer, allow_pickle=False)) def serialize_kv_cache(past_kv): if past_kv is None: return None return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv] def deserialize_kv_cache(data): if data is None: return None return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data] # ============================================================================ # HTTP Communication # ============================================================================ def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]: try: response = requests.post( f"{url.rstrip('/')}/api/forward", json={ "hidden_states": serialize_tensor(hidden_states), "past_kv": serialize_kv_cache(past_kv), "use_cache": use_cache, }, headers={"Authorization": f"Bearer {CONFIG['secret_token']}"}, timeout=120 ) if response.status_code == 200: result = response.json() output = deserialize_tensor(result["hidden_states"]) new_kv = deserialize_kv_cache(result.get("past_kv")) return output, new_kv else: raise RuntimeError(f"Worker returned {response.status_code}") except Exception as e: raise RuntimeError(f"Worker call failed: {e}") # ============================================================================ # Model Loading # ============================================================================ def load_model(): print("🚀 Loading model...") config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"]) with open(config_path, 'r') as f: model_config = json.load(f) STATE.config = model_config from transformers import AutoTokenizer from tokenizers import Tokenizer hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") hf_tokenizer.add_special_tokens({"additional_special_tokens": ["<|im_start|>", "<|im_end|>", "", "", "", ""]}) os.makedirs("./temp_tokenizer", exist_ok=True) hf_tokenizer.save_pretrained("./temp_tokenizer") STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") STATE.eos_token_id = model_config.get('eos_token_id', 50256) weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"]) n_layers = model_config['num_hidden_layers'] d_model = model_config['hidden_size'] n_heads = model_config['num_attention_heads'] ff_dim = model_config['intermediate_size'] max_len = model_config['max_position_embeddings'] rope_theta = model_config['rope_theta'] vocab_size = model_config['vocab_size'] embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens") blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}") for i in range(n_layers)] final_norm = RMSNorm(name="final_norm") lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head") dummy = tf.zeros((1, 16), dtype=tf.int32) x = embedding(dummy) for block in blocks: x, _ = block(x) x = final_norm(x) _ = lm_head(x) class TempModel(keras.Model): def __init__(self): super().__init__() self.embed = embedding self.blocks = blocks self.norm = final_norm self.lm_head = lm_head def call(self, x): x = self.embed(x) for b in self.blocks: x, _ = b(x) return self.lm_head(self.norm(x)) temp_model = TempModel() temp_model(dummy) temp_model.load_weights(weights_path) print("✅ Weights loaded") STATE.my_block_start = CONFIG["layer_start"] STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers STATE.embedding = embedding STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end] print(f"✅ Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}") has_workers = len(CONFIG["worker_urls"]) > 0 if not has_workers: STATE.final_norm = final_norm STATE.lm_head = lm_head print("✅ Loaded final norm and LM head (standalone mode)") print("đŸ”Ĩ Warming up...") dummy = tf.constant([[1, 2, 3]], dtype=tf.int32) x = STATE.embedding(dummy) for block in STATE.blocks: x, _ = block(x, use_cache=False) if STATE.lm_head: _ = STATE.lm_head(STATE.final_norm(x)) print("✅ Model ready!") return True # ============================================================================ # Distributed Forward # ============================================================================ def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False): x = STATE.embedding(input_ids) new_local_kv = [] if use_cache else None for i, block in enumerate(STATE.blocks): block_past = past_kv_local[i] if past_kv_local else None x, kv = block(x, past_kv=block_past, use_cache=use_cache) if use_cache: new_local_kv.append(kv) new_worker_kv = {} if use_cache else None for worker_url in CONFIG["worker_urls"]: worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None x, worker_kv = call_worker(worker_url, x, worker_past, use_cache) if use_cache: new_worker_kv[worker_url] = worker_kv if STATE.lm_head: logits = STATE.lm_head(STATE.final_norm(x)) else: logits = x return logits, new_local_kv, new_worker_kv # ============================================================================ # Generation # ============================================================================ def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty): logits = np.array(logits) / temperature for tid, freq in token_freq.items(): if tid < len(logits): logits[tid] /= (rep_penalty ** freq) if 0 < top_k < len(logits): top_k_idx = np.argpartition(logits, -top_k)[-top_k:] top_k_logits = logits[top_k_idx] else: top_k_idx = np.arange(len(logits)) top_k_logits = logits top_k_logits = top_k_logits - np.max(top_k_logits) probs = np.exp(top_k_logits) probs /= probs.sum() if top_p < 1.0: sorted_idx = np.argsort(probs)[::-1] cumsum = np.cumsum(probs[sorted_idx]) cutoff = np.searchsorted(cumsum, top_p) + 1 nucleus_idx = sorted_idx[:cutoff] nucleus_probs = probs[nucleus_idx] nucleus_probs /= nucleus_probs.sum() sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs) return int(top_k_idx[nucleus_idx[sampled]]) return int(top_k_idx[np.random.choice(len(probs), p=probs)]) def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1): global stop_generation stop_generation = False input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id] if not input_ids: yield "Error: Empty prompt" return generated = "" token_freq = {} stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"), STATE.tokenizer.token_to_id("")} stop_ids.discard(None) max_ctx = STATE.config['max_position_embeddings'] if len(input_ids) > max_ctx - max_tokens: input_ids = input_ids[-(max_ctx - max_tokens):] start = time.time() input_tensor = tf.constant([input_ids], dtype=tf.int32) try: logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True) except Exception as e: yield f"Error: {e}" return next_logits = logits[0, -1, :].numpy() prefill_time = time.time() - start print(f"⚡ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s") decode_start = time.time() tokens_generated = 0 for _ in range(max_tokens): if stop_generation: yield generated + "\n\n*[Stopped]*" return next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty) if next_id in stop_ids: break token_freq[next_id] = token_freq.get(next_id, 0) + 1 generated += STATE.tokenizer.decode([next_id]) tokens_generated += 1 yield generated next_input = tf.constant([[next_id]], dtype=tf.int32) try: logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True) except Exception as e: yield generated + f"\n\n*[Error: {e}]*" return next_logits = logits[0, -1, :].numpy() if tokens_generated > 0: total = time.time() - start tps = tokens_generated / (time.time() - decode_start) workers = len(CONFIG["worker_urls"]) mode = f", {workers} workers" if workers else " standalone" generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*" yield generated def format_prompt(message: str, history: list, reasoning: bool) -> str: prompt = "" for msg in history: if msg["role"] == "user": prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n" elif msg["role"] == "assistant": content = msg['content'].split('*[')[0].strip() prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n" prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" if reasoning: prompt += "" return prompt def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning): if not message.strip(): yield history return prompt = format_prompt(message, history, reasoning) # Add user message to history history = history + [{"role": "user", "content": message}] for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen): display = text # Clean stop tags for tag in ["<|im_end|>", ""]: if tag in display: idx = display.find(tag) stats = display.find("\n\n*[") display = display[:idx] + (display[stats:] if stats > idx else "") # Format reasoning if reasoning and '' in display and '' in display: s, e = display.find(''), display.find('') if s < e: thought = display[s+7:e].strip() display = display[:s] + f'
🧠 Reasoning

{thought}

' + display[e+8:] yield history + [{"role": "assistant", "content": display.strip()}] def stop(): global stop_generation stop_generation = True # ============================================================================ # Gradio UI # ============================================================================ def create_ui(): workers = CONFIG["worker_urls"] mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone" with gr.Blocks(title="Sam-large-2 HEAD") as app: gr.Markdown(f""" # 👑 Sam-large-2 - HEAD NODE **Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']} """) if workers: gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers)) reasoning = gr.State(False) chatbot = gr.Chatbot( height=500, type="messages" # Use new messages format ) with gr.Row(): reason_btn = gr.Button("💡", size="sm", scale=0) msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8) send = gr.Button("Send", variant="primary", scale=1) stop_btn = gr.Button("âšī¸", scale=0) with gr.Accordion("âš™ī¸ Settings", open=False): max_tok = gr.Slider(50, 1024, 512, label="Max Tokens") temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature") topk = gr.Slider(1, 100, 40, label="Top-K") topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P") rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty") def toggle(r): return not r, gr.update(variant="primary" if not r else "secondary") reason_btn.click(toggle, [reasoning], [reasoning, reason_btn]) inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning] submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) stop_btn.click(stop, cancels=[submit, click]) gr.Button("đŸ—‘ī¸ Clear").click(lambda: [], outputs=[chatbot]) return app # ============================================================================ # Main # ============================================================================ print("=" * 60) print("🚀 Sam-large-2 HEAD Node Starting") print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}") print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}") print("=" * 60) load_model() app = create_ui() app.queue() app.launch(server_name="0.0.0.0", server_port=7860)