import tensorflow as tf import keras import numpy as np @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False self.cos_cached = None self.sin_cached = None def build(self, input_shape): super().build(input_shape) def _build_cache(self): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) self.built_cache = True def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k, offset=0): """Apply rotary embeddings with position offset.""" self._build_cache() seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed def get_config(self): config = super().get_config() config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) return config @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon self.scale = None def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") super().build(input_shape) def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) return config @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx def build(self, input_shape): self.pre_attn_norm = RMSNorm(name="pre_attn_norm") self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(self.dropout_rate) super().build(input_shape) def call(self, x, training=None, past_kv=None, use_cache=False): """Simplified call without KV cache for this example""" B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model dtype = x.dtype res = x y = self.pre_attn_norm(x) # Multi-head attention q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) # Apply RoPE q, k = self.rope(q, k, offset=0) # Attention scores scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) # Causal mask mask = tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) # Upper triangular mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) scores = scores + mask[None, None, :, :] attn = tf.nn.softmax(scores, axis=-1) attn_out = tf.matmul(attn, v) attn_out = tf.transpose(attn_out, [0, 2, 1, 3]) attn_out = tf.reshape(attn_out, [B, T, self.d_model]) x = res + self.dropout(self.out_proj(attn_out), training=training) # FFN res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) output = res + self.dropout(ffn, training=training) return output, None # Return None for past_kv in this simplified version def get_config(self): config = super().get_config() config.update({ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx }) return config @keras.saving.register_keras_serializable() class SAM1Model(keras.Model): def __init__(self, **kwargs): super().__init__() if 'config' in kwargs and isinstance(kwargs['config'], dict): self.cfg = kwargs['config'] elif 'vocab_size' in kwargs: self.cfg = kwargs else: self.cfg = kwargs.get('cfg', kwargs) self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) block_args = { 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta'] } self.blocks = [ TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) for i in range(self.cfg['n_layers']) ] self.norm = RMSNorm(name="final_norm") self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") def call(self, input_ids, training=None, past_kv=None, use_cache=False): """ Simplified call without full KV cache implementation """ x = self.embed(input_ids) for block in self.blocks: x, _ = block(x, training=training, past_kv=None, use_cache=False) logits = self.lm_head(self.norm(x)) return logits, None # Return None for past_kv in this simplified version def get_config(self): base_config = super().get_config() base_config['config'] = self.cfg return base_config def count_parameters(model): """Count model parameters""" total_params = 0 for weight in model.weights: w = weight.numpy() total_params += w.size return total_params