Spaces:
Sleeping
Sleeping
| import os | |
| # ============================================================================ | |
| # CPU Optimization - MUST be before TensorFlow import | |
| # ============================================================================ | |
| NUM_CORES = os.cpu_count() or 4 | |
| os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) | |
| os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging | |
| import gradio as gr | |
| import tensorflow as tf | |
| import keras | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| from tokenizers import Tokenizer | |
| import numpy as np | |
| import time | |
| # Configure TF threading | |
| tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) | |
| tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) | |
| print(f"β CPU optimized: {NUM_CORES} threads, oneDNN enabled") | |
| # ============================================================================ | |
| # π FESTIVE MODE TOGGLE π | |
| # ============================================================================ | |
| FESTIVE = True | |
| # ============================================================================ | |
| # Configuration & Model Loading | |
| # ============================================================================ | |
| print("π Loading Sam-large-2 Model...") | |
| MODEL_REPO = "Smilyai-labs/Sam-large-2" | |
| CACHE_DIR = "./model_cache" | |
| # ============================================================================ | |
| # Model Architecture Definitions (Optimized with KV-Cache) | |
| # ============================================================================ | |
| class RotaryEmbedding(keras.layers.Layer): | |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = dim | |
| self.max_len = max_len | |
| self.theta = theta | |
| self.built_cache = False | |
| self.cos_cached = None | |
| self.sin_cached = None | |
| def build(self, input_shape): | |
| super().build(input_shape) | |
| def _build_cache(self): | |
| if not self.built_cache: | |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
| t = tf.range(self.max_len, dtype=tf.float32) | |
| freqs = tf.einsum("i,j->ij", t, inv_freq) | |
| emb = tf.concat([freqs, freqs], axis=-1) | |
| self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) | |
| self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) | |
| self.built_cache = True | |
| def rotate_half(self, x): | |
| x1, x2 = tf.split(x, 2, axis=-1) | |
| return tf.concat([-x2, x1], axis=-1) | |
| def call(self, q, k, offset=0): | |
| """Apply rotary embeddings with position offset for KV-cache.""" | |
| self._build_cache() | |
| seq_len = tf.shape(q)[2] | |
| dtype = q.dtype | |
| cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] | |
| sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] | |
| q_embed = (q * cos) + (self.rotate_half(q) * sin) | |
| k_embed = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) | |
| return config | |
| class RMSNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| self.scale = None | |
| def build(self, input_shape): | |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") | |
| super().build(input_shape) | |
| def call(self, x): | |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) | |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"epsilon": self.epsilon}) | |
| return config | |
| class TransformerBlock(keras.layers.Layer): | |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.ff_dim = ff_dim | |
| self.dropout_rate = dropout | |
| self.max_len = max_len | |
| self.rope_theta = rope_theta | |
| self.head_dim = d_model // n_heads | |
| self.layer_idx = layer_idx | |
| def build(self, input_shape): | |
| self.pre_attn_norm = RMSNorm(name="pre_attn_norm") | |
| self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") | |
| self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") | |
| self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") | |
| self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") | |
| self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") | |
| self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) | |
| self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") | |
| self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") | |
| self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") | |
| self.dropout = keras.layers.Dropout(self.dropout_rate) | |
| super().build(input_shape) | |
| def call(self, x, training=None, past_kv=None, use_cache=False): | |
| """ | |
| Args: | |
| x: input tensor [B, T, D] (T=1 during cached generation) | |
| past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim] | |
| use_cache: whether to return updated kv cache | |
| Returns: | |
| output, (new_k, new_v) if use_cache else output, None | |
| """ | |
| B = tf.shape(x)[0] | |
| T = tf.shape(x)[1] | |
| dtype = x.dtype | |
| res = x | |
| y = self.pre_attn_norm(x) | |
| # Project Q, K, V for current input | |
| q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]) | |
| q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim] | |
| k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]) | |
| k = tf.transpose(k, [0, 2, 1, 3]) | |
| v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]) | |
| v = tf.transpose(v, [0, 2, 1, 3]) | |
| # Determine position offset for RoPE | |
| if past_kv is not None: | |
| past_len = tf.shape(past_kv[0])[2] | |
| else: | |
| past_len = 0 | |
| # Apply RoPE with position offset | |
| q, k = self.rope(q, k, offset=past_len) | |
| # Concatenate with past KV | |
| if past_kv is not None: | |
| k = tf.concat([past_kv[0], k], axis=2) | |
| v = tf.concat([past_kv[1], v], axis=2) | |
| new_kv = (k, v) if use_cache else None | |
| # Attention | |
| full_len = tf.shape(k)[2] | |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) | |
| # Causal mask | |
| q_positions = tf.range(past_len, past_len + T) | |
| k_positions = tf.range(full_len) | |
| mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype) | |
| mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) | |
| scores = scores + mask[None, None, :, :] | |
| attn = tf.nn.softmax(scores, axis=-1) | |
| attn_out = tf.matmul(attn, v) | |
| attn_out = tf.transpose(attn_out, [0, 2, 1, 3]) | |
| attn_out = tf.reshape(attn_out, [B, T, self.d_model]) | |
| x = res + self.dropout(self.out_proj(attn_out), training=training) | |
| # FFN | |
| res = x | |
| y = self.pre_ffn_norm(x) | |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) | |
| output = res + self.dropout(ffn, training=training) | |
| return output, new_kv | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "d_model": self.d_model, | |
| "n_heads": self.n_heads, | |
| "ff_dim": self.ff_dim, | |
| "dropout": self.dropout_rate, | |
| "max_len": self.max_len, | |
| "rope_theta": self.rope_theta, | |
| "layer_idx": self.layer_idx | |
| }) | |
| return config | |
| class SAM1Model(keras.Model): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): | |
| self.cfg = kwargs['config'] | |
| elif 'vocab_size' in kwargs: | |
| self.cfg = kwargs | |
| else: | |
| self.cfg = kwargs.get('cfg', kwargs) | |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") | |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) | |
| block_args = { | |
| 'd_model': self.cfg['d_model'], | |
| 'n_heads': self.cfg['n_heads'], | |
| 'ff_dim': ff_dim, | |
| 'dropout': self.cfg['dropout'], | |
| 'max_len': self.cfg['max_len'], | |
| 'rope_theta': self.cfg['rope_theta'] | |
| } | |
| self.blocks = [ | |
| TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) | |
| for i in range(self.cfg['n_layers']) | |
| ] | |
| self.norm = RMSNorm(name="final_norm") | |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") | |
| def call(self, input_ids, training=None, past_kv=None, use_cache=False): | |
| """ | |
| Args: | |
| input_ids: [B, T] | |
| past_kv: list of (k, v) tuples, one per layer | |
| use_cache: whether to return updated cache | |
| Returns: | |
| logits, new_past_kv (or None) | |
| """ | |
| x = self.embed(input_ids) | |
| new_past_kv = [] if use_cache else None | |
| for i, block in enumerate(self.blocks): | |
| layer_past = past_kv[i] if past_kv is not None else None | |
| x, layer_kv = block(x, training=training, past_kv=layer_past, use_cache=use_cache) | |
| if use_cache: | |
| new_past_kv.append(layer_kv) | |
| logits = self.lm_head(self.norm(x)) | |
| return logits, new_past_kv | |
| def get_config(self): | |
| base_config = super().get_config() | |
| base_config['config'] = self.cfg | |
| return base_config | |
| # --- Model and Tokenizer Loading --- | |
| config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR) | |
| try: | |
| weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR) | |
| print("β Found checkpoint weights (ckpt.weights.h5)") | |
| use_checkpoint = True | |
| except Exception as e: | |
| print(f"β οΈ Checkpoint not found, falling back to model.keras: {e}") | |
| try: | |
| model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR) | |
| use_checkpoint = False | |
| except Exception as e_model: | |
| print(f"β Also failed to find model.keras: {e_model}") | |
| raise RuntimeError("Could not load model weights") | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| from transformers import AutoTokenizer | |
| hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "</think>", "<CONTINUE>", "<im end for model tun>"] | |
| hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens}) | |
| os.makedirs("./temp_tokenizer", exist_ok=True) | |
| hf_tokenizer.save_pretrained("./temp_tokenizer") | |
| tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") | |
| print(f"β Tokenizer created with vocab size: {tokenizer.get_vocab_size()}") | |
| eos_token_id = config.get('eos_token_id', 50256) | |
| print("\nπ Loading model...") | |
| model = None | |
| if use_checkpoint: | |
| print("π¦ Building model from config and loading checkpoint weights...") | |
| model_config = { | |
| 'vocab_size': config['vocab_size'], | |
| 'd_model': config['hidden_size'], | |
| 'n_layers': config['num_hidden_layers'], | |
| 'n_heads': config['num_attention_heads'], | |
| 'ff_mult': config['intermediate_size'] / config['hidden_size'], | |
| 'max_len': config['max_position_embeddings'], | |
| 'dropout': 0.1, | |
| 'rope_theta': config['rope_theta'] | |
| } | |
| model = SAM1Model(config=model_config) | |
| # Build model with dummy input | |
| dummy_input = tf.zeros((1, 16), dtype=tf.int32) | |
| _ = model(dummy_input, training=False, use_cache=False) | |
| print(f"β Model architecture built: {model.count_params():,} parameters") | |
| try: | |
| model.load_weights(weights_path) | |
| print("β Checkpoint weights loaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to load checkpoint weights: {e}") | |
| raise | |
| else: | |
| print("π¦ Loading full saved model...") | |
| try: | |
| custom_objects = { | |
| 'SAM1Model': SAM1Model, | |
| 'TransformerBlock': TransformerBlock, | |
| 'RMSNorm': RMSNorm, | |
| 'RotaryEmbedding': RotaryEmbedding | |
| } | |
| model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects) | |
| print("β Model loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load model: {e}") | |
| raise | |
| if model: | |
| print(f"β Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab") | |
| # Warm up the model | |
| print("π₯ Warming up model...") | |
| warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32) | |
| _, _ = model(warmup_input, training=False, use_cache=True) | |
| print("β Model warmed up") | |
| # ============================================================================ | |
| # Optimized Inference Logic with KV-Cache | |
| # ============================================================================ | |
| stop_generation = False | |
| def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty): | |
| """Pure NumPy sampling for speed.""" | |
| # Temperature scaling | |
| scaled_logits = logits / temperature | |
| # Repetition penalty | |
| if repetition_penalty != 1.0: | |
| for token_id, freq in token_freq.items(): | |
| if token_id < len(scaled_logits): | |
| scaled_logits[token_id] /= (repetition_penalty ** freq) | |
| # Top-K filtering | |
| if top_k > 0 and top_k < len(scaled_logits): | |
| top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:] | |
| top_k_logits = scaled_logits[top_k_indices] | |
| else: | |
| top_k_indices = np.arange(len(scaled_logits)) | |
| top_k_logits = scaled_logits | |
| # Softmax (numerically stable) | |
| top_k_logits = top_k_logits - np.max(top_k_logits) | |
| top_k_probs = np.exp(top_k_logits) | |
| top_k_probs /= top_k_probs.sum() | |
| # Top-P (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_idx = np.argsort(top_k_probs)[::-1] | |
| cumsum = np.cumsum(top_k_probs[sorted_idx]) | |
| cutoff = np.searchsorted(cumsum, top_p) + 1 | |
| nucleus_idx = sorted_idx[:cutoff] | |
| nucleus_probs = top_k_probs[nucleus_idx] | |
| nucleus_probs /= nucleus_probs.sum() | |
| sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs) | |
| return int(top_k_indices[nucleus_idx[sampled]]) | |
| else: | |
| sampled = np.random.choice(len(top_k_probs), p=top_k_probs) | |
| return int(top_k_indices[sampled]) | |
| def generate_stream( | |
| prompt: str, | |
| max_tokens: int = 512, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.1 | |
| ): | |
| """Generate text with KV-cache for fast CPU inference.""" | |
| global stop_generation | |
| stop_generation = False | |
| # Tokenize prompt | |
| prompt_ids = tokenizer.encode(prompt).ids | |
| input_ids = [i for i in prompt_ids if i != eos_token_id] | |
| if len(input_ids) == 0: | |
| yield "Error: Empty prompt after tokenization" | |
| return | |
| generated_text = "" | |
| token_count = 0 | |
| token_freq = {} | |
| # Get special token IDs | |
| im_end_id = tokenizer.token_to_id("<|im_end|>") | |
| model_end_id = tokenizer.token_to_id("<im end for model tun>") | |
| stop_ids = {eos_token_id, im_end_id, model_end_id} | |
| stop_ids.discard(None) | |
| max_context = config['max_position_embeddings'] | |
| start_time = time.time() | |
| # === PREFILL PHASE === | |
| # Truncate if prompt is too long | |
| if len(input_ids) > max_context - max_tokens: | |
| input_ids = input_ids[-(max_context - max_tokens):] | |
| input_tensor = tf.constant([input_ids], dtype=tf.int32) | |
| try: | |
| logits, past_kv = model(input_tensor, training=False, use_cache=True) | |
| except Exception as e: | |
| yield f"Error during prefill: {e}" | |
| return | |
| # Get logits for last position | |
| next_token_logits = logits[0, -1, :].numpy() | |
| prefill_time = time.time() - start_time | |
| print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s") | |
| # === GENERATION LOOP === | |
| decode_start = time.time() | |
| for step in range(max_tokens): | |
| if stop_generation: | |
| yield generated_text + "\n\n*[Generation stopped]*" | |
| return | |
| # Sample next token | |
| next_token_id = sample_token( | |
| next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty | |
| ) | |
| # Stop conditions | |
| if next_token_id in stop_ids: | |
| break | |
| # Update frequency tracking | |
| token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1 | |
| # Decode and yield | |
| token_text = tokenizer.decode([next_token_id]) | |
| generated_text += token_text | |
| token_count += 1 | |
| yield generated_text | |
| # === DECODE PHASE (single token, reuse cache) === | |
| next_input = tf.constant([[next_token_id]], dtype=tf.int32) | |
| try: | |
| logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True) | |
| except Exception as e: | |
| yield generated_text + f"\n\n*[Error during generation: {e}]*" | |
| return | |
| next_token_logits = logits[0, -1, :].numpy() | |
| # Truncate cache if too long | |
| current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0 | |
| if current_len > max_context: | |
| trim_amount = current_len - max_context + 100 # Keep some buffer | |
| past_kv = [ | |
| (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :]) | |
| for k, v in past_kv | |
| ] | |
| decode_time = time.time() - decode_start | |
| total_time = time.time() - start_time | |
| if token_count > 0: | |
| decode_tps = token_count / decode_time if decode_time > 0 else 0 | |
| total_tps = token_count / total_time if total_time > 0 else 0 | |
| stats = ( | |
| f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s " | |
| f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*" | |
| ) | |
| if not stop_generation: | |
| generated_text += stats | |
| yield generated_text | |
| # ============================================================================ | |
| # Chat Interface Logic | |
| # ============================================================================ | |
| def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str: | |
| """Format message history and seed <think> if enabled.""" | |
| prompt = "" | |
| for user_msg, assistant_msg in history: | |
| prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" | |
| if assistant_msg: | |
| # Clean up any stats from previous messages | |
| clean_msg = assistant_msg.split("\n\n*[")[0] | |
| prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n" | |
| prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
| if reasoning_enabled: | |
| prompt += "<think>" | |
| return prompt | |
| def chat_stream( | |
| message: str, | |
| history: list, | |
| max_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| reasoning_enabled: bool | |
| ): | |
| if not message.strip(): | |
| yield history | |
| return | |
| prompt = format_chat_prompt(message, history, reasoning_enabled) | |
| partial_response = "" | |
| for generated in generate_stream( | |
| prompt, max_tokens, temperature, top_k, top_p, repetition_penalty | |
| ): | |
| partial_response = generated | |
| # Robust end-of-turn detection | |
| stop_tags = ["<|im_end|>", "<im end for model tun>"] | |
| earliest_stop = len(partial_response) | |
| should_stop = False | |
| for tag in stop_tags: | |
| if tag in partial_response: | |
| idx = partial_response.find(tag) | |
| if idx < earliest_stop: | |
| earliest_stop = idx | |
| should_stop = True | |
| display_response = partial_response | |
| if should_stop: | |
| # Keep the stats portion if present | |
| stats_start = partial_response.find("\n\n*[") | |
| if stats_start > earliest_stop: | |
| display_response = partial_response[:earliest_stop] + partial_response[stats_start:] | |
| else: | |
| display_response = partial_response[:earliest_stop] | |
| # Post-process reasoning tags for display | |
| if reasoning_enabled: | |
| if '<think>' in display_response and '</think>' in display_response: | |
| start_idx = display_response.find('<think>') | |
| end_idx = display_response.find('</think>') | |
| if start_idx != -1 and end_idx != -1 and end_idx > start_idx: | |
| thought_content = display_response[start_idx + len('<think>'):end_idx].strip() | |
| formatted_thought = thought_content.replace("\n", "<br>") | |
| details_html = ( | |
| f'<details class="reasoning-block">' | |
| f'<summary>π§ Model Reasoning (Click to expand)</summary>' | |
| f'<p>{formatted_thought}</p>' | |
| f'</details>' | |
| ) | |
| display_response = ( | |
| display_response[:start_idx] + | |
| details_html + | |
| display_response[end_idx + len('</think>'):] | |
| ) | |
| elif '<think>' in display_response and '</think>' not in display_response: | |
| display_response = display_response.replace('<think>', '**π§ Thinking:** ') | |
| yield history + [[message, display_response.strip()]] | |
| def stop_gen(): | |
| global stop_generation | |
| stop_generation = True | |
| return None | |
| # ============================================================================ | |
| # Gradio UI | |
| # ============================================================================ | |
| custom_css = """ | |
| .gradio-container { max-width: 1200px !important; margin: auto !important; } | |
| .header { | |
| text-align: center; padding: 2rem; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
| color: white; border-radius: 12px; margin-bottom: 2rem; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3); | |
| animation: pulse 2s ease-in-out infinite; | |
| } | |
| @keyframes pulse { 0%, 100% { transform: scale(1); } 50% { transform: scale(1.02); } } | |
| .header h1 { font-size: 2.8rem; margin-bottom: 0.5rem; font-weight: 700; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); } | |
| .header p { font-size: 1.1rem; opacity: 0.95; } | |
| .celebration { font-size: 2rem; margin: 0.5rem; animation: bounce 1s ease infinite; } | |
| @keyframes bounce { 0%, 100% { transform: translateY(0); } 50% { transform: translateY(-10px); } } | |
| .twin-badge { | |
| display: inline-block; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; padding: 0.5rem 1rem; border-radius: 20px; font-weight: bold; margin: 0.5rem; | |
| box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); | |
| } | |
| footer { text-align: center; padding: 2rem; color: #666; border-top: 1px solid #eee; margin-top: 2rem; } | |
| #reasoning-control-group { position: relative; display: flex; align-items: center; justify-content: center; margin-right: 10px; } | |
| #reasoning-toggle-btn { | |
| font-size: 1.5rem; border-radius: 50%; width: 40px; height: 40px; padding: 0; | |
| min-width: 0 !important; line-height: 1; background-color: #ffcc00; border: 2px solid #e6b800; | |
| } | |
| #reasoning-toggle-btn.off { background-color: #e0e0e0; border: 2px solid #ccc; } | |
| .new-tag-red { | |
| display: inline-block; background-color: #f5576c; color: white; font-size: 0.7em; | |
| font-weight: bold; padding: 2px 5px; border-radius: 4px; line-height: 1; | |
| position: absolute; top: -5px; right: -5px; z-index: 10; animation: blink 1s infinite; | |
| } | |
| @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } } | |
| .gradio-html details.reasoning-block { | |
| border: 1px solid #ddd; border-left: 5px solid #667eea; padding: 5px 10px; | |
| margin: 10px 0; border-radius: 4px; background-color: #f9f9ff; | |
| } | |
| .gradio-html details.reasoning-block summary { font-weight: bold; cursor: pointer; outline: none; color: #667eea; } | |
| .gradio-html details.reasoning-block p { margin-top: 5px; padding-left: 10px; border-left: 1px dashed #ccc; white-space: pre-wrap; } | |
| .modal-overlay { | |
| position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: rgba(0, 0, 0, 0.7); | |
| display: flex; justify-content: center; align-items: center; z-index: 1000; | |
| } | |
| .modal-content { | |
| background: white; padding: 30px; border-radius: 15px; width: 90%; max-width: 900px; | |
| box-shadow: 0 10px 50px rgba(0, 0, 0, 0.5); animation: slide-in 0.5s ease-out; | |
| } | |
| @keyframes slide-in { from { transform: translateY(-50px); opacity: 0; } to { transform: translateY(0); opacity: 1; } } | |
| .modal-content h2 { color: #764ba2; border-bottom: 2px solid #eee; padding-bottom: 10px; margin-top: 0; } | |
| .comparison-box { display: flex; gap: 20px; margin-top: 20px; } | |
| .comparison-mode { flex: 1; padding: 15px; border-radius: 10px; } | |
| .mode-reasoning { border: 2px solid #667eea; background-color: #f6f7ff; } | |
| .mode-direct { border: 2px solid #fcb69f; background-color: #fffaf5; } | |
| .comparison-mode h3 { margin-top: 0; font-size: 1.3rem; } | |
| .comparison-mode pre { background-color: #eef; padding: 10px; border-radius: 5px; overflow-x: auto; } | |
| .close-btn { | |
| margin-top: 20px; padding: 10px 20px; background-color: #764ba2; color: white; | |
| border: none; border-radius: 8px; cursor: pointer; font-size: 1rem; transition: background-color 0.3s; | |
| } | |
| .close-btn:hover { background-color: #5d3a84; } | |
| .speed-indicator { | |
| background: linear-gradient(135deg, #00b894, #00cec9); | |
| color: white; padding: 5px 10px; border-radius: 10px; font-size: 0.8rem; | |
| display: inline-block; margin-left: 10px; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| reasoning_enabled = gr.State(False) | |
| welcome_modal_html = gr.HTML( | |
| """ | |
| <div id="welcome-modal" class="modal-overlay" style="display:none;"> | |
| <div class="modal-content"> | |
| <h2>π§ Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2> | |
| <p>Our latest model features <strong>Chain-of-Thought (CoT)</strong> functionality and <strong>KV-Cache optimization</strong> for fast CPU inference!</p> | |
| <div class="comparison-box"> | |
| <div class="comparison-mode mode-reasoning"> | |
| <h3>π‘ Reasoning Mode (ON)</h3> | |
| <p>The model performs a <strong>CoT step</strong> first. The internal thought process is contained within <code><think>...</think></code> tags.</p> | |
| </div> | |
| <div class="comparison-mode mode-direct"> | |
| <h3>βͺ Direct Mode (OFF)</h3> | |
| <p>The model generates the final answer immediately, maximizing speed.</p> | |
| </div> | |
| </div> | |
| <button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| if FESTIVE: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <div class="celebration">π π β¨ π π</div> | |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg" | |
| alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);"> | |
| <h1>π€ Sam-large-2 Chat π€</h1> | |
| <p><strong>LATEST RELEASE!</strong> Our <strong>BEST Reasoning Model</strong> - Now with KV-Cache! <span class="speed-indicator">β‘ 5-20x Faster</span></p> | |
| <div class="twin-badge">Reasoning Model</div> | |
| <div class="celebration">π π« π― β‘ π₯</div> | |
| </div> | |
| """) | |
| else: | |
| gr.HTML("""<div class="header"><h1>π€ Sam-large-2 Chat</h1><p>Advanced Reasoning Model with KV-Cache</p></div>""") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_label=False, | |
| avatar_images=( | |
| None, | |
| "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg" | |
| ), | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"): | |
| reasoning_btn = gr.Button("π‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"]) | |
| gr.HTML('<span class="new-tag-red">NEW</span>') | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| scale=8, | |
| container=False | |
| ) | |
| submit_btn = gr.Button("Send π" if FESTIVE else "Send", variant="primary", scale=1) | |
| stop_btn = gr.Button("βΉοΈ Stop", variant="stop", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Chat", size="sm") | |
| retry_btn = gr.Button("π Retry", size="sm") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Generation Settings") | |
| max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=50, label="Max Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P") | |
| repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty") | |
| gr.Markdown("---") | |
| gr.Markdown(f"""### π Sam-large-2 Model Info | |
| **Type:** Chain-of-Thought Reasoning Model | |
| **Vocab:** {config['vocab_size']:,} | |
| **Layers:** {config['num_hidden_layers']} | |
| **Context:** {config['max_position_embeddings']:,} tokens | |
| **Optimization:** KV-Cache enabled β‘ | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| "Explain quantum computing in simple terms", | |
| "Write a short poem about artificial intelligence", | |
| "What is 24 * 12? Show your reasoning.", | |
| "What are the main differences between Python and JavaScript?" | |
| ], | |
| inputs=msg | |
| ) | |
| gr.HTML(""" | |
| <footer> | |
| <p><strong>π Sam-large-2 - LATEST RELEASE with KV-Cache! π</strong></p> | |
| <p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β’ Built by Smily studios with TensorFlow & Gradio</p> | |
| </footer> | |
| """) | |
| def show_modal_js(): | |
| return """ | |
| (function() { | |
| if (sessionStorage.getItem('sam2_modal_shown') !== 'true') { | |
| const modal = document.getElementById('welcome-modal'); | |
| if (modal) { modal.style.display = 'flex'; sessionStorage.setItem('sam2_modal_shown', 'true'); } | |
| } | |
| })(); | |
| """ | |
| demo.load(None, inputs=None, outputs=None, js=show_modal_js()) | |
| def toggle_reasoning(current_state): | |
| new_state = not current_state | |
| return new_state, gr.update(elem_classes="" if new_state else "off") | |
| reasoning_btn.click( | |
| fn=toggle_reasoning, | |
| inputs=[reasoning_enabled], | |
| outputs=[reasoning_enabled, reasoning_btn], | |
| preprocess=False | |
| ) | |
| common_inputs = [msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled] | |
| submit_event = msg.submit( | |
| chat_stream, | |
| inputs=common_inputs, | |
| outputs=[chatbot] | |
| ).then(lambda: "", outputs=[msg]) | |
| click_event = submit_btn.click( | |
| chat_stream, | |
| inputs=common_inputs, | |
| outputs=[chatbot] | |
| ).then(lambda: "", outputs=[msg]) | |
| stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[submit_event, click_event]) | |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| def retry_last(history, max_tok, temp, topk, topp, rep_pen, reasoning_en): | |
| if not history: | |
| return history | |
| last_user_msg = history[-1][0] | |
| for update in chat_stream(last_user_msg, history[:-1], max_tok, temp, topk, topp, rep_pen, reasoning_en): | |
| yield update | |
| retry_event = retry_btn.click( | |
| retry_last, | |
| inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled], | |
| outputs=[chatbot] | |
| ) | |
| stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[retry_event]) | |
| if __name__ == "__main__": | |
| print("\n" + "=" * 60) | |
| print("π Starting Sam-large-2 Chat with KV-Cache Optimization") | |
| print("=" * 60 + "\n") | |
| demo.queue(max_size=20) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) |