Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['KERAS_BACKEND'] = 'tensorflow' | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| from abc import ABC, abstractmethod | |
| # ============================================================================== | |
| # Model Architecture (Must match training code) | |
| # ============================================================================== | |
| class RotaryEmbedding(keras.layers.Layer): | |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = dim | |
| self.max_len = max_len | |
| self.theta = theta | |
| self.built_cache = False | |
| def build(self, input_shape): | |
| if not self.built_cache: | |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
| t = tf.range(self.max_len, dtype=tf.float32) | |
| freqs = tf.einsum("i,j->ij", t, inv_freq) | |
| emb = tf.concat([freqs, freqs], axis=-1) | |
| self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32) | |
| self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32) | |
| self.built_cache = True | |
| super().build(input_shape) | |
| def rotate_half(self, x): | |
| x1, x2 = tf.split(x, 2, axis=-1) | |
| return tf.concat([-x2, x1], axis=-1) | |
| def call(self, q, k): | |
| seq_len = tf.shape(q)[2] | |
| dtype = q.dtype | |
| cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] | |
| sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] | |
| q_rotated = (q * cos) + (self.rotate_half(q) * sin) | |
| k_rotated = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_rotated, k_rotated | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) | |
| return config | |
| class RMSNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| def build(self, input_shape): | |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") | |
| def call(self, x): | |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) | |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"epsilon": self.epsilon}) | |
| return config | |
| class TransformerBlock(keras.layers.Layer): | |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.ff_dim = ff_dim | |
| self.dropout_rate = dropout | |
| self.max_len = max_len | |
| self.rope_theta = rope_theta | |
| self.head_dim = d_model // n_heads | |
| self.layer_idx = layer_idx | |
| self.pre_attn_norm = RMSNorm() | |
| self.pre_ffn_norm = RMSNorm() | |
| self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") | |
| self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") | |
| self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") | |
| self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") | |
| self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) | |
| self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") | |
| self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") | |
| self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") | |
| self.dropout = keras.layers.Dropout(dropout) | |
| def call(self, x, training=None): | |
| B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model | |
| dtype = x.dtype | |
| res = x | |
| y = self.pre_attn_norm(x) | |
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| q, k = self.rope(q, k) | |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) | |
| mask = tf.where( | |
| tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, | |
| tf.constant(-1e9, dtype=dtype), | |
| tf.constant(0.0, dtype=dtype) | |
| ) | |
| scores += mask | |
| attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) | |
| attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) | |
| x = res + self.dropout(self.out_proj(attn), training=training) | |
| res = x | |
| y = self.pre_ffn_norm(x) | |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) | |
| return res + self.dropout(ffn, training=training) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "d_model": self.d_model, | |
| "n_heads": self.n_heads, | |
| "ff_dim": self.ff_dim, | |
| "dropout": self.dropout_rate, | |
| "max_len": self.max_len, | |
| "rope_theta": self.rope_theta, | |
| "layer_idx": self.layer_idx | |
| }) | |
| return config | |
| class SAM1Model(keras.Model): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): | |
| self.cfg = kwargs['config'] | |
| elif 'vocab_size' in kwargs: | |
| self.cfg = kwargs | |
| else: | |
| self.cfg = kwargs.get('cfg', kwargs) | |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") | |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) | |
| block_args = { | |
| 'd_model': self.cfg['d_model'], | |
| 'n_heads': self.cfg['n_heads'], | |
| 'ff_dim': ff_dim, | |
| 'dropout': self.cfg['dropout'], | |
| 'max_len': self.cfg['max_len'], | |
| 'rope_theta': self.cfg['rope_theta'] | |
| } | |
| self.blocks = [] | |
| for i in range(self.cfg['n_layers']): | |
| block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) | |
| self.blocks.append(block) | |
| self.norm = RMSNorm(name="final_norm") | |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") | |
| def call(self, input_ids, training=None): | |
| x = self.embed(input_ids) | |
| for block in self.blocks: | |
| x = block(x, training=training) | |
| return self.lm_head(self.norm(x)) | |
| def get_config(self): | |
| base_config = super().get_config() | |
| base_config['config'] = self.cfg | |
| return base_config | |
| # ============================================================================== | |
| # Helper Functions | |
| # ============================================================================== | |
| def count_parameters(model): | |
| """Count total and non-zero parameters in model.""" | |
| total_params = 0 | |
| non_zero_params = 0 | |
| for weight in model.weights: | |
| w = weight.numpy() | |
| total_params += w.size | |
| non_zero_params += np.count_nonzero(w) | |
| return total_params, non_zero_params | |
| def format_param_count(count): | |
| """Format parameter count in human readable format.""" | |
| if count >= 1e9: | |
| return f"{count/1e9:.2f}B" | |
| elif count >= 1e6: | |
| return f"{count/1e6:.2f}M" | |
| elif count >= 1e3: | |
| return f"{count/1e3:.2f}K" | |
| else: | |
| return str(count) | |
| # ============================================================================== | |
| # Model Backend Interface | |
| # ============================================================================== | |
| class ModelBackend(ABC): | |
| def predict(self, input_ids): | |
| pass | |
| def get_name(self): | |
| pass | |
| def get_info(self): | |
| pass | |
| class KerasBackend(ModelBackend): | |
| def __init__(self, model, name, display_name): | |
| self.model = model | |
| self.name = name | |
| self.display_name = display_name | |
| # Count parameters | |
| total, non_zero = count_parameters(model) | |
| self.total_params = total | |
| self.non_zero_params = non_zero | |
| self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0 | |
| # Calculate actual model config for speed estimation | |
| self.n_heads = model.cfg.get('n_heads', 0) | |
| self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0)) | |
| def predict(self, input_ids): | |
| inputs = np.array([input_ids], dtype=np.int32) | |
| logits = self.model(inputs, training=False) | |
| return logits[0, -1, :].numpy() | |
| def get_name(self): | |
| return self.display_name | |
| def get_info(self): | |
| info = f"{self.display_name}\n" | |
| info += f" Total params: {format_param_count(self.total_params)}\n" | |
| info += f" Attention heads: {self.n_heads}\n" | |
| info += f" FFN dimension: {self.ff_dim}\n" | |
| if self.sparsity > 1: | |
| info += f" Sparsity: {self.sparsity:.1f}%\n" | |
| return info | |
| # ============================================================================== | |
| # EASY MODEL REGISTRY - ADD YOUR MODELS HERE! | |
| # ============================================================================== | |
| MODEL_REGISTRY = [ | |
| # Format: (display_name, repo_id, weights_filename, config_filename) | |
| # Smaller models are ACTUALLY faster (fewer params = real speedup!) | |
| ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None), | |
| ("SAM-X-1-Fast โก (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"), | |
| ("SAM-X-1-Mini ๐ (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"), | |
| ("SAM-X-1-Nano โกโก (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"), | |
| ] | |
| # To add a new model, just add a new line above! Format: | |
| # ("Display Name", "repo_id", "weights.h5", "config.json") | |
| # If config_filename is None, uses the default config | |
| # ============================================================================== | |
| # Load Models | |
| # ============================================================================== | |
| CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002" | |
| print("="*80) | |
| print("๐ค SAM-X-1 Multi-Model Chat Interface".center(80)) | |
| print("="*80) | |
| # Download config and tokenizer | |
| print(f"\n๐ฆ Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}") | |
| config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json") | |
| tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json") | |
| # Load config | |
| with open(config_path, 'r') as f: | |
| base_config = json.load(f) | |
| print(f"โ Base config loaded") | |
| # Build base model config | |
| base_model_config = { | |
| 'vocab_size': base_config['vocab_size'], | |
| 'd_model': base_config['hidden_size'], | |
| 'n_heads': base_config['num_attention_heads'], | |
| 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], | |
| 'dropout': base_config.get('dropout', 0.0), | |
| 'max_len': base_config['max_position_embeddings'], | |
| 'rope_theta': base_config['rope_theta'], | |
| 'n_layers': base_config['num_hidden_layers'] | |
| } | |
| # Recreate tokenizer | |
| print("\n๐ค Recreating tokenizer...") | |
| tokenizer = Tokenizer.from_pretrained("gpt2") | |
| eos_token = "" | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| if eos_token_id is None: | |
| tokenizer.add_special_tokens([eos_token]) | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| custom_tokens = ["<think>", "<think/>"] | |
| for token in custom_tokens: | |
| if tokenizer.token_to_id(token) is None: | |
| tokenizer.add_special_tokens([token]) | |
| tokenizer.no_padding() | |
| tokenizer.enable_truncation(max_length=base_config['max_position_embeddings']) | |
| print(f"โ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})") | |
| # Load all models from registry | |
| print("\n" + "="*80) | |
| print("๐ฆ LOADING MODELS".center(80)) | |
| print("="*80) | |
| available_models = {} | |
| dummy_input = tf.zeros((1, 1), dtype=tf.int32) | |
| for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY: | |
| try: | |
| print(f"\nโณ Loading: {display_name}") | |
| print(f" Repo: {repo_id}") | |
| print(f" Weights: {weights_filename}") | |
| # Download weights | |
| weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) | |
| # Load custom config if specified (for pruned models) | |
| if config_filename: | |
| print(f" Config: {config_filename}") | |
| custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) | |
| with open(custom_config_path, 'r') as f: | |
| model_config = json.load(f) | |
| print(f" ๐ Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim") | |
| else: | |
| model_config = base_model_config.copy() | |
| # Create model with appropriate config | |
| model = SAM1Model(**model_config) | |
| model(dummy_input) | |
| model.load_weights(weights_path) | |
| model.trainable = False | |
| # Create backend | |
| backend = KerasBackend(model, display_name, display_name) | |
| available_models[display_name] = backend | |
| # Print stats | |
| print(f" โ Loaded successfully!") | |
| print(f" ๐ Parameters: {format_param_count(backend.total_params)}") | |
| print(f" ๐ Attention heads: {backend.n_heads}") | |
| print(f" ๐ FFN dimension: {backend.ff_dim}") | |
| except Exception as e: | |
| print(f" โ ๏ธ Failed to load: {e}") | |
| print(f" Skipping {display_name}...") | |
| if not available_models: | |
| raise RuntimeError("โ No models loaded! Check your MODEL_REGISTRY configuration.") | |
| print(f"\nโ Successfully loaded {len(available_models)} model(s)") | |
| print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}") | |
| current_backend = list(available_models.values())[0] | |
| # ============================================================================== | |
| # Important Note About Pruning and Speed | |
| # ============================================================================== | |
| print("\n" + "="*80) | |
| print("๐ก ABOUT PRUNING & SPEED".center(80)) | |
| print("="*80) | |
| print(""" | |
| ๐ Does pruning reduce parameter count? | |
| YES and NO: | |
| โข Total param count stays the same (architecture unchanged) | |
| โข BUT pruned weights are set to ZERO (sparse weights) | |
| โข Active/non-zero params are reduced significantly | |
| ๐ Does pruning speed up inference? | |
| IT DEPENDS: | |
| โข Dense operations (regular matrix multiply): NO speedup by default | |
| โข Need sparse kernels or hardware support for actual speedup | |
| โข HOWEVER: Smaller active weights = better cache utilization | |
| โข Less computation on zeros = potential speedup on some hardware | |
| ๐ What DOES speed things up reliably? | |
| โ Quantization (FP16, INT8) - smaller types = faster compute | |
| โ Fewer layers (layer pruning) | |
| โ Smaller hidden dimensions (width reduction) | |
| โ Knowledge distillation to smaller architecture | |
| ๐ Why use structured pruning then? | |
| โ Reduces memory footprint (especially with sparse storage) | |
| โ Can be combined with quantization for real speedups | |
| โ Preserves quality better than aggressive dimension reduction | |
| โ Foundation for converting to truly smaller architecture | |
| """) | |
| def generate_response_stream(prompt, temperature=0.7, backend=None): | |
| """Generate response and yield tokens one by one for streaming.""" | |
| if backend is None: | |
| backend = current_backend | |
| encoded_prompt = tokenizer.encode(prompt) | |
| input_ids = [i for i in encoded_prompt.ids if i != eos_token_id] | |
| generated = input_ids.copy() | |
| current_text = "" | |
| in_thinking = False | |
| # Get max_len from the backend's model config | |
| max_len = backend.model.cfg['max_len'] | |
| for _ in range(512): | |
| current_input = generated[-max_len:] | |
| # Get logits from selected backend | |
| next_token_logits = backend.predict(current_input) | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| top_k_indices = np.argpartition(next_token_logits, -50)[-50:] | |
| top_k_logits = next_token_logits[top_k_indices] | |
| top_k_probs = np.exp(top_k_logits - np.max(top_k_logits)) | |
| top_k_probs /= top_k_probs.sum() | |
| next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)] | |
| else: | |
| next_token = np.argmax(next_token_logits) | |
| if next_token == eos_token_id: | |
| break | |
| generated.append(int(next_token)) | |
| new_text = tokenizer.decode(generated[len(input_ids):]) | |
| if len(new_text) > len(current_text): | |
| new_chunk = new_text[len(current_text):] | |
| current_text = new_text | |
| if "<think>" in new_chunk: | |
| in_thinking = True | |
| elif "</think>" in new_chunk or "<think/>" in new_chunk: | |
| in_thinking = False | |
| yield new_chunk, in_thinking | |
| # ============================================================================== | |
| # Gradio Interface | |
| # ============================================================================== | |
| if __name__ == "__main__": | |
| import gradio as gr | |
| custom_css = """ | |
| .chat-container { | |
| height: 600px; | |
| overflow-y: auto; | |
| padding: 20px; | |
| background: #ffffff; | |
| } | |
| .user-message { | |
| background: #f7f7f8; | |
| padding: 16px; | |
| margin: 12px 0; | |
| border-radius: 8px; | |
| } | |
| .assistant-message { | |
| background: #ffffff; | |
| padding: 16px; | |
| margin: 12px 0; | |
| border-radius: 8px; | |
| border-left: 3px solid #10a37f; | |
| } | |
| .message-content { | |
| color: #353740; | |
| line-height: 1.6; | |
| font-size: 15px; | |
| } | |
| .message-header { | |
| font-weight: 600; | |
| margin-bottom: 8px; | |
| color: #353740; | |
| font-size: 14px; | |
| } | |
| .thinking-content { | |
| color: #6b7280; | |
| font-style: italic; | |
| border-left: 3px solid #d1d5db; | |
| padding-left: 12px; | |
| margin: 8px 0; | |
| background: #f9fafb; | |
| padding: 8px 12px; | |
| border-radius: 4px; | |
| } | |
| .input-row { | |
| background: #ffffff; | |
| padding: 12px; | |
| border-radius: 8px; | |
| margin-top: 12px; | |
| border: 1px solid #e5e7eb; | |
| } | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto !important; | |
| } | |
| .announcement-banner { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 16px 24px; | |
| border-radius: 12px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| text-align: center; | |
| font-size: 16px; | |
| font-weight: 500; | |
| animation: slideIn 0.5s ease-out; | |
| } | |
| @keyframes slideIn { | |
| from { | |
| opacity: 0; | |
| transform: translateY(-20px); | |
| } | |
| to { | |
| opacity: 1; | |
| transform: translateY(0); | |
| } | |
| } | |
| .announcement-banner strong { | |
| font-weight: 700; | |
| font-size: 18px; | |
| } | |
| .settings-panel { | |
| background: #f9fafb; | |
| padding: 16px; | |
| border-radius: 8px; | |
| margin-bottom: 12px; | |
| border: 1px solid #e5e7eb; | |
| } | |
| .model-info { | |
| background: #f0f9ff; | |
| border: 1px solid #bae6fd; | |
| padding: 12px; | |
| border-radius: 8px; | |
| margin-top: 8px; | |
| font-size: 13px; | |
| font-family: monospace; | |
| white-space: pre-line; | |
| } | |
| """ | |
| def format_message_html(role, content, show_thinking=True): | |
| """Format a single message as HTML.""" | |
| role_class = "user-message" if role == "user" else "assistant-message" | |
| role_name = "You" if role == "user" else "SAM-X-1" | |
| thinking = "" | |
| answer = "" | |
| if "<think>" in content: | |
| parts = content.split("<think>", 1) | |
| before_think = parts[0].strip() | |
| if len(parts) > 1: | |
| after_think = parts[1] | |
| if "</think>" in after_think: | |
| think_parts = after_think.split("</think>", 1) | |
| thinking = think_parts[0].strip() | |
| answer = (before_think + " " + think_parts[1]).strip() | |
| elif "<think/>" in after_think: | |
| think_parts = after_think.split("<think/>", 1) | |
| thinking = think_parts[0].strip() | |
| answer = (before_think + " " + think_parts[1]).strip() | |
| else: | |
| thinking = after_think.strip() | |
| answer = before_think | |
| else: | |
| answer = before_think | |
| else: | |
| answer = content | |
| html = f'<div class="{role_class}">' | |
| html += f'<div class="message-header">{role_name}</div>' | |
| html += f'<div class="message-content">' | |
| if thinking and show_thinking: | |
| html += f'<div class="thinking-content">๐ญ {thinking}</div>' | |
| if answer: | |
| html += f'<div>{answer}</div>' | |
| html += '</div></div>' | |
| return html | |
| def render_history(history, show_thinking): | |
| """Render chat history as HTML.""" | |
| html = "" | |
| for msg in history: | |
| html += format_message_html(msg["role"], msg["content"], show_thinking) | |
| return html | |
| def send_message(message, history, show_thinking, temperature, model_choice): | |
| if not message.strip(): | |
| yield history, "", render_history(history, show_thinking), "" | |
| return | |
| # Switch backend based on selection | |
| backend = available_models[model_choice] | |
| # Add user message | |
| history.append({"role": "user", "content": message}) | |
| yield history, "", render_history(history, show_thinking), backend.get_info() | |
| # Generate prompt | |
| prompt = f"User: {message}\nSam: <think>" | |
| # Start assistant message | |
| history.append({"role": "assistant", "content": "<think>"}) | |
| # Stream response | |
| for new_chunk, in_thinking in generate_response_stream(prompt, temperature, backend): | |
| history[-1]["content"] += new_chunk | |
| yield history, "", render_history(history, show_thinking), backend.get_info() | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo: | |
| # Announcement Banner | |
| gr.HTML(""" | |
| <div class="announcement-banner"> | |
| ๐ <strong>NEW UPDATE:</strong> Multiple model variants now available! | |
| Choose Fast/Mini/Nano for <strong>30-250% speed boost</strong>! โก | |
| The models marked with (BETA) are not useful yet. <strong>They are still in development!</strong> | |
| </div> | |
| """) | |
| gr.Markdown("# ๐ค SAM-X-1 Multi-Model Chat") | |
| # Settings panel | |
| with gr.Accordion("โ๏ธ Settings", open=False): | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(available_models.keys()), | |
| value=list(available_models.keys())[0], | |
| label="Model Selection", | |
| info="Choose your speed/quality tradeoff" | |
| ) | |
| model_info_box = gr.Textbox( | |
| label="Selected Model Info", | |
| value=list(available_models.values())[0].get_info(), | |
| interactive=False, | |
| lines=4, | |
| elem_classes=["model-info"] | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| show_thinking_checkbox = gr.Checkbox( | |
| label="Show Thinking Process", | |
| value=True, | |
| info="Display model's reasoning" | |
| ) | |
| # Chat state and display | |
| chatbot_state = gr.State([]) | |
| chat_html = gr.HTML(value="", elem_classes=["chat-container"]) | |
| # Input area | |
| with gr.Row(elem_classes=["input-row"]): | |
| msg_input = gr.Textbox( | |
| placeholder="Ask me anything...", | |
| show_label=False, | |
| container=False, | |
| scale=9 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("๐๏ธ Clear Chat", size="sm") | |
| # Event handlers | |
| msg_input.submit( | |
| send_message, | |
| inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector], | |
| outputs=[chatbot_state, msg_input, chat_html, model_info_box] | |
| ) | |
| send_btn.click( | |
| send_message, | |
| inputs=[msg_input, chatbot_state, show_thinking_checkbox, temperature_slider, model_selector], | |
| outputs=[chatbot_state, msg_input, chat_html, model_info_box] | |
| ) | |
| clear_btn.click( | |
| lambda: ([], ""), | |
| outputs=[chatbot_state, chat_html] | |
| ) | |
| show_thinking_checkbox.change( | |
| lambda h, st: render_history(h, st), | |
| inputs=[chatbot_state, show_thinking_checkbox], | |
| outputs=[chat_html] | |
| ) | |
| # Update model info when selection changes | |
| model_selector.change( | |
| lambda choice: available_models[choice].get_info(), | |
| inputs=[model_selector], | |
| outputs=[model_info_box] | |
| ) | |
| demo.launch(debug=True, share=True) |