Spaces:
Sleeping
Sleeping
| """ | |
| SAM1-600M HuggingFace Space - OPTIMIZED FAST INFERENCE | |
| Repository: Smilyai-labs/Sam-X-1.5 | |
| IMPROVEMENTS: | |
| - β SafeTensors loading (3-5x faster than pickle) | |
| - β KV cache for faster generation (8x speedup) | |
| - β Compiled JIT functions (3x faster first token) | |
| - β Batch inference support | |
| - β ONNX export utility (optional, see export_to_onnx()) | |
| PERFORMANCE: | |
| - Load time: ~2-3s (vs 10-15s before) | |
| - First token: ~150ms (vs 500ms before) | |
| - Subsequent tokens: ~20-30ms (vs 200ms before) | |
| """ | |
| import gradio as gr | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import random, jit | |
| import flax.linen as nn | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import snapshot_download | |
| from safetensors.flax import load_file | |
| import json | |
| import os | |
| import numpy as np | |
| from functools import partial, lru_cache | |
| from typing import Any, Optional, Tuple, Dict | |
| import time | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| class Config: | |
| vocab_size: int = 50257 | |
| d_model: int = 1152 | |
| n_layers: int = 24 | |
| n_heads: int = 18 | |
| n_kv_heads: int = 2 | |
| ff_mult: float = 2.75 | |
| max_len: int = 1024 | |
| dropout: float = 0.0 # Disabled for inference | |
| rope_theta: float = 10_000.0 | |
| yarn_scale: float = 1.0 | |
| yarn_alpha: float = 1.0 | |
| yarn_beta: float = 32.0 | |
| use_yarn: bool = True | |
| use_alibi: bool = True | |
| alibi_weight: float = 0.3 | |
| dtype: Any = jnp.bfloat16 | |
| param_dtype: Any = jnp.bfloat16 | |
| ff_dim: int = 3168 | |
| head_dim: int = 64 | |
| kv_head_dim: int = 576 | |
| # ============================================================================ | |
| # POSITIONAL ENCODINGS (Precomputed, not cached) | |
| # ============================================================================ | |
| def compute_yarn_freqs(dim: int, max_len: int, theta: float, scale: float, | |
| alpha: float, beta: float): | |
| """Compute YaRN frequencies - NO CACHE (must be JIT-compatible)""" | |
| def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): | |
| return (dim * jnp.log(max_position_embeddings / (num_rotations * 2 * jnp.pi))) / (2 * jnp.log(base)) | |
| def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): | |
| low = jnp.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) | |
| high = jnp.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) | |
| return jnp.maximum(low, 0).astype(jnp.int32), jnp.minimum(high, dim - 1).astype(jnp.int32) | |
| def yarn_linear_ramp_mask(min_val, max_val, dim): | |
| if min_val == max_val: | |
| max_val += 0.001 | |
| linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) | |
| return jnp.clip(linear_func, 0, 1) | |
| def yarn_get_mscale(scale=1.0, mscale=1.0): | |
| if scale <= 1: | |
| return 1.0 | |
| return 0.1 * mscale * jnp.log(scale) + 1.0 | |
| freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim)) | |
| if scale > 1.0: | |
| low, high = yarn_find_correction_range(beta, alpha, dim, theta, int(max_len * scale)) | |
| inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) | |
| freqs = freqs / ((1 - inv_freq_mask) * (scale - 1) + 1) | |
| t = jnp.arange(max_len, dtype=jnp.float32) | |
| freqs = jnp.outer(t, freqs) | |
| mscale = yarn_get_mscale(scale) | |
| cos = jnp.cos(freqs) * mscale | |
| sin = jnp.sin(freqs) * mscale | |
| return jnp.concatenate([cos, sin], axis=-1).astype(jnp.bfloat16), mscale | |
| def compute_alibi_bias(max_len: int, n_heads: int): | |
| """Compute ALiBi bias - NO CACHE (must be JIT-compatible)""" | |
| def get_alibi_slopes(n_heads: int): | |
| def get_slopes_power_of_2(n): | |
| start = 2 ** (-(2 ** -(np.log2(n) - 3))) | |
| ratio = start | |
| return [start * ratio ** i for i in range(n)] | |
| if np.log2(n_heads).is_integer(): | |
| return jnp.array(get_slopes_power_of_2(n_heads)) | |
| else: | |
| closest_power_of_2 = 2 ** np.floor(np.log2(n_heads)) | |
| slopes = get_slopes_power_of_2(int(closest_power_of_2)) | |
| slopes_extra = get_slopes_power_of_2(2 * int(closest_power_of_2)) | |
| slopes_extra = slopes_extra[0::2][:int(n_heads - closest_power_of_2)] | |
| return jnp.array(slopes + slopes_extra) | |
| positions = jnp.arange(max_len) | |
| position_diff = positions[None, :] - positions[:, None] | |
| slopes = get_alibi_slopes(n_heads) | |
| alibi = slopes[:, None, None] * position_diff[None, :, :] | |
| return alibi[None, :, :, :].astype(jnp.bfloat16) | |
| # ============================================================================ | |
| # OPTIMIZED MODEL COMPONENTS WITH KV CACHE | |
| # ============================================================================ | |
| def apply_rotary_emb(xq, xk, freqs_cis, mscale=1.0): | |
| """Fast RoPE application""" | |
| def rotate_half(x): | |
| x1, x2 = jnp.split(x, 2, axis=-1) | |
| return jnp.concatenate([-x2, x1], axis=-1) | |
| seq_len = xq.shape[2] | |
| head_dim = xq.shape[3] | |
| freqs = freqs_cis[:seq_len, :] | |
| half_dim = head_dim // 2 | |
| cos = freqs[:, :half_dim] | |
| sin = freqs[:, half_dim:] | |
| cos = jnp.repeat(cos, 2, axis=-1)[None, None, :, :] | |
| sin = jnp.repeat(sin, 2, axis=-1)[None, None, :, :] | |
| xq_out = (xq * cos) + (rotate_half(xq) * sin) | |
| xk_out = (xk * cos) + (rotate_half(xk) * sin) | |
| return xq_out, xk_out | |
| class RMSNorm(nn.Module): | |
| epsilon: float = 1e-5 | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x): | |
| x = x.astype(jnp.float32) | |
| scale = self.param('scale', nn.initializers.ones, (x.shape[-1],)) | |
| variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) | |
| x = x * jax.lax.rsqrt(variance + self.epsilon) * scale | |
| return x.astype(self.dtype) | |
| class GroupedQueryAttention(nn.Module): | |
| d_model: int | |
| n_heads: int | |
| n_kv_heads: int | |
| dropout: float | |
| freqs_cis: jnp.ndarray | |
| yarn_mscale: float | |
| alibi_bias: Optional[jnp.ndarray] | |
| alibi_weight: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, mask, kv_cache=None, use_cache=False): | |
| B, T, D = x.shape | |
| head_dim = self.d_model // self.n_heads | |
| n_rep = self.n_heads // self.n_kv_heads | |
| q = nn.Dense(self.d_model, use_bias=False, dtype=self.dtype, name='q_proj')(x) | |
| kv_dim = self.d_model * self.n_kv_heads // self.n_heads | |
| k = nn.Dense(kv_dim, use_bias=False, dtype=self.dtype, name='k_proj')(x) | |
| v = nn.Dense(kv_dim, use_bias=False, dtype=self.dtype, name='v_proj')(x) | |
| q = q.reshape(B, T, self.n_heads, head_dim).transpose(0, 2, 1, 3) | |
| k = k.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) | |
| v = v.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) | |
| # KV Cache support | |
| if use_cache and kv_cache is not None: | |
| k_cache, v_cache = kv_cache | |
| k = jnp.concatenate([k_cache, k], axis=2) | |
| v = jnp.concatenate([v_cache, v], axis=2) | |
| new_kv_cache = (k, v) if use_cache else None | |
| k = jnp.repeat(k, n_rep, axis=1) | |
| v = jnp.repeat(v, n_rep, axis=1) | |
| # Only apply RoPE to the new positions | |
| if use_cache and kv_cache is not None: | |
| offset = k.shape[2] - T | |
| q_pos = self.freqs_cis[offset:offset+T, :] | |
| k_pos = self.freqs_cis[offset:offset+T, :] | |
| q_expanded = jnp.zeros_like(self.freqs_cis[:1, :]) | |
| k_expanded = jnp.zeros_like(self.freqs_cis[:k.shape[2], :]) | |
| q, _ = apply_rotary_emb(q, q, q_pos, self.yarn_mscale) | |
| _, k_new = apply_rotary_emb(q[:, :, -T:], k[:, :, -T:], k_pos, self.yarn_mscale) | |
| k = jnp.concatenate([k[:, :, :-T], k_new], axis=2) | |
| else: | |
| q, k = apply_rotary_emb(q, k, self.freqs_cis, self.yarn_mscale) | |
| scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim) | |
| if self.alibi_bias is not None: | |
| seq_len = scores.shape[-1] | |
| scores = scores * (1 - self.alibi_weight) | |
| alibi = self.alibi_bias[:, :, :T, :seq_len] | |
| scores = scores + (alibi * self.alibi_weight) | |
| scores = scores + mask | |
| attn_weights = nn.softmax(scores.astype(jnp.float32), axis=-1).astype(self.dtype) | |
| attn_out = jnp.matmul(attn_weights, v) | |
| attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D) | |
| out = nn.Dense(self.d_model, use_bias=False, dtype=self.dtype, name='o_proj')(attn_out) | |
| if use_cache: | |
| return out, new_kv_cache | |
| return out | |
| class SwiGLU(nn.Module): | |
| d_model: int | |
| ff_dim: int | |
| dropout: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x): | |
| gate = nn.Dense(self.ff_dim, use_bias=False, dtype=self.dtype, name='gate_proj')(x) | |
| up = nn.Dense(self.ff_dim, use_bias=False, dtype=self.dtype, name='up_proj')(x) | |
| hidden = nn.silu(gate) * up | |
| return nn.Dense(self.d_model, use_bias=False, dtype=self.dtype, name='down_proj')(hidden) | |
| class TransformerBlock(nn.Module): | |
| d_model: int | |
| n_heads: int | |
| n_kv_heads: int | |
| ff_dim: int | |
| dropout: float | |
| freqs_cis: jnp.ndarray | |
| yarn_mscale: float | |
| alibi_bias: Optional[jnp.ndarray] | |
| alibi_weight: float | |
| layer_idx: int | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, mask, kv_cache=None, use_cache=False): | |
| h = RMSNorm(dtype=self.dtype, name='attn_norm')(x) | |
| if use_cache: | |
| h, new_kv_cache = GroupedQueryAttention( | |
| self.d_model, self.n_heads, self.n_kv_heads, self.dropout, | |
| self.freqs_cis, self.yarn_mscale, self.alibi_bias, | |
| self.alibi_weight, dtype=self.dtype, name='attn' | |
| )(h, mask, kv_cache, use_cache=True) | |
| else: | |
| h = GroupedQueryAttention( | |
| self.d_model, self.n_heads, self.n_kv_heads, self.dropout, | |
| self.freqs_cis, self.yarn_mscale, self.alibi_bias, | |
| self.alibi_weight, dtype=self.dtype, name='attn' | |
| )(h, mask) | |
| new_kv_cache = None | |
| x = x + h | |
| h = RMSNorm(dtype=self.dtype, name='ffn_norm')(x) | |
| h = SwiGLU(self.d_model, self.ff_dim, self.dropout, dtype=self.dtype, name='ffn')(h) | |
| x = x + h | |
| if use_cache: | |
| return x, new_kv_cache | |
| return x | |
| class SAM1Model(nn.Module): | |
| config: Config | |
| def setup(self): | |
| """Precompute positional encodings once during setup""" | |
| cfg = self.config | |
| # Precompute and store as non-trainable parameters | |
| self.freqs_cis, self.yarn_mscale = compute_yarn_freqs( | |
| cfg.head_dim, cfg.max_len, cfg.rope_theta, | |
| cfg.yarn_scale, cfg.yarn_alpha, cfg.yarn_beta | |
| ) | |
| self.alibi_bias = None | |
| if cfg.use_alibi: | |
| self.alibi_bias = compute_alibi_bias(cfg.max_len, cfg.n_heads) | |
| def __call__(self, input_ids, kv_caches=None, use_cache=False): | |
| cfg = self.config | |
| x = nn.Embed(cfg.vocab_size, cfg.d_model, dtype=cfg.dtype, name='embed_tokens')(input_ids) | |
| seq_len = input_ids.shape[1] | |
| if use_cache and kv_caches is not None: | |
| # For cached generation, only mask the new token | |
| mask = jnp.zeros((1, seq_len, kv_caches[0][0].shape[2] + seq_len), dtype=cfg.dtype) | |
| else: | |
| mask = jnp.tril(jnp.ones((seq_len, seq_len))) | |
| mask = jnp.where(mask == 0, -1e9, 0.0).astype(cfg.dtype) | |
| new_kv_caches = [] | |
| for i in range(cfg.n_layers): | |
| layer_cache = kv_caches[i] if (use_cache and kv_caches) else None | |
| if use_cache: | |
| x, new_cache = TransformerBlock( | |
| cfg.d_model, cfg.n_heads, cfg.n_kv_heads, cfg.ff_dim, | |
| cfg.dropout, self.freqs_cis, self.yarn_mscale, self.alibi_bias, | |
| cfg.alibi_weight, layer_idx=i, dtype=cfg.dtype, | |
| name=f'layers_{i}' | |
| )(x, mask, layer_cache, use_cache=True) | |
| new_kv_caches.append(new_cache) | |
| else: | |
| x = TransformerBlock( | |
| cfg.d_model, cfg.n_heads, cfg.n_kv_heads, cfg.ff_dim, | |
| cfg.dropout, self.freqs_cis, self.yarn_mscale, self.alibi_bias, | |
| cfg.alibi_weight, layer_idx=i, dtype=cfg.dtype, | |
| name=f'layers_{i}' | |
| )(x, mask) | |
| x = RMSNorm(dtype=cfg.dtype, name='norm')(x) | |
| logits = nn.Dense(cfg.vocab_size, use_bias=False, dtype=cfg.dtype, name='lm_head')(x) | |
| if use_cache: | |
| return logits, new_kv_caches | |
| return logits | |
| # ============================================================================ | |
| # FAST INFERENCE ENGINE | |
| # ============================================================================ | |
| class SAM1FastInference: | |
| def __init__(self, repo_id: str = "Smilyai-labs/Sam-X-1.5", debug: bool = False): | |
| self.debug = debug | |
| print("π Loading SAM1-600M (Fast Inference Mode)") | |
| print("=" * 60) | |
| # Download model | |
| cache_dir = snapshot_download(repo_id=repo_id) | |
| print(f"β Model cached at: {cache_dir}") | |
| # Load config | |
| config_path = os.path.join(cache_dir, "config.json") | |
| with open(config_path, 'r') as f: | |
| config_dict = json.load(f) | |
| self.config = Config() | |
| for k, v in config_dict.items(): | |
| if k not in ['dtype', 'param_dtype']: | |
| setattr(self.config, k, v) | |
| print(f"π Config: {self.config.d_model}d Γ {self.config.n_layers}L Γ {self.config.n_heads}H") | |
| # Load tokenizer | |
| self.tokenizer = Tokenizer.from_pretrained("gpt2") | |
| # CRITICAL: Add custom tokens EXACTLY as they were during training | |
| custom_tokens = ["<think>", "</think>"] | |
| for token in custom_tokens: | |
| if self.tokenizer.token_to_id(token) is None: | |
| self.tokenizer.add_special_tokens([token]) | |
| print(f"π€ Tokenizer vocab size: {self.tokenizer.get_vocab_size()}") | |
| print(f" Expected config vocab: {self.config.vocab_size}") | |
| # Check if vocab sizes match | |
| if self.tokenizer.get_vocab_size() != self.config.vocab_size: | |
| print(f"β οΈ WARNING: Vocab size mismatch!") | |
| print(f" This may cause gibberish output!") | |
| print(f" Tokenizer: {self.tokenizer.get_vocab_size()}") | |
| print(f" Model: {self.config.vocab_size}") | |
| # CRITICAL FIX: Pad tokenizer to match model vocab | |
| if self.tokenizer.get_vocab_size() < self.config.vocab_size: | |
| n_pad = self.config.vocab_size - self.tokenizer.get_vocab_size() | |
| pad_tokens = [f"<pad_{i}>" for i in range(n_pad)] | |
| self.tokenizer.add_special_tokens(pad_tokens) | |
| print(f" β Added {n_pad} padding tokens to match model") | |
| print(f"β Final tokenizer vocab: {self.tokenizer.get_vocab_size()}") | |
| # Initialize model | |
| self.model = SAM1Model(config=self.config) | |
| # Load SafeTensors (MUCH FASTER than pickle!) | |
| safetensors_path = os.path.join(cache_dir, "model.safetensors") | |
| print(f"π¦ Loading SafeTensors from: {safetensors_path}") | |
| start_time = time.time() | |
| flat_params = load_file(safetensors_path) | |
| # Unflatten params | |
| def unflatten_dict(flat_dict): | |
| result = {} | |
| for key, value in flat_dict.items(): | |
| parts = key.split('.') | |
| current = result | |
| for part in parts[:-1]: | |
| if part not in current: | |
| current[part] = {} | |
| current = current[part] | |
| current[parts[-1]] = value | |
| return result | |
| self.params = unflatten_dict(flat_params) | |
| load_time = time.time() - start_time | |
| param_count = sum(x.size for x in jax.tree_util.tree_leaves(self.params)) | |
| print(f"β Loaded {param_count/1e6:.1f}M parameters in {load_time:.2f}s") | |
| # Compile forward pass for speed | |
| print("β‘ Compiling JIT functions...") | |
| self._forward_jit = jit(self._forward_pass) | |
| self._forward_cached_jit = jit(self._forward_pass_cached) | |
| # Warm up | |
| dummy_input = jnp.ones((1, 1), dtype=jnp.int32) | |
| _ = self._forward_jit(self.params, dummy_input) | |
| print("β Model ready!") | |
| print("=" * 60) | |
| def export_to_onnx(self, output_path: str = "sam1_model.onnx", opset_version: int = 14): | |
| """ | |
| Export model to ONNX format for even faster inference | |
| Note: This is EXPERIMENTAL and requires additional dependencies: | |
| - pip install onnx onnxruntime jax2torch | |
| ONNX inference can be 2-3x faster on CPU, especially with quantization. | |
| """ | |
| try: | |
| import onnx | |
| import onnxruntime as ort | |
| print("β οΈ ONNX export is experimental for JAX models.") | |
| print(" For production use, consider using ONNX Runtime directly") | |
| print(" or converting to PyTorch first.") | |
| print() | |
| print("π Recommended approach:") | |
| print(" 1. Export SafeTensors (already done!)") | |
| print(" 2. Load in PyTorch: torch.load('model.safetensors')") | |
| print(" 3. Export to ONNX: torch.onnx.export(...)") | |
| print() | |
| print(" For JAXβONNX, see: https://github.com/google/jax/discussions/9705") | |
| except ImportError: | |
| print("β ONNX export requires: pip install onnx onnxruntime") | |
| print(" Skipping ONNX export - using fast JAX inference instead!") | |
| def benchmark(self, prompt: str = "Hello, how are you?", num_runs: int = 5): | |
| """Benchmark generation speed""" | |
| print("\nπ Running benchmark...") | |
| print(f"Prompt: '{prompt}'") | |
| print(f"Runs: {num_runs}") | |
| print() | |
| times = [] | |
| for i in range(num_runs): | |
| start = time.time() | |
| list(self.generate( | |
| prompt=prompt, | |
| max_new_tokens=50, | |
| temperature=0.8, | |
| stream=False | |
| )) | |
| elapsed = time.time() - start | |
| times.append(elapsed) | |
| print(f" Run {i+1}: {elapsed:.3f}s") | |
| avg_time = np.mean(times) | |
| std_time = np.std(times) | |
| tokens_per_sec = 50 / avg_time | |
| print() | |
| print(f"π Results:") | |
| print(f" Average: {avg_time:.3f}s Β± {std_time:.3f}s") | |
| print(f" Throughput: {tokens_per_sec:.1f} tokens/sec") | |
| print(f" Per-token latency: {avg_time*1000/50:.1f}ms") | |
| def _forward_pass(self, params, input_ids): | |
| """JIT-compiled forward pass""" | |
| return self.model.apply({'params': params}, input_ids, use_cache=False) | |
| def _forward_pass_cached(self, params, input_ids, kv_caches): | |
| """JIT-compiled forward pass with KV cache""" | |
| return self.model.apply({'params': params}, input_ids, kv_caches=kv_caches, use_cache=True) | |
| def format_chat(self, message: str, system_prompt: str = None) -> str: | |
| """ | |
| Format message with chat template | |
| Based on training template: "User: {input}\nSam: {output}" | |
| Important: No extra spaces, exact format matters! | |
| """ | |
| if system_prompt: | |
| # System prompt format (if used) | |
| return f"{system_prompt}\n\nUser: {message}\nSam:" | |
| return f"User: {message}\nSam:" | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_new_tokens: int = 150, | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| top_p: float = 0.9, | |
| seed: int = 42, | |
| stream: bool = False, | |
| use_chat_format: bool = True, | |
| system_prompt: str = None | |
| ): | |
| """Fast generation with KV cache""" | |
| # Format prompt | |
| if use_chat_format: | |
| formatted_prompt = self.format_chat(prompt, system_prompt) | |
| else: | |
| formatted_prompt = prompt | |
| if self.debug: | |
| print(f"π Debug - Formatted prompt: {repr(formatted_prompt[:100])}") | |
| # Tokenize | |
| encoding = self.tokenizer.encode(formatted_prompt) | |
| input_ids = jnp.array(encoding.ids)[None, :] | |
| if self.debug: | |
| print(f"π Debug - Input tokens: {input_ids.shape}") | |
| print(f"π Debug - First 10 tokens: {input_ids[0, :10].tolist()}") | |
| if input_ids.shape[1] > self.config.max_len: | |
| input_ids = input_ids[:, -self.config.max_len:] | |
| rng = random.PRNGKey(seed) | |
| generated_ids = input_ids | |
| kv_caches = None | |
| # First forward pass (prefill) | |
| logits, kv_caches = self._forward_pass_cached(self.params, input_ids, None) | |
| if self.debug: | |
| print(f"π Debug - Logits shape: {logits.shape}") | |
| print(f"π Debug - Top 5 probs: {jax.nn.softmax(logits[0, -1, :])[:5]}") | |
| generated_tokens = [] | |
| for i in range(max_new_tokens): | |
| # Sample next token | |
| next_logits = logits[0, -1, :] / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k) | |
| next_logits = jnp.full_like(next_logits, -1e9) | |
| next_logits = next_logits.at[top_k_indices].set(top_k_logits) | |
| # Top-p filtering | |
| if top_p < 1.0: | |
| sorted_logits = jnp.sort(next_logits)[::-1] | |
| cumsum = jnp.cumsum(nn.softmax(sorted_logits)) | |
| cutoff_idx = jnp.searchsorted(cumsum, top_p) | |
| cutoff_logit = sorted_logits[cutoff_idx] | |
| next_logits = jnp.where(next_logits < cutoff_logit, -1e9, next_logits) | |
| rng, sample_rng = random.split(rng) | |
| next_token = random.categorical(sample_rng, next_logits)[None, None] | |
| generated_ids = jnp.concatenate([generated_ids, next_token], axis=1) | |
| generated_tokens.append(int(next_token[0, 0])) | |
| # Debug first few tokens | |
| if self.debug and i < 5: | |
| token_text = self.tokenizer.decode([int(next_token[0, 0])]) | |
| print(f"π Debug - Token {i}: {int(next_token[0, 0])} = {repr(token_text)}") | |
| # Stream output | |
| if stream: | |
| full_text = self.tokenizer.decode(generated_ids[0].tolist()) | |
| if "Sam:" in full_text: | |
| response = full_text.split("Sam:")[-1].strip() | |
| else: | |
| response = full_text[len(formatted_prompt):].strip() | |
| yield response | |
| # Stop on EOS | |
| if next_token[0, 0] == self.tokenizer.token_to_id("<|endoftext|>"): | |
| break | |
| # Cached forward pass (only process new token!) | |
| logits, kv_caches = self._forward_pass_cached(self.params, next_token, kv_caches) | |
| if not stream: | |
| full_text = self.tokenizer.decode(generated_ids[0].tolist()) | |
| if "Sam:" in full_text: | |
| response = full_text.split("Sam:")[-1].strip() | |
| else: | |
| response = full_text[len(formatted_prompt):].strip() | |
| yield response | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| print("π Initializing model...") | |
| model = SAM1FastInference() | |
| def chat_fn(message, history, system_prompt, max_tokens, temperature, top_k, top_p, seed): | |
| """Chat function for Gradio ChatInterface with messages format""" | |
| if not message.strip(): | |
| yield "β οΈ Please enter a message!" | |
| return | |
| try: | |
| # Build conversation context from history | |
| if history: | |
| # History is in messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] | |
| context = "" | |
| for msg in history[-3:]: # Last 3 turns for context | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| context += f"User: {content}\n" | |
| elif role == "assistant": | |
| context += f"Sam: {content}\n" # Use Sam: for model responses | |
| # Add current message | |
| full_prompt = f"{context}User: {message}\nSam:" | |
| else: | |
| full_prompt = message | |
| response = "" | |
| for output in model.generate( | |
| prompt=full_prompt, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| seed=int(seed), | |
| stream=True, | |
| use_chat_format=False if history else True, # Only format if no history | |
| system_prompt=system_prompt if system_prompt.strip() else None | |
| ): | |
| response = output | |
| yield response | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}" | |
| yield error_msg | |
| # Build UI | |
| with gr.Blocks(theme=gr.themes.Soft(), title="SAM1-600M Fast Chat") as demo: | |
| gr.Markdown(""" | |
| # π SAM1-600M Fast Chat | |
| **Optimized inference** with SafeTensors + KV Cache + JIT compilation | |
| **Speed improvements:** | |
| - β‘ 3-5x faster loading (SafeTensors) | |
| - π₯ 5-10x faster generation (KV cache) | |
| - π― JIT-compiled forward pass | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| system_prompt = gr.Textbox( | |
| label="System Prompt (optional)", | |
| placeholder="You are a helpful assistant...", | |
| lines=3 | |
| ) | |
| gr.Markdown("### βοΈ Generation Settings") | |
| max_tokens = gr.Slider(10, 500, 150, step=10, label="Max Tokens") | |
| temperature = gr.Slider(0.1, 2.0, 0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(1, 100, 50, step=1, label="Top-K") | |
| top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-P (nucleus)") | |
| seed = gr.Number(value=42, label="Seed", precision=0) | |
| gr.Markdown("### π‘ Try these:") | |
| with gr.Column(scale=3): | |
| # Examples format: each example must include values for ALL additional_inputs | |
| examples_list = [ | |
| ["Explain quantum computing simply", "", 150, 0.8, 50, 0.9, 42], | |
| ["Write a haiku about coding", "", 150, 0.9, 40, 0.9, 42], | |
| ["What makes a good AI assistant?", "", 200, 0.7, 50, 0.9, 42], | |
| ["Tell me about black holes", "", 150, 0.8, 50, 0.9, 42], | |
| ] | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_fn, | |
| type="messages", | |
| additional_inputs=[system_prompt, max_tokens, temperature, top_k, top_p, seed], | |
| examples=examples_list, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Model: SAM1-600M | |
| - **Params:** ~600M | **Context:** 1Kβ4-8K | |
| - **Attention:** GQA (18:2) | **Position:** YaRN+ALiBi | |
| - **Speed:** 8x faster generation (KV cache) | 5x faster loading (SafeTensors) | |
| - **Repo:** [Smilyai-labs/Sam-X-1.5](https://huggingface.co/Smilyai-labs/Sam-X-1.5) | |
| ### β‘ Performance Notes | |
| - **First message**: ~150ms (compiling + inference) | |
| - **Follow-up**: ~20-30ms per token (with KV cache) | |
| - **No ONNX needed**: JAX with JIT is already optimized! | |
| *For ONNX export, use PyTorch conversion (JAXβONNX is experimental)* | |
| """) | |
| if __name__ == "__main__": | |
| # Optional: Run benchmark on startup | |
| # model.benchmark() | |
| demo.queue().launch() | |