worker-mini / model_architecture.py
Bc-AI's picture
Upload folder using huggingface_hub
0b11938 verified
import tensorflow as tf
import keras
import numpy as np
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
self.cos_cached = None
self.sin_cached = None
def build(self, input_shape):
super().build(input_shape)
def _build_cache(self):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
self.built_cache = True
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k, offset=0):
"""Apply rotary embeddings with position offset."""
self._build_cache()
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
self.scale = None
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
def build(self, input_shape):
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(self.dropout_rate)
super().build(input_shape)
def call(self, x, training=None, past_kv=None, use_cache=False):
"""Simplified call without KV cache for this example"""
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
# Multi-head attention
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
# Apply RoPE
q, k = self.rope(q, k, offset=0)
# Attention scores
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
# Causal mask
mask = tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) # Upper triangular
mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
scores = scores + mask[None, None, :, :]
attn = tf.nn.softmax(scores, axis=-1)
attn_out = tf.matmul(attn, v)
attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
attn_out = tf.reshape(attn_out, [B, T, self.d_model])
x = res + self.dropout(self.out_proj(attn_out), training=training)
# FFN
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
output = res + self.dropout(ffn, training=training)
return output, None # Return None for past_kv in this simplified version
def get_config(self):
config = super().get_config()
config.update({
"d_model": self.d_model,
"n_heads": self.n_heads,
"ff_dim": self.ff_dim,
"dropout": self.dropout_rate,
"max_len": self.max_len,
"rope_theta": self.rope_theta,
"layer_idx": self.layer_idx
})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {
'd_model': self.cfg['d_model'],
'n_heads': self.cfg['n_heads'],
'ff_dim': ff_dim,
'dropout': self.cfg['dropout'],
'max_len': self.cfg['max_len'],
'rope_theta': self.cfg['rope_theta']
}
self.blocks = [
TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
for i in range(self.cfg['n_layers'])
]
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
"""
Simplified call without full KV cache implementation
"""
x = self.embed(input_ids)
for block in self.blocks:
x, _ = block(x, training=training, past_kv=None, use_cache=False)
logits = self.lm_head(self.norm(x))
return logits, None # Return None for past_kv in this simplified version
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
def count_parameters(model):
"""Count model parameters"""
total_params = 0
for weight in model.weights:
w = weight.numpy()
total_params += w.size
return total_params