Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| class RotaryEmbedding(keras.layers.Layer): | |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = dim | |
| self.max_len = max_len | |
| self.theta = theta | |
| self.built_cache = False | |
| self.cos_cached = None | |
| self.sin_cached = None | |
| def build(self, input_shape): | |
| super().build(input_shape) | |
| def _build_cache(self): | |
| if not self.built_cache: | |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
| t = tf.range(self.max_len, dtype=tf.float32) | |
| freqs = tf.einsum("i,j->ij", t, inv_freq) | |
| emb = tf.concat([freqs, freqs], axis=-1) | |
| self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) | |
| self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) | |
| self.built_cache = True | |
| def rotate_half(self, x): | |
| x1, x2 = tf.split(x, 2, axis=-1) | |
| return tf.concat([-x2, x1], axis=-1) | |
| def call(self, q, k, offset=0): | |
| """Apply rotary embeddings with position offset.""" | |
| self._build_cache() | |
| seq_len = tf.shape(q)[2] | |
| dtype = q.dtype | |
| cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] | |
| sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] | |
| q_embed = (q * cos) + (self.rotate_half(q) * sin) | |
| k_embed = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) | |
| return config | |
| class RMSNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| self.scale = None | |
| def build(self, input_shape): | |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") | |
| super().build(input_shape) | |
| def call(self, x): | |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) | |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"epsilon": self.epsilon}) | |
| return config | |
| class TransformerBlock(keras.layers.Layer): | |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.ff_dim = ff_dim | |
| self.dropout_rate = dropout | |
| self.max_len = max_len | |
| self.rope_theta = rope_theta | |
| self.head_dim = d_model // n_heads | |
| self.layer_idx = layer_idx | |
| def build(self, input_shape): | |
| self.pre_attn_norm = RMSNorm(name="pre_attn_norm") | |
| self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") | |
| self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") | |
| self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") | |
| self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") | |
| self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") | |
| self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) | |
| self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") | |
| self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") | |
| self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") | |
| self.dropout = keras.layers.Dropout(self.dropout_rate) | |
| super().build(input_shape) | |
| def call(self, x, training=None, past_kv=None, use_cache=False): | |
| """Simplified call without KV cache for this example""" | |
| B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model | |
| dtype = x.dtype | |
| res = x | |
| y = self.pre_attn_norm(x) | |
| # Multi-head attention | |
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| # Apply RoPE | |
| q, k = self.rope(q, k, offset=0) | |
| # Attention scores | |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) | |
| # Causal mask | |
| mask = tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) # Upper triangular | |
| mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) | |
| scores = scores + mask[None, None, :, :] | |
| attn = tf.nn.softmax(scores, axis=-1) | |
| attn_out = tf.matmul(attn, v) | |
| attn_out = tf.transpose(attn_out, [0, 2, 1, 3]) | |
| attn_out = tf.reshape(attn_out, [B, T, self.d_model]) | |
| x = res + self.dropout(self.out_proj(attn_out), training=training) | |
| # FFN | |
| res = x | |
| y = self.pre_ffn_norm(x) | |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) | |
| output = res + self.dropout(ffn, training=training) | |
| return output, None # Return None for past_kv in this simplified version | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "d_model": self.d_model, | |
| "n_heads": self.n_heads, | |
| "ff_dim": self.ff_dim, | |
| "dropout": self.dropout_rate, | |
| "max_len": self.max_len, | |
| "rope_theta": self.rope_theta, | |
| "layer_idx": self.layer_idx | |
| }) | |
| return config | |
| class SAM1Model(keras.Model): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): | |
| self.cfg = kwargs['config'] | |
| elif 'vocab_size' in kwargs: | |
| self.cfg = kwargs | |
| else: | |
| self.cfg = kwargs.get('cfg', kwargs) | |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") | |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) | |
| block_args = { | |
| 'd_model': self.cfg['d_model'], | |
| 'n_heads': self.cfg['n_heads'], | |
| 'ff_dim': ff_dim, | |
| 'dropout': self.cfg['dropout'], | |
| 'max_len': self.cfg['max_len'], | |
| 'rope_theta': self.cfg['rope_theta'] | |
| } | |
| self.blocks = [ | |
| TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) | |
| for i in range(self.cfg['n_layers']) | |
| ] | |
| self.norm = RMSNorm(name="final_norm") | |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") | |
| def call(self, input_ids, training=None, past_kv=None, use_cache=False): | |
| """ | |
| Simplified call without full KV cache implementation | |
| """ | |
| x = self.embed(input_ids) | |
| for block in self.blocks: | |
| x, _ = block(x, training=training, past_kv=None, use_cache=False) | |
| logits = self.lm_head(self.norm(x)) | |
| return logits, None # Return None for past_kv in this simplified version | |
| def get_config(self): | |
| base_config = super().get_config() | |
| base_config['config'] = self.cfg | |
| return base_config | |
| def count_parameters(model): | |
| """Count model parameters""" | |
| total_params = 0 | |
| for weight in model.weights: | |
| w = weight.numpy() | |
| total_params += w.size | |
| return total_params |