Keeby-smilyai commited on
Commit
ac2a3fe
Β·
verified Β·
1 Parent(s): 4e1d66f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -8
app.py CHANGED
@@ -23,7 +23,7 @@ MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
23
  CACHE_DIR = "./model_cache"
24
 
25
  # ============================================================================
26
- # Model Architecture Definitions (Required for Loading)
27
  # ============================================================================
28
 
29
  @keras.saving.register_keras_serializable()
@@ -35,13 +35,13 @@ class RotaryEmbedding(keras.layers.Layer):
35
  self.theta = theta
36
 
37
  def build(self, input_shape):
38
- # Compute embeddings using numpy, then convert to TF tensors
39
  inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
40
  t = np.arange(self.max_len, dtype=np.float32)
41
  freqs = np.outer(t, inv_freq)
42
  emb = np.concatenate([freqs, freqs], axis=-1)
43
 
44
- # Create non-trainable weights for cos and sin embeddings
45
  self.cos_cached = self.add_weight(
46
  name="cos_cached",
47
  shape=(self.max_len, self.dim),
@@ -225,7 +225,16 @@ print("βœ… Model architecture registered")
225
 
226
  # Download model files
227
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
228
- model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
229
 
230
  # Load config
231
  with open(config_path, 'r') as f:
@@ -257,8 +266,72 @@ if tokenizer.get_vocab_size() != config.get('vocab_size'):
257
 
258
  eos_token_id = config.get('eos_token_id', 50256)
259
 
260
- # Load model with TF function optimization
261
- model = keras.models.load_model(model_path, compile=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  # Create optimized inference function
264
  @tf.function(reduce_retracing=True)
@@ -619,11 +692,14 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
619
 
620
  with gr.Row():
621
  with gr.Column(scale=4):
622
- # Chat interface
623
  chatbot = gr.Chatbot(
624
  height=600,
625
  show_label=False,
626
- avatar_images=(None, "πŸ€–" if not FESTIVE else "πŸŽ‰"),
 
 
 
627
  bubble_full_width=False
628
  )
629
 
 
23
  CACHE_DIR = "./model_cache"
24
 
25
  # ============================================================================
26
+ # Model Architecture Definitions (FIXED for model loading)
27
  # ============================================================================
28
 
29
  @keras.saving.register_keras_serializable()
 
35
  self.theta = theta
36
 
37
  def build(self, input_shape):
38
+ # FIXED: Compute in numpy first to avoid symbolic tensor issues
39
  inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
40
  t = np.arange(self.max_len, dtype=np.float32)
41
  freqs = np.outer(t, inv_freq)
42
  emb = np.concatenate([freqs, freqs], axis=-1)
43
 
44
+ # Create as non-trainable weights instead of tf.constant
45
  self.cos_cached = self.add_weight(
46
  name="cos_cached",
47
  shape=(self.max_len, self.dim),
 
225
 
226
  # Download model files
227
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
228
+
229
+ # Try to download checkpoint weights first (more reliable)
230
+ try:
231
+ weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
232
+ print("βœ… Found checkpoint weights (ckpt.weights.h5)")
233
+ use_checkpoint = True
234
+ except Exception as e:
235
+ print(f"⚠️ Checkpoint not found, falling back to model.keras: {e}")
236
+ model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
237
+ use_checkpoint = False
238
 
239
  # Load config
240
  with open(config_path, 'r') as f:
 
266
 
267
  eos_token_id = config.get('eos_token_id', 50256)
268
 
269
+ # ==============================================================================
270
+ # Load Model - Priority: checkpoint weights > saved model
271
+ # ==============================================================================
272
+ print("\nπŸ”„ Loading model...")
273
+
274
+ if use_checkpoint:
275
+ print("πŸ“¦ Building model from config and loading checkpoint weights...")
276
+
277
+ # Build model from scratch with config
278
+ model_config = {
279
+ 'vocab_size': config['vocab_size'],
280
+ 'd_model': config['hidden_size'],
281
+ 'n_layers': config['num_hidden_layers'],
282
+ 'n_heads': config['num_attention_heads'],
283
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
284
+ 'max_len': config['max_position_embeddings'],
285
+ 'dropout': 0.1, # Default dropout
286
+ 'rope_theta': config['rope_theta']
287
+ }
288
+
289
+ model = SAM1Model(config=model_config)
290
+
291
+ # Build model by running a dummy forward pass
292
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
293
+ _ = model(dummy_input, training=False)
294
+
295
+ print(f"βœ… Model architecture built: {model.count_params():,} parameters")
296
+
297
+ # Load checkpoint weights
298
+ print(f"πŸ“₯ Loading checkpoint weights from: {weights_path}")
299
+ model.load_weights(weights_path)
300
+ print("βœ… Checkpoint weights loaded successfully!")
301
+
302
+ else:
303
+ print("πŸ“¦ Loading full saved model...")
304
+ try:
305
+ model = keras.models.load_model(model_path, compile=False)
306
+ print("βœ… Model loaded successfully")
307
+ except Exception as e:
308
+ print(f"❌ Failed to load model: {e}")
309
+ print("\nπŸ”„ Trying alternative: building from config + loading weights...")
310
+
311
+ # Fallback to building model
312
+ model_config = {
313
+ 'vocab_size': config['vocab_size'],
314
+ 'd_model': config['hidden_size'],
315
+ 'n_layers': config['num_hidden_layers'],
316
+ 'n_heads': config['num_attention_heads'],
317
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
318
+ 'max_len': config['max_position_embeddings'],
319
+ 'dropout': 0.1,
320
+ 'rope_theta': config['rope_theta']
321
+ }
322
+
323
+ model = SAM1Model(config=model_config)
324
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
325
+ _ = model(dummy_input, training=False)
326
+
327
+ # Try to load weights from model.keras
328
+ try:
329
+ temp_model = keras.models.load_model(model_path, compile=False)
330
+ model.set_weights(temp_model.get_weights())
331
+ print("βœ… Weights transferred successfully")
332
+ except:
333
+ print("❌ Could not load weights - model may not work correctly!")
334
+ raise
335
 
336
  # Create optimized inference function
337
  @tf.function(reduce_retracing=True)
 
692
 
693
  with gr.Row():
694
  with gr.Column(scale=4):
695
+ # Chat interface with bot avatar
696
  chatbot = gr.Chatbot(
697
  height=600,
698
  show_label=False,
699
+ avatar_images=(
700
+ None,
701
+ "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"
702
+ ),
703
  bubble_full_width=False
704
  )
705