eeshanyaj commited on
Commit
3986a1e
Β·
1 Parent(s): 7ed4c17

fix policy model issues

Browse files
Files changed (1) hide show
  1. app/ml/policy_network.py +44 -17
app/ml/policy_network.py CHANGED
@@ -396,30 +396,57 @@ def load_policy_model() -> PolicyNetwork:
396
  hidden_size=hidden_size # βœ… Pass detected hidden size
397
  )
398
 
399
- # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
400
  saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
401
  current_vocab_size = len(POLICY_MODEL.tokenizer)
402
-
403
  if saved_vocab_size != current_vocab_size:
404
  print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
405
- print(f"βœ… Resizing tokenizer and embeddings to match saved model...")
406
- # Resize model to match saved checkpoint
407
- POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
408
-
409
- # Move to device
410
- POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
411
-
412
- # Now load trained weights (sizes and architecture will match!)
413
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
414
- POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  else:
416
- POLICY_MODEL.load_state_dict(checkpoint)
417
-
418
- # Set to evaluation mode
 
 
 
 
 
 
419
  POLICY_MODEL.eval()
420
-
421
- # Cache tokenizer
422
  POLICY_TOKENIZER = POLICY_MODEL.tokenizer
 
 
 
 
 
423
 
424
  print("βœ… Policy network loaded and cached")
425
 
 
396
  hidden_size=hidden_size # βœ… Pass detected hidden size
397
  )
398
 
399
+ # **KEY FIX**: Handle vocab size mismatch
400
  saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
401
  current_vocab_size = len(POLICY_MODEL.tokenizer)
402
+
403
  if saved_vocab_size != current_vocab_size:
404
  print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
405
+
406
+ if abs(saved_vocab_size - current_vocab_size) <= 2:
407
+ # Small difference - just load with strict=False
408
+ print(f"βœ… Loading with strict=False to handle minor vocab differences...")
409
+
410
+ # Move to device first
411
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
412
+
413
+ # Load weights with strict=False
414
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
415
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'], strict=False)
416
+ else:
417
+ POLICY_MODEL.load_state_dict(checkpoint, strict=False)
418
+ else:
419
+ # Large difference - resize properly
420
+ print(f"βœ… Resizing model to match saved vocab size...")
421
+ POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
422
+
423
+ # Move to device
424
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
425
+
426
+ # Load weights
427
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
428
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
429
+ else:
430
+ POLICY_MODEL.load_state_dict(checkpoint)
431
  else:
432
+ # No mismatch
433
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
434
+
435
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
436
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
437
+ else:
438
+ POLICY_MODEL.load_state_dict(checkpoint)
439
+
440
+ # Set to evaluation mode
441
  POLICY_MODEL.eval()
442
+
443
+ # Cache tokenizer
444
  POLICY_TOKENIZER = POLICY_MODEL.tokenizer
445
+
446
+ print("βœ… Policy network loaded and cached")
447
+ print(f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}")
448
+ print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}")
449
+
450
 
451
  print("βœ… Policy network loaded and cached")
452