eeshanyaj commited on
Commit
7ed4c17
Β·
1 Parent(s): 52d7b60

fix policy model issues

Browse files
Files changed (1) hide show
  1. app/ml/policy_network.py +145 -22
app/ml/policy_network.py CHANGED
@@ -40,41 +40,82 @@ class PolicyNetwork(nn.Module):
40
  * NO_FETCH + Bad: -0.5
41
  """
42
 
43
- def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
44
- super(PolicyNetwork, self).__init__()
45
 
46
- # Load pre-trained BERT
47
- self.bert = AutoModel.from_pretrained(model_name)
48
 
49
- # Load tokenizer
50
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Add special tokens for actions: [FETCH] and [NO_FETCH]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
54
  self.tokenizer.add_special_tokens(special_tokens)
55
-
56
- # Resize BERT embeddings to accommodate new tokens
57
  self.bert.resize_token_embeddings(len(self.tokenizer))
58
-
59
- # Initialize random embeddings for special tokens
60
  self._init_action_embeddings()
61
-
62
- # βœ… FLEXIBLE CLASSIFIER ARCHITECTURE
63
  if use_multilayer:
64
- # Multi-layer classifier (your new trained model)
65
  self.classifier = nn.Sequential(
66
- nn.Linear(self.bert.config.hidden_size, 256),
67
  nn.ReLU(),
68
  nn.Dropout(dropout_rate),
69
- nn.Linear(256, 2)
70
  )
71
  else:
72
- # Single-layer classifier (fallback)
73
  self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
74
-
75
- # Dropout for regularization
76
- self.dropout = nn.Dropout(dropout_rate)
77
 
 
 
 
 
 
 
 
 
78
  def _init_action_embeddings(self):
79
  """
80
  Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
@@ -239,6 +280,80 @@ POLICY_MODEL: Optional[PolicyNetwork] = None
239
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
240
 
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def load_policy_model() -> PolicyNetwork:
243
  """
244
  Load trained policy model (called once on startup).
@@ -265,13 +380,20 @@ def load_policy_model() -> PolicyNetwork:
265
  # βœ… AUTO-DETECT ARCHITECTURE from checkpoint keys
266
  has_multilayer = "classifier.0.weight" in checkpoint
267
 
268
- print(f"πŸ“Š Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
 
 
 
 
 
 
269
 
270
  # Create model instance with correct architecture
271
  POLICY_MODEL = PolicyNetwork(
272
  model_name="bert-base-uncased",
273
  dropout_rate=0.1,
274
- use_multilayer=has_multilayer # βœ… Auto-detect!
 
275
  )
276
 
277
  # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
@@ -314,6 +436,7 @@ def load_policy_model() -> PolicyNetwork:
314
  return POLICY_MODEL
315
 
316
 
 
317
  # ============================================================================
318
  # PREDICTION FUNCTIONS
319
  # ============================================================================
 
40
  * NO_FETCH + Bad: -0.5
41
  """
42
 
43
+ # def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
44
+ # super(PolicyNetwork, self).__init__()
45
 
46
+ # # Load pre-trained BERT
47
+ # self.bert = AutoModel.from_pretrained(model_name)
48
 
49
+ # # Load tokenizer
50
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ # # Add special tokens for actions: [FETCH] and [NO_FETCH]
53
+ # special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
54
+ # self.tokenizer.add_special_tokens(special_tokens)
55
+
56
+ # # Resize BERT embeddings to accommodate new tokens
57
+ # self.bert.resize_token_embeddings(len(self.tokenizer))
58
+
59
+ # # Initialize random embeddings for special tokens
60
+ # self._init_action_embeddings()
61
 
62
+ # # βœ… FLEXIBLE CLASSIFIER ARCHITECTURE
63
+ # if use_multilayer:
64
+ # # Multi-layer classifier (your new trained model)
65
+ # self.classifier = nn.Sequential(
66
+ # nn.Linear(self.bert.config.hidden_size, 256),
67
+ # nn.ReLU(),
68
+ # nn.Dropout(dropout_rate),
69
+ # nn.Linear(256, 2)
70
+ # )
71
+ # else:
72
+ # # Single-layer classifier (fallback)
73
+ # self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
74
+
75
+ # # Dropout for regularization
76
+ # self.dropout = nn.Dropout(dropout_rate)
77
+
78
+
79
+ def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True, hidden_size: int = 128):
80
+ super(PolicyNetwork, self).__init__()
81
+
82
+ # Load pre-trained BERT
83
+ self.bert = AutoModel.from_pretrained(model_name)
84
+
85
+ # Load tokenizer
86
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
87
+
88
+ # Add special tokens for actions: [FETCH] and [NO_FETCH]
89
  special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
90
  self.tokenizer.add_special_tokens(special_tokens)
91
+
92
+ # Resize BERT embeddings to accommodate new tokens
93
  self.bert.resize_token_embeddings(len(self.tokenizer))
