Keeby-smilyai commited on
Commit
d243125
·
verified ·
1 Parent(s): 9822eb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py CHANGED
@@ -21,6 +21,195 @@ print("🚀 Loading SAM-Z-1 Model...")
21
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
22
  CACHE_DIR = "./model_cache"
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Download model files
25
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
26
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
 
21
  MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
22
  CACHE_DIR = "./model_cache"
23
 
24
+ # ============================================================================
25
+ # Model Architecture Definitions (Required for Loading)
26
+ # ============================================================================
27
+
28
+ @tf.keras.saving.register_keras_serializable()
29
+ class RotaryEmbedding(tf.keras.layers.Layer):
30
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
31
+ super().__init__(**kwargs)
32
+ self.dim = dim
33
+ self.max_len = max_len
34
+ self.theta = theta
35
+ self.built_cache = False
36
+
37
+ def build(self, input_shape):
38
+ if not self.built_cache:
39
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
40
+ t = tf.range(self.max_len, dtype=tf.float32)
41
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
42
+ emb = tf.concat([freqs, freqs], axis=-1)
43
+
44
+ self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
45
+ self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
46
+ self.built_cache = True
47
+
48
+ super().build(input_shape)
49
+
50
+ def rotate_half(self, x):
51
+ x1, x2 = tf.split(x, 2, axis=-1)
52
+ return tf.concat([-x2, x1], axis=-1)
53
+
54
+ def call(self, q, k):
55
+ seq_len = tf.shape(q)[2]
56
+ dtype = q.dtype
57
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
58
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
59
+
60
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
61
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
62
+
63
+ return q_rotated, k_rotated
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
68
+ return config
69
+
70
+
71
+ @tf.keras.saving.register_keras_serializable()
72
+ class RMSNorm(tf.keras.layers.Layer):
73
+ def __init__(self, epsilon=1e-5, **kwargs):
74
+ super().__init__(**kwargs)
75
+ self.epsilon = epsilon
76
+
77
+ def build(self, input_shape):
78
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
79
+
80
+ def call(self, x):
81
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
82
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
83
+
84
+ def get_config(self):
85
+ config = super().get_config()
86
+ config.update({"epsilon": self.epsilon})
87
+ return config
88
+
89
+
90
+ @tf.keras.saving.register_keras_serializable()
91
+ class TransformerBlock(tf.keras.layers.Layer):
92
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
93
+ super().__init__(**kwargs)
94
+ self.d_model = d_model
95
+ self.n_heads = n_heads
96
+ self.ff_dim = ff_dim
97
+ self.dropout_rate = dropout
98
+ self.max_len = max_len
99
+ self.rope_theta = rope_theta
100
+ self.head_dim = d_model // n_heads
101
+ self.layer_idx = layer_idx
102
+
103
+ self.pre_attn_norm = RMSNorm()
104
+ self.pre_ffn_norm = RMSNorm()
105
+
106
+ self.q_proj = tf.keras.layers.Dense(d_model, use_bias=False, name="q_proj")
107
+ self.k_proj = tf.keras.layers.Dense(d_model, use_bias=False, name="k_proj")
108
+ self.v_proj = tf.keras.layers.Dense(d_model, use_bias=False, name="v_proj")
109
+ self.out_proj = tf.keras.layers.Dense(d_model, use_bias=False, name="o_proj")
110
+
111
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
112
+
113
+ self.gate_proj = tf.keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
114
+ self.up_proj = tf.keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
115
+ self.down_proj = tf.keras.layers.Dense(d_model, use_bias=False, name="down_proj")
116
+
117
+ self.dropout = tf.keras.layers.Dropout(dropout)
118
+
119
+ def call(self, x, training=None):
120
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
121
+ dtype = x.dtype
122
+
123
+ # Attention
124
+ res = x
125
+ y = self.pre_attn_norm(x)
126
+
127
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
128
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
129
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
130
+
131
+ q, k = self.rope(q, k)
132
+
133
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
134
+
135
+ mask = tf.where(
136
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
137
+ tf.constant(-1e9, dtype=dtype),
138
+ tf.constant(0.0, dtype=dtype)
139
+ )
140
+ scores += mask
141
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
142
+
143
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
144
+ x = res + self.dropout(self.out_proj(attn), training=training)
145
+
146
+ # FFN (SwiGLU)
147
+ res = x
148
+ y = self.pre_ffn_norm(x)
149
+ ffn = self.down_proj(tf.keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
150
+
151
+ return res + self.dropout(ffn, training=training)
152
+
153
+ def get_config(self):
154
+ config = super().get_config()
155
+ config.update({
156
+ "d_model": self.d_model,
157
+ "n_heads": self.n_heads,
158
+ "ff_dim": self.ff_dim,
159
+ "dropout": self.dropout_rate,
160
+ "max_len": self.max_len,
161
+ "rope_theta": self.rope_theta,
162
+ "layer_idx": self.layer_idx
163
+ })
164
+ return config
165
+
166
+
167
+ @tf.keras.saving.register_keras_serializable()
168
+ class SAM1Model(tf.keras.Model):
169
+ def __init__(self, **kwargs):
170
+ super().__init__()
171
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
172
+ self.cfg = kwargs['config']
173
+ elif 'vocab_size' in kwargs:
174
+ self.cfg = kwargs
175
+ else:
176
+ self.cfg = kwargs.get('cfg', kwargs)
177
+
178
+ self.embed = tf.keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
179
+
180
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
181
+ block_args = {
182
+ 'd_model': self.cfg['d_model'],
183
+ 'n_heads': self.cfg['n_heads'],
184
+ 'ff_dim': ff_dim,
185
+ 'dropout': self.cfg['dropout'],
186
+ 'max_len': self.cfg['max_len'],
187
+ 'rope_theta': self.cfg['rope_theta']
188
+ }
189
+
190
+ self.blocks = []
191
+ for i in range(self.cfg['n_layers']):
192
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
193
+ self.blocks.append(block)
194
+
195
+ self.norm = RMSNorm(name="final_norm")
196
+ self.lm_head = tf.keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
197
+
198
+ def call(self, input_ids, training=None):
199
+ x = self.embed(input_ids)
200
+
201
+ for block in self.blocks:
202
+ x = block(x, training=training)
203
+
204
+ return self.lm_head(self.norm(x))
205
+
206
+ def get_config(self):
207
+ base_config = super().get_config()
208
+ base_config['config'] = self.cfg
209
+ return base_config
210
+
211
+ print("✅ Model architecture registered")
212
+
213
  # Download model files
214
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
215
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)