Bc-AI's picture
Update app.py
0a83aff verified
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
import keras
import numpy as np
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import json
from abc import ABC, abstractmethod
# ==============================================================================
# Model Architecture (Must match training code)
# ==============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
self.built_cache = True
super().build(input_shape)
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(
tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
tf.constant(-1e9, dtype=dtype),
tf.constant(0.0, dtype=dtype)
)
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({
"d_model": self.d_model,
"n_heads": self.n_heads,
"ff_dim": self.ff_dim,
"dropout": self.dropout_rate,
"max_len": self.max_len,
"rope_theta": self.rope_theta,
"layer_idx": self.layer_idx
})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {
'd_model': self.cfg['d_model'],
'n_heads': self.cfg['n_heads'],
'ff_dim': ff_dim,
'dropout': self.cfg['dropout'],
'max_len': self.cfg['max_len'],
'rope_theta': self.cfg['rope_theta']
}
self.blocks = []
for i in range(self.cfg['n_layers']):
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
self.blocks.append(block)
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
# ==============================================================================
# Helper Functions
# ==============================================================================
def count_parameters(model):
"""Count total and non-zero parameters in model."""
total_params = 0
non_zero_params = 0
for weight in model.weights:
w = weight.numpy()
total_params += w.size
non_zero_params += np.count_nonzero(w)
return total_params, non_zero_params
def format_param_count(count):
"""Format parameter count in human readable format."""
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
else:
return str(count)
# ==============================================================================
# Model Backend Interface
# ==============================================================================
class ModelBackend(ABC):
@abstractmethod
def predict(self, input_ids):
pass
@abstractmethod
def get_name(self):
pass
@abstractmethod
def get_info(self):
pass
class KerasBackend(ModelBackend):
def __init__(self, model, name, display_name):
self.model = model
self.name = name
self.display_name = display_name
# Count parameters
total, non_zero = count_parameters(model)
self.total_params = total
self.non_zero_params = non_zero
self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
# Calculate actual model config for speed estimation
self.n_heads = model.cfg.get('n_heads', 0)
self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
def predict(self, input_ids):
inputs = np.array([input_ids], dtype=np.int32)
logits = self.model(inputs, training=False)
return logits[0, -1, :].numpy()
def get_name(self):
return self.display_name
def get_info(self):
info = f"{self.display_name}\n"
info += f" Total params: {format_param_count(self.total_params)}\n"
info += f" Attention heads: {self.n_heads}\n"
info += f" FFN dimension: {self.ff_dim}\n"
if self.sparsity > 1:
info += f" Sparsity: {self.sparsity:.1f}%\n"
return info
# ==============================================================================
# EASY MODEL REGISTRY - ADD YOUR MODELS HERE!
# ==============================================================================
MODEL_REGISTRY = [
# Format: (display_name, repo_id, weights_filename, config_filename)
# Smaller models are ACTUALLY faster (fewer params = real speedup!)
("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
("SAM-X-1-Fast โšก (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
("SAM-X-1-Mini ๐Ÿš€ (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
("SAM-X-1-Nano โšกโšก (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
]
# To add a new model, just add a new line above! Format:
# ("Display Name", "repo_id", "weights.h5", "config.json")
# If config_filename is None, uses the default config
# ==============================================================================
# Load Models
# ==============================================================================
CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
print("="*80)
print("๐Ÿค– SAM-X-1 Multi-Model Chat Interface".center(80))
print("="*80)
# Download config and tokenizer
print(f"\n๐Ÿ“ฆ Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
# Load config
with open(config_path, 'r') as f:
base_config = json.load(f)
print(f"โœ… Base config loaded")
# Build base model config
base_model_config = {
'vocab_size': base_config['vocab_size'],
'd_model': base_config['hidden_size'],
'n_heads': base_config['num_attention_heads'],
'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
'dropout': base_config.get('dropout', 0.0),
'max_len': base_config['max_position_embeddings'],
'rope_theta': base_config['rope_theta'],
'n_layers': base_config['num_hidden_layers']
}
# Recreate tokenizer
print("\n๐Ÿ”ค Recreating tokenizer...")
tokenizer = Tokenizer.from_pretrained("gpt2")
eos_token = ""
eos_token_id = tokenizer.token_to_id(eos_token)
if eos_token_id is None:
tokenizer.add_special_tokens([eos_token])
eos_token_id = tokenizer.token_to_id(eos_token)
custom_tokens = ["<think>", "<think/>"]
for token in custom_tokens:
if tokenizer.token_to_id(token) is None:
tokenizer.add_special_tokens([token])
tokenizer.no_padding()
tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
print(f"โœ… Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
# Load all models from registry
print("\n" + "="*80)
print("๐Ÿ“ฆ LOADING MODELS".center(80))
print("="*80)
available_models = {}
dummy_input = tf.zeros((1, 1), dtype=tf.int32)
for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
try:
print(f"\nโณ Loading: {display_name}")
print(f" Repo: {repo_id}")
print(f" Weights: {weights_filename}")
# Download weights
weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
# Load custom config if specified (for pruned models)
if config_filename:
print(f" Config: {config_filename}")
custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(custom_config_path, 'r') as f:
model_config = json.load(f)
print(f" ๐Ÿ“ Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim")
else:
model_config = base_model_config.copy()
# Create model with appropriate config
model = SAM1Model(**model_config)
model(dummy_input)
model.load_weights(weights_path)
model.trainable = False
# Create backend
backend = KerasBackend(model, display_name, display_name)
available_models[display_name] = backend
# Print stats
print(f" โœ… Loaded successfully!")
print(f" ๐Ÿ“Š Parameters: {format_param_count(backend.total_params)}")
print(f" ๐Ÿ“Š Attention heads: {backend.n_heads}")
print(f" ๐Ÿ“Š FFN dimension: {backend.ff_dim}")
except Exception as e:
print(f" โš ๏ธ Failed to load: {e}")
print(f" Skipping {display_name}...")
if not available_models:
raise RuntimeError("โŒ No models loaded! Check your MODEL_REGISTRY configuration.")
print(f"\nโœ… Successfully loaded {len(available_models)} model(s)")
print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
current_backend = list(available_models.values())[0]
# ==============================================================================
# Important Note About Pruning and Speed
# ==============================================================================
print("\n" + "="*80)
print("๐Ÿ’ก ABOUT PRUNING & SPEED".center(80))
print("="*80)
print("""
๐Ÿ“Œ Does pruning reduce parameter count?
YES and NO:
โ€ข Total param count stays the same (architecture unchanged)
โ€ข BUT pruned weights are set to ZERO (sparse weights)
โ€ข Active/non-zero params are reduced significantly
๐Ÿ“Œ Does pruning speed up inference?
IT DEPENDS:
โ€ข Dense operations (regular matrix multiply): NO speedup by default
โ€ข Need sparse kernels or hardware support for actual speedup
โ€ข HOWEVER: Smaller active weights = better cache utilization
โ€ข Less computation on zeros = potential speedup on some hardware
๐Ÿ“Œ What DOES speed things up reliably?
โœ… Quantization (FP16, INT8) - smaller types = faster compute
โœ… Fewer layers (layer pruning)
โœ… Smaller hidden dimensions (width reduction)
โœ… Knowledge distillation to smaller architecture
๐Ÿ“Œ Why use structured pruning then?
โœ… Reduces memory footprint (especially with sparse storage)
โœ… Can be combined with quantization for real speedups
โœ… Preserves quality better than aggressive dimension reduction
โœ… Foundation for converting to truly smaller architecture
""")
def generate_response_stream(prompt, temperature=0.7, backend=None):
"""Generate response and yield tokens one by one for streaming."""
if backend is None:
backend = current_backend
encoded_prompt = tokenizer.encode(prompt)
input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
generated = input_ids.copy()
current_text = ""
in_thinking = False
# Get max_len from the backend's model config
max_len = backend.model.cfg['max_len']
for _ in range(512):
current_input = generated[-max_len:]
# Get logits from selected backend
next_token_logits = backend.predict(current_input)
if temperature > 0:
next_token_logits = next_token_logits / temperature
top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
top_k_logits = next_token_logits[top_k_indices]
top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
top_k_probs /= top_k_probs.sum()
next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)]
else:
next_token = np.argmax(next_token_logits)
if next_token == eos_token_id:
break
generated.append(int(next_token))
new_text = tokenizer.decode(generated[len(input_ids):])
if len(new_text) > len(current_text):
new_chunk = new_text[len(current_text):]
current_text = new_text
if "<think>" in new_chunk:
in_thinking = True
elif "</think>" in new_chunk or "<think/>" in new_chunk:
in_thinking = False
yield new_chunk, in_thinking
# ==============================================================================
# Gradio Interface
# ==============================================================================
if __name__ == "__main__":
import gradio as gr
custom_css = """
.chat-container {
height: 600px;
overflow-y: auto;
padding: 20px;
background: #ffffff;
}
.user-message {
background: #f7f7f8;
padding: 16px;
margin: 12px 0;
border-radius: 8px;
}
.assistant-message {
background: #ffffff;
padding: 16px;
margin: 12px 0;
border-radius: 8px;
border-left: 3px solid #10a37f;
}
.message-content {
color: #353740;
line-height: 1.6;
font-size: 15px;
}
.message-header {
font-weight: 600;
margin-bottom: 8px;
color: #353740;
font-size: 14px;
}
.thinking-content {
color: #6b7280;
font-style: italic;
border-left: 3px solid #d1d5db;
padding-left: 12px;
margin: 8px 0;
background: #f9fafb;
padding: 8px 12px;
border-radius: 4px;
}
.input-row {
background: #ffffff;
padding: 12px;
border-radius: 8px;
margin-top: 12px;
border: 1px solid #e5e7eb;
}
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
.announcement-banner {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 16px 24px;
border-radius: 12px;
margin-bottom: 20px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
text-align: center;
font-size: 16px;
font-weight: 500;
animation: slideIn 0.5s ease-out;
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.announcement-banner strong {
font-weight: 700;
font-size: 18px;
}
.settings-panel {
background: #f9fafb;
padding: 16px;
border-radius: 8px;
margin-bottom: 12px;
border: 1px solid #e5e7eb;
}
.model-info {
background: #f0f9ff;
border: 1px solid #bae6fd;
padding: 12px;
border-radius: 8px;
margin-top: 8px;
font-size: 13px;
font-family: monospace;
white-space: pre-line;
}
"""
def format_message_html(role, content, show_thinking=True):
"""Format a single message as HTML."""
role_class = "user-message" if role == "user" else "assistant-message"
role_name = "You" if role == "user" else "SAM-X-1"
thinking = ""
answer = ""
if "<think>" in content:
parts = content.split("<think>", 1)
before_think = parts[0].strip()
if len(parts) > 1:
after_think = parts[1]
if "</think>" in after_think:
think_parts = after_think.split("</think>", 1)
thinking = think_parts[0].strip()
answer = (before_think + " " + think_parts[1]).strip()
elif "<think/>" in after_think:
think_parts = after_think.split("<think/>", 1)
thinking = think_parts[0].strip()
answer = (before_think + " " + think_parts[1]).strip()
else:
thinking = after_think.strip()
answer = before_think
else:
answer = before_think
else:
answer = content
html = f'<div class="{role_class}">'
html += f'<div class="message-header">{role_name}</div>'
html += f'<div class="message-content">'
if thinking and show_thinking:
html += f'<div class="thinking-content">๐Ÿ’ญ {thinking}</div>'
if answer:
html += f'<div>{answer}</div>'
html += '</div></div>'
return html
def render_history(history, show_thinking):
"""Render chat history as HTML."""
html = ""
for msg in history:
html += format_message_html(msg["role"], msg["content"], show_thinking)
return html
def send_message(message, history, show_thinking, temperature, model_choice):
if not message.strip():
yield history, "", render_history(history, show_thinking), ""
return
# Switch backend based on selection
backend = available_models[model_choice]
# Add user message
history.append({"role": "user", "content": message})
yield history, "", render_history(history, show_thinking), backend.get_info()
# Generate prompt
prompt = f"User: {message}\nSam: <think>"
# Start assistant message
history.append({"role": "assistant", "content": "<think>"})
# Stream response
for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend):
history[-1]["content"] += new_chunk
yield history, "", render_history(history, show_thinking), backend.get_info()
# Create Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
# Announcement Banner
gr.HTML("""
<div class="announcement-banner">
๐ŸŽ‰ <strong>NEW UPDATE:</strong> Multiple model variants now available!
Choose Fast/Mini/Nano for <strong>30-250% speed boost</strong>! โšก
The models marked with (BETA) are not useful yet. <strong>They are still in development!</strong>
</div>
""")
gr.Markdown("# ๐Ÿค– SAM-X-1 Multi-Model Chat")
# Settings panel
with gr.Accordion("โš™๏ธ Settings", open=False):
with gr.Row():
model_selector = gr.Dropdown(
choices=list(available_models.keys()),
value=list(available_models.keys())[0],
label="Model Selection",
info="Choose your speed/quality tradeoff"
)
model_info_box = gr.Textbox(
label="Selected Model Info",
value=list(available_models.values())[0].get_info(),
interactive=False,
lines=4,
elem_classes=["model-info"]
)
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more focused"
)
show_thinking_checkbox = gr.Checkbox(
label="Show Thinking Process",
value=True,
info="Display model's reasoning"
)
# Chat state and display
chatbot_state = gr.State([])
chat_html = gr.HTML(value="", elem_classes=["chat-container"])
# Input area
with gr.Row(elem_classes=["input-row"]):
msg_input = gr.Textbox(
placeholder="Ask me anything...",
show_label=False,
container=False,
scale=9
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", size="sm")
# Event handlers
msg_input.submit(
send_message,
inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
outputs=[chatbot_state, msg_input, chat_html, model_info_box]
)
send_btn.click(
send_message,
inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector],
outputs=[chatbot_state, msg_input, chat_html, model_info_box]
)
clear_btn.click(
lambda: ([], ""),
outputs=[chatbot_state, chat_html]
)
show_thinking_checkbox.change(
lambda h, st: render_history(h, st),
inputs=[chatbot_state, show_thinking_checkbox],
outputs=[chat_html]
)
# Update model info when selection changes
model_selector.change(
lambda choice: available_models[choice].get_info(),
inputs=[model_selector],
outputs=[model_info_box]
)
demo.launch(debug=True, share=True)