Spaces:
Running
Running
fix policy model issues
Browse files- app/ml/policy_network.py +145 -22
app/ml/policy_network.py
CHANGED
|
@@ -40,41 +40,82 @@ class PolicyNetwork(nn.Module):
|
|
| 40 |
* NO_FETCH + Bad: -0.5
|
| 41 |
"""
|
| 42 |
|
| 43 |
-
def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
|
| 54 |
self.tokenizer.add_special_tokens(special_tokens)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
self.bert.resize_token_embeddings(len(self.tokenizer))
|
| 58 |
-
|
| 59 |
-
|
| 60 |
self._init_action_embeddings()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
if use_multilayer:
|
| 64 |
-
|
| 65 |
self.classifier = nn.Sequential(
|
| 66 |
-
nn.Linear(self.bert.config.hidden_size,
|
| 67 |
nn.ReLU(),
|
| 68 |
nn.Dropout(dropout_rate),
|
| 69 |
-
nn.Linear(
|
| 70 |
)
|
| 71 |
else:
|
| 72 |
-
|
| 73 |
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
|
| 74 |
-
|
| 75 |
-
# Dropout for regularization
|
| 76 |
-
self.dropout = nn.Dropout(dropout_rate)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _init_action_embeddings(self):
|
| 79 |
"""
|
| 80 |
Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
|
|
@@ -239,6 +280,80 @@ POLICY_MODEL: Optional[PolicyNetwork] = None
|
|
| 239 |
POLICY_TOKENIZER: Optional[AutoTokenizer] = None
|
| 240 |
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
def load_policy_model() -> PolicyNetwork:
|
| 243 |
"""
|
| 244 |
Load trained policy model (called once on startup).
|
|
@@ -265,13 +380,20 @@ def load_policy_model() -> PolicyNetwork:
|
|
| 265 |
# β
AUTO-DETECT ARCHITECTURE from checkpoint keys
|
| 266 |
has_multilayer = "classifier.0.weight" in checkpoint
|
| 267 |
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Create model instance with correct architecture
|
| 271 |
POLICY_MODEL = PolicyNetwork(
|
| 272 |
model_name="bert-base-uncased",
|
| 273 |
dropout_rate=0.1,
|
| 274 |
-
use_multilayer=has_multilayer
|
|
|
|
| 275 |
)
|
| 276 |
|
| 277 |
# **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
|
|
@@ -314,6 +436,7 @@ def load_policy_model() -> PolicyNetwork:
|
|
| 314 |
return POLICY_MODEL
|
| 315 |
|
| 316 |
|
|
|
|
| 317 |
# ============================================================================
|
| 318 |
# PREDICTION FUNCTIONS
|
| 319 |
# ============================================================================
|
|
|
|
| 40 |
* NO_FETCH + Bad: -0.5
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
# def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
|
| 44 |
+
# super(PolicyNetwork, self).__init__()
|
| 45 |
|
| 46 |
+
# # Load pre-trained BERT
|
| 47 |
+
# self.bert = AutoModel.from_pretrained(model_name)
|
| 48 |
|
| 49 |
+
# # Load tokenizer
|
| 50 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
+
|
| 52 |
+
# # Add special tokens for actions: [FETCH] and [NO_FETCH]
|
| 53 |
+
# special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
|
| 54 |
+
# self.tokenizer.add_special_tokens(special_tokens)
|
| 55 |
+
|
| 56 |
+
# # Resize BERT embeddings to accommodate new tokens
|
| 57 |
+
# self.bert.resize_token_embeddings(len(self.tokenizer))
|
| 58 |
+
|
| 59 |
+
# # Initialize random embeddings for special tokens
|
| 60 |
+
# self._init_action_embeddings()
|
| 61 |
|
| 62 |
+
# # β
FLEXIBLE CLASSIFIER ARCHITECTURE
|
| 63 |
+
# if use_multilayer:
|
| 64 |
+
# # Multi-layer classifier (your new trained model)
|
| 65 |
+
# self.classifier = nn.Sequential(
|
| 66 |
+
# nn.Linear(self.bert.config.hidden_size, 256),
|
| 67 |
+
# nn.ReLU(),
|
| 68 |
+
# nn.Dropout(dropout_rate),
|
| 69 |
+
# nn.Linear(256, 2)
|
| 70 |
+
# )
|
| 71 |
+
# else:
|
| 72 |
+
# # Single-layer classifier (fallback)
|
| 73 |
+
# self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
|
| 74 |
+
|
| 75 |
+
# # Dropout for regularization
|
| 76 |
+
# self.dropout = nn.Dropout(dropout_rate)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True, hidden_size: int = 128):
|
| 80 |
+
super(PolicyNetwork, self).__init__()
|
| 81 |
+
|
| 82 |
+
# Load pre-trained BERT
|
| 83 |
+
self.bert = AutoModel.from_pretrained(model_name)
|
| 84 |
+
|
| 85 |
+
# Load tokenizer
|
| 86 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 87 |
+
|
| 88 |
+
# Add special tokens for actions: [FETCH] and [NO_FETCH]
|
| 89 |
special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
|
| 90 |
self.tokenizer.add_special_tokens(special_tokens)
|
| 91 |
+
|
| 92 |
+
# Resize BERT embeddings to accommodate new tokens
|
| 93 |
self.bert.resize_token_embeddings(len(self.tokenizer))
|
| 94 |
+
|
| 95 |
+
# Initialize random embeddings for special tokens
|
| 96 |
self._init_action_embeddings()
|
| 97 |
+
|
| 98 |
+
# β
FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size)
|
| 99 |
if use_multilayer:
|
| 100 |
+
# Multi-layer classifier with specified hidden size (128 or 256)
|
| 101 |
self.classifier = nn.Sequential(
|
| 102 |
+
nn.Linear(self.bert.config.hidden_size, hidden_size), # β
Use hidden_size param
|
| 103 |
nn.ReLU(),
|
| 104 |
nn.Dropout(dropout_rate),
|
| 105 |
+
nn.Linear(hidden_size, 2) # β
Use hidden_size param
|
| 106 |
)
|
| 107 |
else:
|
| 108 |
+
# Single-layer classifier (fallback)
|
| 109 |
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Dropout for regularization
|
| 112 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def _init_action_embeddings(self):
|
| 120 |
"""
|
| 121 |
Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
|
|
|
|
| 280 |
POLICY_TOKENIZER: Optional[AutoTokenizer] = None
|
| 281 |
|
| 282 |
|
| 283 |
+
# def load_policy_model() -> PolicyNetwork:
|
| 284 |
+
# """
|
| 285 |
+
# Load trained policy model (called once on startup).
|
| 286 |
+
# Downloads from HuggingFace Hub if not present locally.
|
| 287 |
+
# Uses module-level caching - model stays in RAM.
|
| 288 |
+
|
| 289 |
+
# Returns:
|
| 290 |
+
# PolicyNetwork: Loaded policy model
|
| 291 |
+
# """
|
| 292 |
+
# global POLICY_MODEL, POLICY_TOKENIZER
|
| 293 |
+
|
| 294 |
+
# if POLICY_MODEL is None:
|
| 295 |
+
# # Download model from HF Hub if needed (for deployment)
|
| 296 |
+
# settings.download_model_if_needed(
|
| 297 |
+
# hf_filename="models/policy_query_only.pt",
|
| 298 |
+
# local_path=settings.POLICY_MODEL_PATH
|
| 299 |
+
# )
|
| 300 |
+
|
| 301 |
+
# print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
|
| 302 |
+
# try:
|
| 303 |
+
# # Load checkpoint first to detect architecture
|
| 304 |
+
# checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
|
| 305 |
+
|
| 306 |
+
# # β
AUTO-DETECT ARCHITECTURE from checkpoint keys
|
| 307 |
+
# has_multilayer = "classifier.0.weight" in checkpoint
|
| 308 |
+
|
| 309 |
+
# print(f"π Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
|
| 310 |
+
|
| 311 |
+
# # Create model instance with correct architecture
|
| 312 |
+
# POLICY_MODEL = PolicyNetwork(
|
| 313 |
+
# model_name="bert-base-uncased",
|
| 314 |
+
# dropout_rate=0.1,
|
| 315 |
+
# use_multilayer=has_multilayer # β
Auto-detect!
|
| 316 |
+
# )
|
| 317 |
+
|
| 318 |
+
# # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
|
| 319 |
+
# saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
|
| 320 |
+
# current_vocab_size = len(POLICY_MODEL.tokenizer)
|
| 321 |
+
|
| 322 |
+
# if saved_vocab_size != current_vocab_size:
|
| 323 |
+
# print(f"β οΈ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
|
| 324 |
+
# print(f"β
Resizing tokenizer and embeddings to match saved model...")
|
| 325 |
+
# # Resize model to match saved checkpoint
|
| 326 |
+
# POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
|
| 327 |
+
|
| 328 |
+
# # Move to device
|
| 329 |
+
# POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 330 |
+
|
| 331 |
+
# # Now load trained weights (sizes and architecture will match!)
|
| 332 |
+
# if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 333 |
+
# POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 334 |
+
# else:
|
| 335 |
+
# POLICY_MODEL.load_state_dict(checkpoint)
|
| 336 |
+
|
| 337 |
+
# # Set to evaluation mode
|
| 338 |
+
# POLICY_MODEL.eval()
|
| 339 |
+
|
| 340 |
+
# # Cache tokenizer
|
| 341 |
+
# POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
| 342 |
+
|
| 343 |
+
# print("β
Policy network loaded and cached")
|
| 344 |
+
|
| 345 |
+
# except FileNotFoundError:
|
| 346 |
+
# print(f"β Policy model file not found: {settings.POLICY_MODEL_PATH}")
|
| 347 |
+
# print(f"β οΈ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
|
| 348 |
+
# raise
|
| 349 |
+
# except Exception as e:
|
| 350 |
+
# print(f"β Failed to load policy model: {e}")
|
| 351 |
+
# import traceback
|
| 352 |
+
# traceback.print_exc()
|
| 353 |
+
# raise
|
| 354 |
+
|
| 355 |
+
# return POLICY_MODEL
|
| 356 |
+
|
| 357 |
def load_policy_model() -> PolicyNetwork:
|
| 358 |
"""
|
| 359 |
Load trained policy model (called once on startup).
|
|
|
|
| 380 |
# β
AUTO-DETECT ARCHITECTURE from checkpoint keys
|
| 381 |
has_multilayer = "classifier.0.weight" in checkpoint
|
| 382 |
|
| 383 |
+
# β
AUTO-DETECT HIDDEN SIZE from checkpoint
|
| 384 |
+
if has_multilayer:
|
| 385 |
+
hidden_size = checkpoint['classifier.0.weight'].shape[0] # Get output size of first layer
|
| 386 |
+
print(f"π Detected: Multi-layer classifier (hidden_size={hidden_size})")
|
| 387 |
+
else:
|
| 388 |
+
hidden_size = 768 # Doesn't matter for single-layer
|
| 389 |
+
print(f"π Detected: Single-layer classifier")
|
| 390 |
|
| 391 |
# Create model instance with correct architecture
|
| 392 |
POLICY_MODEL = PolicyNetwork(
|
| 393 |
model_name="bert-base-uncased",
|
| 394 |
dropout_rate=0.1,
|
| 395 |
+
use_multilayer=has_multilayer,
|
| 396 |
+
hidden_size=hidden_size # β
Pass detected hidden size
|
| 397 |
)
|
| 398 |
|
| 399 |
# **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
|
|
|
|
| 436 |
return POLICY_MODEL
|
| 437 |
|
| 438 |
|
| 439 |
+
|
| 440 |
# ============================================================================
|
| 441 |
# PREDICTION FUNCTIONS
|
| 442 |
# ============================================================================
|