import os os.environ['KERAS_BACKEND'] = 'tensorflow' os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf import keras import numpy as np from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import json from abc import ABC, abstractmethod # ============================================================================== # Model Architecture (Must match training code) # ============================================================================== @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False def build(self, input_shape): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32) self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32) self.built_cache = True super().build(input_shape) def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k): seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] q_rotated = (q * cos) + (self.rotate_half(q) * sin) k_rotated = (k * cos) + (self.rotate_half(k) * sin) return q_rotated, k_rotated def get_config(self): config = super().get_config() config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) return config @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) return config @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx self.pre_attn_norm = RMSNorm() self.pre_ffn_norm = RMSNorm() self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(dropout) def call(self, x, training=None): B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model dtype = x.dtype res = x y = self.pre_attn_norm(x) q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) q, k = self.rope(q, k) scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) mask = tf.where( tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype) ) scores += mask attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) x = res + self.dropout(self.out_proj(attn), training=training) res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) return res + self.dropout(ffn, training=training) def get_config(self): config = super().get_config() config.update({ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx }) return config @keras.saving.register_keras_serializable() class SAM1Model(keras.Model): def __init__(self, **kwargs): super().__init__() if 'config' in kwargs and isinstance(kwargs['config'], dict): self.cfg = kwargs['config'] elif 'vocab_size' in kwargs: self.cfg = kwargs else: self.cfg = kwargs.get('cfg', kwargs) self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) block_args = { 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta'] } self.blocks = [] for i in range(self.cfg['n_layers']): block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) self.blocks.append(block) self.norm = RMSNorm(name="final_norm") self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") def call(self, input_ids, training=None): x = self.embed(input_ids) for block in self.blocks: x = block(x, training=training) return self.lm_head(self.norm(x)) def get_config(self): base_config = super().get_config() base_config['config'] = self.cfg return base_config # ============================================================================== # Helper Functions # ============================================================================== def count_parameters(model): """Count total and non-zero parameters in model.""" total_params = 0 non_zero_params = 0 for weight in model.weights: w = weight.numpy() total_params += w.size non_zero_params += np.count_nonzero(w) return total_params, non_zero_params def format_param_count(count): """Format parameter count in human readable format.""" if count >= 1e9: return f"{count/1e9:.2f}B" elif count >= 1e6: return f"{count/1e6:.2f}M" elif count >= 1e3: return f"{count/1e3:.2f}K" else: return str(count) # ============================================================================== # Model Backend Interface # ============================================================================== class ModelBackend(ABC): @abstractmethod def predict(self, input_ids): pass @abstractmethod def get_name(self): pass @abstractmethod def get_info(self): pass class KerasBackend(ModelBackend): def __init__(self, model, name, display_name): self.model = model self.name = name self.display_name = display_name # Count parameters total, non_zero = count_parameters(model) self.total_params = total self.non_zero_params = non_zero self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0 # Calculate actual model config for speed estimation self.n_heads = model.cfg.get('n_heads', 0) self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0)) def predict(self, input_ids): inputs = np.array([input_ids], dtype=np.int32) logits = self.model(inputs, training=False) return logits[0, -1, :].numpy() def get_name(self): return self.display_name def get_info(self): info = f"{self.display_name}\n" info += f" Total params: {format_param_count(self.total_params)}\n" info += f" Attention heads: {self.n_heads}\n" info += f" FFN dimension: {self.ff_dim}\n" if self.sparsity > 1: info += f" Sparsity: {self.sparsity:.1f}%\n" return info # ============================================================================== # EASY MODEL REGISTRY - ADD YOUR MODELS HERE! # ============================================================================== MODEL_REGISTRY = [ # Format: (display_name, repo_id, weights_filename, config_filename) # Smaller models are ACTUALLY faster (fewer params = real speedup!) ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None), ("SAM-X-1-Fast โšก (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"), ("SAM-X-1-Mini ๐Ÿš€ (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"), ("SAM-X-1-Nano โšกโšก (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"), ] # To add a new model, just add a new line above! Format: # ("Display Name", "repo_id", "weights.h5", "config.json") # If config_filename is None, uses the default config # ============================================================================== # Load Models # ============================================================================== CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002" print("="*80) print("๐Ÿค– SAM-X-1 Multi-Model Chat Interface".center(80)) print("="*80) # Download config and tokenizer print(f"\n๐Ÿ“ฆ Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}") config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json") tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json") # Load config with open(config_path, 'r') as f: base_config = json.load(f) print(f"โœ… Base config loaded") # Build base model config base_model_config = { 'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers'] } # Recreate tokenizer print("\n๐Ÿ”ค Recreating tokenizer...") tokenizer = Tokenizer.from_pretrained("gpt2") eos_token = "" eos_token_id = tokenizer.token_to_id(eos_token) if eos_token_id is None: tokenizer.add_special_tokens([eos_token]) eos_token_id = tokenizer.token_to_id(eos_token) custom_tokens = ["", ""] for token in custom_tokens: if tokenizer.token_to_id(token) is None: tokenizer.add_special_tokens([token]) tokenizer.no_padding() tokenizer.enable_truncation(max_length=base_config['max_position_embeddings']) print(f"โœ… Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})") # Load all models from registry print("\n" + "="*80) print("๐Ÿ“ฆ LOADING MODELS".center(80)) print("="*80) available_models = {} dummy_input = tf.zeros((1, 1), dtype=tf.int32) for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY: try: print(f"\nโณ Loading: {display_name}") print(f" Repo: {repo_id}") print(f" Weights: {weights_filename}") # Download weights weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) # Load custom config if specified (for pruned models) if config_filename: print(f" Config: {config_filename}") custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) with open(custom_config_path, 'r') as f: model_config = json.load(f) print(f" ๐Ÿ“ Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim") else: model_config = base_model_config.copy() # Create model with appropriate config model = SAM1Model(**model_config) model(dummy_input) model.load_weights(weights_path) model.trainable = False # Create backend backend = KerasBackend(model, display_name, display_name) available_models[display_name] = backend # Print stats print(f" โœ… Loaded successfully!") print(f" ๐Ÿ“Š Parameters: {format_param_count(backend.total_params)}") print(f" ๐Ÿ“Š Attention heads: {backend.n_heads}") print(f" ๐Ÿ“Š FFN dimension: {backend.ff_dim}") except Exception as e: print(f" โš ๏ธ Failed to load: {e}") print(f" Skipping {display_name}...") if not available_models: raise RuntimeError("โŒ No models loaded! Check your MODEL_REGISTRY configuration.") print(f"\nโœ… Successfully loaded {len(available_models)} model(s)") print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}") current_backend = list(available_models.values())[0] # ============================================================================== # Important Note About Pruning and Speed # ============================================================================== print("\n" + "="*80) print("๐Ÿ’ก ABOUT PRUNING & SPEED".center(80)) print("="*80) print(""" ๐Ÿ“Œ Does pruning reduce parameter count? YES and NO: โ€ข Total param count stays the same (architecture unchanged) โ€ข BUT pruned weights are set to ZERO (sparse weights) โ€ข Active/non-zero params are reduced significantly ๐Ÿ“Œ Does pruning speed up inference? IT DEPENDS: โ€ข Dense operations (regular matrix multiply): NO speedup by default โ€ข Need sparse kernels or hardware support for actual speedup โ€ข HOWEVER: Smaller active weights = better cache utilization โ€ข Less computation on zeros = potential speedup on some hardware ๐Ÿ“Œ What DOES speed things up reliably? โœ… Quantization (FP16, INT8) - smaller types = faster compute โœ… Fewer layers (layer pruning) โœ… Smaller hidden dimensions (width reduction) โœ… Knowledge distillation to smaller architecture ๐Ÿ“Œ Why use structured pruning then? โœ… Reduces memory footprint (especially with sparse storage) โœ… Can be combined with quantization for real speedups โœ… Preserves quality better than aggressive dimension reduction โœ… Foundation for converting to truly smaller architecture """) def generate_response_stream(prompt, temperature=0.7, backend=None): """Generate response and yield tokens one by one for streaming.""" if backend is None: backend = current_backend encoded_prompt = tokenizer.encode(prompt) input_ids = [i for i in encoded_prompt.ids if i != eos_token_id] generated = input_ids.copy() current_text = "" in_thinking = False # Get max_len from the backend's model config max_len = backend.model.cfg['max_len'] for _ in range(512): current_input = generated[-max_len:] # Get logits from selected backend next_token_logits = backend.predict(current_input) if temperature > 0: next_token_logits = next_token_logits / temperature top_k_indices = np.argpartition(next_token_logits, -50)[-50:] top_k_logits = next_token_logits[top_k_indices] top_k_probs = np.exp(top_k_logits - np.max(top_k_logits)) top_k_probs /= top_k_probs.sum() next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)] else: next_token = np.argmax(next_token_logits) if next_token == eos_token_id: break generated.append(int(next_token)) new_text = tokenizer.decode(generated[len(input_ids):]) if len(new_text) > len(current_text): new_chunk = new_text[len(current_text):] current_text = new_text if "" in new_chunk: in_thinking = True elif "" in new_chunk or "" in new_chunk: in_thinking = False yield new_chunk, in_thinking # ============================================================================== # Gradio Interface # ============================================================================== if __name__ == "__main__": import gradio as gr custom_css = """ .chat-container { height: 600px; overflow-y: auto; padding: 20px; background: #ffffff; } .user-message { background: #f7f7f8; padding: 16px; margin: 12px 0; border-radius: 8px; } .assistant-message { background: #ffffff; padding: 16px; margin: 12px 0; border-radius: 8px; border-left: 3px solid #10a37f; } .message-content { color: #353740; line-height: 1.6; font-size: 15px; } .message-header { font-weight: 600; margin-bottom: 8px; color: #353740; font-size: 14px; } .thinking-content { color: #6b7280; font-style: italic; border-left: 3px solid #d1d5db; padding-left: 12px; margin: 8px 0; background: #f9fafb; padding: 8px 12px; border-radius: 4px; } .input-row { background: #ffffff; padding: 12px; border-radius: 8px; margin-top: 12px; border: 1px solid #e5e7eb; } .gradio-container { max-width: 900px !important; margin: auto !important; } .announcement-banner { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 16px 24px; border-radius: 12px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); text-align: center; font-size: 16px; font-weight: 500; animation: slideIn 0.5s ease-out; } @keyframes slideIn { from { opacity: 0; transform: translateY(-20px); } to { opacity: 1; transform: translateY(0); } } .announcement-banner strong { font-weight: 700; font-size: 18px; } .settings-panel { background: #f9fafb; padding: 16px; border-radius: 8px; margin-bottom: 12px; border: 1px solid #e5e7eb; } .model-info { background: #f0f9ff; border: 1px solid #bae6fd; padding: 12px; border-radius: 8px; margin-top: 8px; font-size: 13px; font-family: monospace; white-space: pre-line; } """ def format_message_html(role, content, show_thinking=True): """Format a single message as HTML.""" role_class = "user-message" if role == "user" else "assistant-message" role_name = "You" if role == "user" else "SAM-X-1" thinking = "" answer = "" if "" in content: parts = content.split("", 1) before_think = parts[0].strip() if len(parts) > 1: after_think = parts[1] if "" in after_think: think_parts = after_think.split("", 1) thinking = think_parts[0].strip() answer = (before_think + " " + think_parts[1]).strip() elif "" in after_think: think_parts = after_think.split("", 1) thinking = think_parts[0].strip() answer = (before_think + " " + think_parts[1]).strip() else: thinking = after_think.strip() answer = before_think else: answer = before_think else: answer = content html = f'
' html += f'
{role_name}
' html += f'
' if thinking and show_thinking: html += f'
๐Ÿ’ญ {thinking}
' if answer: html += f'
{answer}
' html += '
' return html def render_history(history, show_thinking): """Render chat history as HTML.""" html = "" for msg in history: html += format_message_html(msg["role"], msg["content"], show_thinking) return html def send_message(message, history, show_thinking, temperature, model_choice): if not message.strip(): yield history, "", render_history(history, show_thinking), "" return # Switch backend based on selection backend = available_models[model_choice] # Add user message history.append({"role": "user", "content": message}) yield history, "", render_history(history, show_thinking), backend.get_info() # Generate prompt prompt = f"User: {message}\nSam: " # Start assistant message history.append({"role": "assistant", "content": ""}) # Stream response for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend): history[-1]["content"] += new_chunk yield history, "", render_history(history, show_thinking), backend.get_info() # Create Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo: # Announcement Banner gr.HTML("""
๐ŸŽ‰ NEW UPDATE: Multiple model variants now available! Choose Fast/Mini/Nano for 30-250% speed boost! โšก The models marked with (BETA) are not useful yet. They are still in development!
""") gr.Markdown("# ๐Ÿค– SAM-X-1 Multi-Model Chat") # Settings panel with gr.Accordion("โš™๏ธ Settings", open=False): with gr.Row(): model_selector = gr.Dropdown( choices=list(available_models.keys()), value=list(available_models.keys())[0], label="Model Selection", info="Choose your speed/quality tradeoff" ) model_info_box = gr.Textbox( label="Selected Model Info", value=list(available_models.values())[0].get_info(), interactive=False, lines=4, elem_classes=["model-info"] ) with gr.Row(): temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative, Lower = more focused" ) show_thinking_checkbox = gr.Checkbox( label="Show Thinking Process", value=True, info="Display model's reasoning" ) # Chat state and display chatbot_state = gr.State([]) chat_html = gr.HTML(value="", elem_classes=["chat-container"]) # Input area with gr.Row(elem_classes=["input-row"]): msg_input = gr.Textbox( placeholder="Ask me anything...", show_label=False, container=False, scale=9 ) send_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", size="sm") # Event handlers msg_input.submit( send_message, inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector], outputs=[chatbot_state, msg_input, chat_html, model_info_box] ) send_btn.click( send_message, inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector], outputs=[chatbot_state, msg_input, chat_html, model_info_box] ) clear_btn.click( lambda: ([], ""), outputs=[chatbot_state, chat_html] ) show_thinking_checkbox.change( lambda h, st: render_history(h, st), inputs=[chatbot_state, show_thinking_checkbox], outputs=[chat_html] ) # Update model info when selection changes model_selector.change( lambda choice: available_models[choice].get_info(), inputs=[model_selector], outputs=[model_info_box] ) demo.launch(debug=True, share=True)