Spaces:
Running
Running
fix policy model issues
Browse files- app/ml/policy_network.py +44 -17
app/ml/policy_network.py
CHANGED
|
@@ -396,30 +396,57 @@ def load_policy_model() -> PolicyNetwork:
|
|
| 396 |
hidden_size=hidden_size # β
Pass detected hidden size
|
| 397 |
)
|
| 398 |
|
| 399 |
-
# **KEY FIX**:
|
| 400 |
saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
|
| 401 |
current_vocab_size = len(POLICY_MODEL.tokenizer)
|
| 402 |
-
|
| 403 |
if saved_vocab_size != current_vocab_size:
|
| 404 |
print(f"β οΈ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
else:
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
POLICY_MODEL.eval()
|
| 420 |
-
|
| 421 |
-
|
| 422 |
POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
print("β
Policy network loaded and cached")
|
| 425 |
|
|
|
|
| 396 |
hidden_size=hidden_size # β
Pass detected hidden size
|
| 397 |
)
|
| 398 |
|
| 399 |
+
# **KEY FIX**: Handle vocab size mismatch
|
| 400 |
saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
|
| 401 |
current_vocab_size = len(POLICY_MODEL.tokenizer)
|
| 402 |
+
|
| 403 |
if saved_vocab_size != current_vocab_size:
|
| 404 |
print(f"β οΈ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
|
| 405 |
+
|
| 406 |
+
if abs(saved_vocab_size - current_vocab_size) <= 2:
|
| 407 |
+
# Small difference - just load with strict=False
|
| 408 |
+
print(f"β
Loading with strict=False to handle minor vocab differences...")
|
| 409 |
+
|
| 410 |
+
# Move to device first
|
| 411 |
+
POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 412 |
+
|
| 413 |
+
# Load weights with strict=False
|
| 414 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 415 |
+
POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 416 |
+
else:
|
| 417 |
+
POLICY_MODEL.load_state_dict(checkpoint, strict=False)
|
| 418 |
+
else:
|
| 419 |
+
# Large difference - resize properly
|
| 420 |
+
print(f"β
Resizing model to match saved vocab size...")
|
| 421 |
+
POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
|
| 422 |
+
|
| 423 |
+
# Move to device
|
| 424 |
+
POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 425 |
+
|
| 426 |
+
# Load weights
|
| 427 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 428 |
+
POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 429 |
+
else:
|
| 430 |
+
POLICY_MODEL.load_state_dict(checkpoint)
|
| 431 |
else:
|
| 432 |
+
# No mismatch
|
| 433 |
+
POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 434 |
+
|
| 435 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 436 |
+
POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 437 |
+
else:
|
| 438 |
+
POLICY_MODEL.load_state_dict(checkpoint)
|
| 439 |
+
|
| 440 |
+
# Set to evaluation mode
|
| 441 |
POLICY_MODEL.eval()
|
| 442 |
+
|
| 443 |
+
# Cache tokenizer
|
| 444 |
POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
| 445 |
+
|
| 446 |
+
print("β
Policy network loaded and cached")
|
| 447 |
+
print(f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}")
|
| 448 |
+
print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}")
|
| 449 |
+
|
| 450 |
|
| 451 |
print("β
Policy network loaded and cached")
|
| 452 |
|