94
+
95
+ # Initialize random embeddings for special tokens
96
  self._init_action_embeddings()
97
+
98
+ # βœ… FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size)
99
  if use_multilayer:
100
+ # Multi-layer classifier with specified hidden size (128 or 256)
101
  self.classifier = nn.Sequential(
102
+ nn.Linear(self.bert.config.hidden_size, hidden_size), # βœ… Use hidden_size param
103
  nn.ReLU(),
104
  nn.Dropout(dropout_rate),
105
+ nn.Linear(hidden_size, 2) # βœ… Use hidden_size param
106
  )
107
  else:
108
+ # Single-layer classifier (fallback)
109
  self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
 
 
 
110
 
111
+ # Dropout for regularization
112
+ self.dropout = nn.Dropout(dropout_rate)
113
+
114
+
115
+
116
+
117
+
118
+
119
  def _init_action_embeddings(self):
120
  """
121
  Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
 
280
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
281
 
282
 
283
+ # def load_policy_model() -> PolicyNetwork:
284
+ # """
285
+ # Load trained policy model (called once on startup).
286
+ # Downloads from HuggingFace Hub if not present locally.
287
+ # Uses module-level caching - model stays in RAM.
288
+
289
+ # Returns:
290
+ # PolicyNetwork: Loaded policy model
291
+ # """
292
+ # global POLICY_MODEL, POLICY_TOKENIZER
293
+
294
+ # if POLICY_MODEL is None:
295
+ # # Download model from HF Hub if needed (for deployment)
296
+ # settings.download_model_if_needed(
297
+ # hf_filename="models/policy_query_only.pt",
298
+ # local_path=settings.POLICY_MODEL_PATH
299
+ # )
300
+
301
+ # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
302
+ # try:
303
+ # # Load checkpoint first to detect architecture
304
+ # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
305
+
306
+ # # βœ… AUTO-DETECT ARCHITECTURE from checkpoint keys
307
+ # has_multilayer = "classifier.0.weight" in checkpoint
308
+
309
+ # print(f"πŸ“Š Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
310
+
311
+ # # Create model instance with correct architecture
312
+ # POLICY_MODEL = PolicyNetwork(
313
+ # model_name="bert-base-uncased",
314
+ # dropout_rate=0.1,
315
+ # use_multilayer=has_multilayer # βœ… Auto-detect!
316
+ # )
317
+
318
+ # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
319
+ # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
320
+ # current_vocab_size = len(POLICY_MODEL.tokenizer)
321
+
322
+ # if saved_vocab_size != current_vocab_size:
323
+ # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
324
+ # print(f"βœ… Resizing tokenizer and embeddings to match saved model...")
325
+ # # Resize model to match saved checkpoint
326
+ # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
327
+
328
+ # # Move to device
329
+ # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
330
+
331
+ # # Now load trained weights (sizes and architecture will match!)
332
+ # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
333
+ # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
334
+ # else:
335
+ # POLICY_MODEL.load_state_dict(checkpoint)
336
+
337
+ # # Set to evaluation mode
338
+ # POLICY_MODEL.eval()
339
+
340
+ # # Cache tokenizer
341
+ # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
342
+
343
+ # print("βœ… Policy network loaded and cached")
344
+
345
+ # except FileNotFoundError:
346
+ # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
347
+ # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
348
+ # raise
349
+ # except Exception as e:
350
+ # print(f"❌ Failed to load policy model: {e}")
351
+ # import traceback
352
+ # traceback.print_exc()
353
+ # raise
354
+
355
+ # return POLICY_MODEL
356
+
357
  def load_policy_model() -> PolicyNetwork:
358
  """
359
  Load trained policy model (called once on startup).
 
380
  # βœ… AUTO-DETECT ARCHITECTURE from checkpoint keys
381
  has_multilayer = "classifier.0.weight" in checkpoint
382
 
383
+ # βœ… AUTO-DETECT HIDDEN SIZE from checkpoint
384
+ if has_multilayer:
385
+ hidden_size = checkpoint['classifier.0.weight'].shape[0] # Get output size of first layer
386
+ print(f"πŸ“Š Detected: Multi-layer classifier (hidden_size={hidden_size})")
387
+ else:
388
+ hidden_size = 768 # Doesn't matter for single-layer
389
+ print(f"πŸ“Š Detected: Single-layer classifier")
390
 
391
  # Create model instance with correct architecture
392
  POLICY_MODEL = PolicyNetwork(
393
  model_name="bert-base-uncased",
394
  dropout_rate=0.1,
395
+ use_multilayer=has_multilayer,
396
+ hidden_size=hidden_size # βœ… Pass detected hidden size
397
  )
398
 
399
  # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
 
436
  return POLICY_MODEL
437
 
438
 
439
+
440
  # ============================================================================
441
  # PREDICTION FUNCTIONS
442
  # ============================================================================