eeshanyaj commited on
Commit
52d7b60
·
1 Parent(s): 42fe10b

fixed policy model weights error

Browse files
app/ml/backup_policynetwork.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT-based Policy Network for FETCH/NO_FETCH decisions
3
+ Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization)
4
+
5
+ This is adapted from your RL.py with:
6
+ - PolicyNetwork class (BERT-based)
7
+ - State encoding from conversation history
8
+ - Action prediction (FETCH vs NO_FETCH)
9
+ - Module-level caching (load once on startup)
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from typing import List, Dict, Optional, Tuple
17
+ from transformers import AutoTokenizer, AutoModel
18
+
19
+ from app.config import settings
20
+
21
+
22
+ # ============================================================================
23
+ # POLICY NETWORK (From RL.py)
24
+ # ============================================================================
25
+
26
+ class PolicyNetwork(nn.Module):
27
+ """
28
+ BERT-based Policy Network for deciding FETCH vs NO_FETCH actions.
29
+
30
+ Architecture:
31
+ - Base: BERT-base-uncased (pre-trained)
32
+ - Input: Current query + conversation history + previous actions
33
+ - Output: 2-class softmax (FETCH=0, NO_FETCH=1)
34
+ - Special tokens: [FETCH], [NO_FETCH] for action encoding
35
+
36
+ Training Details:
37
+ - Loss: Policy Gradient + Entropy Regularization
38
+ - Optimizer: AdamW
39
+ - Reward structure:
40
+ * FETCH: +0.5 (always)
41
+ * NO_FETCH + Good: +2.0
42
+ * NO_FETCH + Bad: -0.5
43
+ """
44
+
45
+ def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1):
46
+ super(PolicyNetwork, self).__init__()
47
+
48
+ # Load pre-trained BERT
49
+ self.bert = AutoModel.from_pretrained(model_name)
50
+
51
+ # Load tokenizer
52
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
53
+
54
+ # Add special tokens for actions: [FETCH] and [NO_FETCH]
55
+ special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
56
+ self.tokenizer.add_special_tokens(special_tokens)
57
+
58
+ # Resize BERT embeddings to accommodate new tokens
59
+ self.bert.resize_token_embeddings(len(self.tokenizer))
60
+
61
+ # Initialize random embeddings for special tokens
62
+ self._init_action_embeddings()
63
+
64
+ # Classification head: BERT hidden size (768) → 2 classes
65
+ self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
66
+
67
+ # Dropout for regularization
68
+ self.dropout = nn.Dropout(dropout_rate)
69
+
70
+ def _init_action_embeddings(self):
71
+ """
72
+ Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
73
+ These are learned during training.
74
+ """
75
+ with torch.no_grad():
76
+ # Get token IDs for special tokens
77
+ fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]")
78
+ no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]")
79
+
80
+ # Get embedding dimension
81
+ embedding_dim = self.bert.config.hidden_size
82
+
83
+ # Initialize with small random values (same as BERT initialization)
84
+ self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02
85
+ self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02
86
+
87
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ """
89
+ Forward pass through BERT + classifier.
90
+
91
+ Args:
92
+ input_ids: Tokenized input IDs (shape: [batch_size, seq_len])
93
+ attention_mask: Attention mask (shape: [batch_size, seq_len])
94
+
95
+ Returns:
96
+ logits: Raw logits (shape: [batch_size, 2])
97
+ probs: Softmax probabilities (shape: [batch_size, 2])
98
+ """
99
+ # Pass through BERT
100
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
101
+
102
+ # Extract [CLS] token representation (first token)
103
+ cls_output = outputs.last_hidden_state[:, 0, :]
104
+
105
+ # Apply dropout
106
+ cls_output = self.dropout(cls_output)
107
+
108
+ # Classification
109
+ logits = self.classifier(cls_output)
110
+
111
+ # Softmax for probabilities
112
+ probs = F.softmax(logits, dim=-1)
113
+
114
+ return logits, probs
115
+
116
+ def encode_state(
117
+ self,
118
+ state: Dict,
119
+ max_length: int = None
120
+ ) -> Dict[str, torch.Tensor]:
121
+ """
122
+ Encode conversation state into BERT input format.
123
+
124
+ State structure:
125
+ {
126
+ 'previous_queries': [query1, query2, ...],
127
+ 'previous_actions': ['FETCH', 'NO_FETCH', ...],
128
+ 'current_query': 'user query'
129
+ }
130
+
131
+ Encoding format:
132
+ "Previous query 1: <text> [Action: [FETCH]] Previous query 2: <text> [Action: [NO_FETCH]] Current query: <text>"
133
+
134
+ Args:
135
+ state: State dictionary
136
+ max_length: Maximum sequence length (default from config)
137
+
138
+ Returns:
139
+ dict: Tokenized inputs (input_ids, attention_mask)
140
+ """
141
+ if max_length is None:
142
+ max_length = settings.POLICY_MAX_LEN
143
+
144
+ # Build state text from conversation history
145
+ state_text = ""
146
+
147
+ # Add previous queries and their actions
148
+ prev_queries = state.get('previous_queries', [])
149
+ prev_actions = state.get('previous_actions', [])
150
+
151
+ if prev_queries and prev_actions:
152
+ for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)):
153
+ state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] "
154
+
155
+ # Add current query
156
+ current_query = state.get('current_query', '')
157
+ state_text += f"Current query: {current_query}"
158
+
159
+ # Tokenize
160
+ encoding = self.tokenizer(
161
+ state_text,
162
+ truncation=True,
163
+ padding='max_length',
164
+ max_length=max_length,
165
+ return_tensors='pt'
166
+ )
167
+
168
+ return encoding
169
+
170
+ def predict_action(
171
+ self,
172
+ state: Dict,
173
+ use_dropout: bool = False,
174
+ num_samples: int = 10
175
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
176
+ """
177
+ Predict action probabilities for a given state.
178
+
179
+ Args:
180
+ state: Conversation state dictionary
181
+ use_dropout: Whether to use MC Dropout for uncertainty estimation
182
+ num_samples: Number of MC Dropout samples (if use_dropout=True)
183
+
184
+ Returns:
185
+ probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)]
186
+ uncertainty: Standard deviation across samples (if use_dropout=True)
187
+ """
188
+ device = next(self.parameters()).device
189
+
190
+ if use_dropout:
191
+ # MC Dropout for uncertainty estimation
192
+ self.train() # Enable dropout during inference
193
+ all_probs = []
194
+
195
+ for _ in range(num_samples):
196
+ with torch.no_grad():
197
+ encoding = self.encode_state(state)
198
+ input_ids = encoding['input_ids'].to(device)
199
+ attention_mask = encoding['attention_mask'].to(device)
200
+
201
+ _, probs = self.forward(input_ids, attention_mask)
202
+ all_probs.append(probs.cpu().numpy())
203
+
204
+ # Average probabilities across samples
205
+ avg_probs = np.mean(all_probs, axis=0)
206
+
207
+ # Calculate uncertainty (standard deviation)
208
+ uncertainty = np.std(all_probs, axis=0)
209
+
210
+ return avg_probs, uncertainty
211
+
212
+ else:
213
+ # Standard inference (no uncertainty estimation)
214
+ self.eval()
215
+
216
+ with torch.no_grad():
217
+ encoding = self.encode_state(state)
218
+ input_ids = encoding['input_ids'].to(device)
219
+ attention_mask = encoding['attention_mask'].to(device)
220
+
221
+ _, probs = self.forward(input_ids, attention_mask)
222
+
223
+ return probs.cpu().numpy(), None
224
+
225
+
226
+ # ============================================================================
227
+ # MODULE-LEVEL CACHING (Load once on import)
228
+ # ============================================================================
229
+
230
+ # Global variables for caching
231
+ POLICY_MODEL: Optional[PolicyNetwork] = None
232
+ POLICY_TOKENIZER: Optional[AutoTokenizer] = None
233
+
234
+
235
+
236
+ # =============================================================================================
237
+ # Latest version given by perplexity, should work, if not then use one of the other versions.
238
+ # =============================================================================================
239
+
240
+ def load_policy_model() -> PolicyNetwork:
241
+ """
242
+ Load trained policy model (called once on startup).
243
+ Downloads from HuggingFace Hub if not present locally.
244
+ Uses module-level caching - model stays in RAM.
245
+
246
+ Returns:
247
+ PolicyNetwork: Loaded policy model
248
+ """
249
+ global POLICY_MODEL, POLICY_TOKENIZER
250
+
251
+ if POLICY_MODEL is None:
252
+ # Download model from HF Hub if needed (for deployment)
253
+ settings.download_model_if_needed(
254
+ hf_filename="models/policy_query_only.pt",
255
+ local_path=settings.POLICY_MODEL_PATH
256
+ )
257
+
258
+ print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
259
+
260
+ try:
261
+ # Load checkpoint first to get vocab size
262
+ checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
263
+
264
+ # Create model instance
265
+ POLICY_MODEL = PolicyNetwork(
266
+ model_name="bert-base-uncased",
267
+ dropout_rate=0.1
268
+ )
269
+
270
+ # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
271
+ saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
272
+ current_vocab_size = len(POLICY_MODEL.tokenizer)
273
+
274
+ if saved_vocab_size != current_vocab_size:
275
+ print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
276
+ print(f"✅ Resizing tokenizer and embeddings to match saved model...")
277
+ # Resize model to match saved checkpoint
278
+ POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
279
+
280
+ # Move to device
281
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
282
+
283
+ # Now load trained weights (sizes will match!)
284
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
285
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
286
+ else:
287
+ POLICY_MODEL.load_state_dict(checkpoint)
288
+
289
+ # Set to evaluation mode
290
+ POLICY_MODEL.eval()
291
+
292
+ # Cache tokenizer
293
+ POLICY_TOKENIZER = POLICY_MODEL.tokenizer
294
+
295
+ print("✅ Policy network loaded and cached")
296
+
297
+ except FileNotFoundError:
298
+ print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
299
+ print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
300
+ raise
301
+ except Exception as e:
302
+ print(f"❌ Failed to load policy model: {e}")
303
+ raise
304
+
305
+ return POLICY_MODEL
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+ # ===========================================================================
321
+ # This version is used in the code, atleast for localhost testing
322
+ # ===========================================================================
323
+
324
+ # def load_policy_model() -> PolicyNetwork:
325
+ # """
326
+ # Load trained policy model (called once on startup).
327
+ # Uses module-level caching - model stays in RAM.
328
+
329
+ # Returns:
330
+ # PolicyNetwork: Loaded policy model
331
+ # """
332
+ # global POLICY_MODEL, POLICY_TOKENIZER
333
+
334
+ # if POLICY_MODEL is None:
335
+ # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
336
+
337
+ # try:
338
+ # # Load checkpoint first to get vocab size
339
+ # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
340
+
341
+ # # Create model instance
342
+ # POLICY_MODEL = PolicyNetwork(
343
+ # model_name="bert-base-uncased",
344
+ # dropout_rate=0.1
345
+ # )
346
+
347
+ # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
348
+ # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
349
+ # current_vocab_size = len(POLICY_MODEL.tokenizer)
350
+
351
+ # if saved_vocab_size != current_vocab_size:
352
+ # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
353
+ # print(f"✅ Resizing tokenizer and embeddings to match saved model...")
354
+
355
+ # # Resize model to match saved checkpoint
356
+ # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
357
+
358
+ # # Move to device
359
+ # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
360
+
361
+ # # Now load trained weights (sizes will match!)
362
+ # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
363
+ # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
364
+ # else:
365
+ # POLICY_MODEL.load_state_dict(checkpoint)
366
+
367
+ # # Set to evaluation mode
368
+ # POLICY_MODEL.eval()
369
+
370
+ # # Cache tokenizer
371
+ # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
372
+
373
+ # print("✅ Policy network loaded and cached")
374
+
375
+ # except FileNotFoundError:
376
+ # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
377
+ # print("⚠️ You need to train the policy network first!")
378
+ # raise
379
+
380
+ # except Exception as e:
381
+ # print(f"❌ Failed to load policy model: {e}")
382
+ # raise
383
+
384
+ # return POLICY_MODEL
385
+
386
+
387
+
388
+
389
+
390
+
391
+
392
+
393
+
394
+
395
+
396
+
397
+
398
+
399
+
400
+
401
+
402
+
403
+
404
+
405
+
406
+
407
+
408
+ # =====================================================================================
409
+ # This is the older version or proably a different version, potentially still useful
410
+ # =====================================================================================
411
+
412
+ # def load_policy_model() -> PolicyNetwork:
413
+ # """
414
+ # Load trained policy model (called once on startup).
415
+ # Uses module-level caching - model stays in RAM.
416
+
417
+ # Returns:
418
+ # PolicyNetwork: Loaded policy model
419
+ # """
420
+ # global POLICY_MODEL, POLICY_TOKENIZER
421
+
422
+ # if POLICY_MODEL is None:
423
+ # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
424
+
425
+ # try:
426
+ # # Create model instance
427
+ # POLICY_MODEL = PolicyNetwork(
428
+ # model_name="bert-base-uncased",
429
+ # dropout_rate=0.1
430
+ # ).to(settings.DEVICE)
431
+
432
+ # # Load trained weights
433
+ # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
434
+
435
+ # # Handle different checkpoint formats
436
+ # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
437
+ # # Full checkpoint with metadata
438
+ # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
439
+ # else:
440
+ # # Just state dict
441
+ # POLICY_MODEL.load_state_dict(checkpoint)
442
+
443
+ # # Set to evaluation mode
444
+ # POLICY_MODEL.eval()
445
+
446
+ # # Cache tokenizer
447
+ # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
448
+
449
+ # print("✅ Policy network loaded and cached")
450
+
451
+ # except FileNotFoundError:
452
+ # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
453
+ # print("⚠️ You need to train the policy network first!")
454
+ # raise
455
+
456
+ # except Exception as e:
457
+ # print(f"❌ Failed to load policy model: {e}")
458
+ # raise
459
+
460
+ # return POLICY_MODEL
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+ # ============================================================================
471
+ # PREDICTION FUNCTIONS
472
+ # ============================================================================
473
+
474
+ def create_state_from_history(
475
+ current_query: str,
476
+ conversation_history: List[Dict],
477
+ max_history: int = 2
478
+ ) -> Dict:
479
+ """
480
+ Create state dictionary from conversation history.
481
+ Extracts last N query-action pairs.
482
+
483
+ Args:
484
+ current_query: Current user query
485
+ conversation_history: List of conversation turns
486
+ Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}
487
+ max_history: Maximum number of previous turns to include (default: 2)
488
+
489
+ Returns:
490
+ dict: State dictionary for policy network
491
+ """
492
+ state = {
493
+ 'current_query': current_query,
494
+ 'previous_queries': [],
495
+ 'previous_actions': []
496
+ }
497
+
498
+ if not conversation_history:
499
+ return state
500
+
501
+ # Extract last N conversation turns (user + assistant pairs)
502
+ relevant_history = conversation_history[-(max_history * 2):]
503
+
504
+ for i, turn in enumerate(relevant_history):
505
+ # User turns
506
+ if turn.get('role') == 'user':
507
+ query = turn.get('content', '')
508
+ state['previous_queries'].append(query)
509
+
510
+ # Look for corresponding assistant turn
511
+ if i + 1 < len(relevant_history):
512
+ bot_turn = relevant_history[i + 1]
513
+ if bot_turn.get('role') == 'assistant':
514
+ metadata = bot_turn.get('metadata', {})
515
+ action = metadata.get('policy_action', 'FETCH')
516
+ state['previous_actions'].append(action)
517
+
518
+ return state
519
+
520
+
521
+ def predict_policy_action(
522
+ query: str,
523
+ history: List[Dict] = None,
524
+ return_probs: bool = False
525
+ ) -> Dict:
526
+ """
527
+ Predict FETCH/NO_FETCH action for a query.
528
+
529
+ Args:
530
+ query: User query text
531
+ history: Conversation history (optional)
532
+ return_probs: Whether to return full probability distribution
533
+
534
+ Returns:
535
+ dict: Prediction results
536
+ {
537
+ 'action': 'FETCH' or 'NO_FETCH',
538
+ 'confidence': float (0-1),
539
+ 'fetch_prob': float,
540
+ 'no_fetch_prob': float,
541
+ 'should_retrieve': bool
542
+ }
543
+ """
544
+ # Load model (cached after first call)
545
+ model = load_policy_model()
546
+
547
+ # Create state from history
548
+ if history is None:
549
+ history = []
550
+
551
+ state = create_state_from_history(query, history)
552
+
553
+ # Predict action
554
+ probs, _ = model.predict_action(state, use_dropout=False)
555
+
556
+ # Extract probabilities
557
+ fetch_prob = float(probs[0][0])
558
+ no_fetch_prob = float(probs[0][1])
559
+
560
+ # Determine action (argmax)
561
+ action_idx = np.argmax(probs[0])
562
+ action = "FETCH" if action_idx == 0 else "NO_FETCH"
563
+ confidence = float(probs[0][action_idx])
564
+
565
+ # Check confidence threshold
566
+ should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD)
567
+
568
+ result = {
569
+ 'action': action,
570
+ 'confidence': confidence,
571
+ 'should_retrieve': should_retrieve,
572
+ 'policy_decision': action
573
+ }
574
+
575
+ if return_probs:
576
+ result['fetch_prob'] = fetch_prob
577
+ result['no_fetch_prob'] = no_fetch_prob
578
+
579
+ return result
580
+
581
+
582
+ # ============================================================================
583
+ # USAGE EXAMPLE (for reference)
584
+ # ============================================================================
585
+ """
586
+ # In your service file:
587
+
588
+ from app.ml.policy_network import predict_policy_action
589
+
590
+ # Predict action
591
+ history = [
592
+ {'role': 'user', 'content': 'What is my balance?'},
593
+ {'role': 'assistant', 'content': '$1000', 'metadata': {'policy_action': 'FETCH'}}
594
+ ]
595
+
596
+ result = predict_policy_action(
597
+ query="Thank you!",
598
+ history=history,
599
+ return_probs=True
600
+ )
601
+
602
+ print(result)
603
+ # {
604
+ # 'action': 'NO_FETCH',
605
+ # 'confidence': 0.95,
606
+ # 'should_retrieve': False,
607
+ # 'fetch_prob': 0.05,
608
+ # 'no_fetch_prob': 0.95
609
+ # }
610
+ """
app/ml/policy_network.py CHANGED
@@ -15,10 +15,8 @@ import torch.nn.functional as F
15
  import numpy as np
16
  from typing import List, Dict, Optional, Tuple
17
  from transformers import AutoTokenizer, AutoModel
18
-
19
  from app.config import settings
20
 
21
-
22
  # ============================================================================
23
  # POLICY NETWORK (From RL.py)
24
  # ============================================================================
@@ -37,12 +35,12 @@ class PolicyNetwork(nn.Module):
37
  - Loss: Policy Gradient + Entropy Regularization
38
  - Optimizer: AdamW
39
  - Reward structure:
40
- * FETCH: +0.5 (always)
41
- * NO_FETCH + Good: +2.0
42
- * NO_FETCH + Bad: -0.5
43
  """
44
 
45
- def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1):
46
  super(PolicyNetwork, self).__init__()
47
 
48
  # Load pre-trained BERT
@@ -61,8 +59,18 @@ class PolicyNetwork(nn.Module):
61
  # Initialize random embeddings for special tokens
62
  self._init_action_embeddings()
63
 
64
- # Classification head: BERT hidden size (768) → 2 classes
65
- self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Dropout for regularization
68
  self.dropout = nn.Dropout(dropout_rate)
@@ -114,8 +122,8 @@ class PolicyNetwork(nn.Module):
114
  return logits, probs
115
 
116
  def encode_state(
117
- self,
118
- state: Dict,
119
  max_length: int = None
120
  ) -> Dict[str, torch.Tensor]:
121
  """
@@ -129,7 +137,7 @@ class PolicyNetwork(nn.Module):
129
  }
130
 
131
  Encoding format:
132
- "Previous query 1: <text> [Action: [FETCH]] Previous query 2: <text> [Action: [NO_FETCH]] Current query: <text>"
133
 
134
  Args:
135
  state: State dictionary
@@ -168,9 +176,9 @@ class PolicyNetwork(nn.Module):
168
  return encoding
169
 
170
  def predict_action(
171
- self,
172
- state: Dict,
173
- use_dropout: bool = False,
174
  num_samples: int = 10
175
  ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
176
  """
@@ -212,15 +220,14 @@ class PolicyNetwork(nn.Module):
212
  else:
213
  # Standard inference (no uncertainty estimation)
214
  self.eval()
215
-
216
  with torch.no_grad():
217
  encoding = self.encode_state(state)
218
  input_ids = encoding['input_ids'].to(device)
219
  attention_mask = encoding['attention_mask'].to(device)
220
 
221
  _, probs = self.forward(input_ids, attention_mask)
222
-
223
- return probs.cpu().numpy(), None
224
 
225
 
226
  # ============================================================================
@@ -232,11 +239,6 @@ POLICY_MODEL: Optional[PolicyNetwork] = None
232
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
233
 
234
 
235
-
236
- # =============================================================================================
237
- # Latest version given by perplexity, should work, if not then use one of the other versions.
238
- # =============================================================================================
239
-
240
  def load_policy_model() -> PolicyNetwork:
241
  """
242
  Load trained policy model (called once on startup).
@@ -256,15 +258,20 @@ def load_policy_model() -> PolicyNetwork:
256
  )
257
 
258
  print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
259
-
260
  try:
261
- # Load checkpoint first to get vocab size
262
  checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
263
 
264
- # Create model instance
 
 
 
 
 
265
  POLICY_MODEL = PolicyNetwork(
266
  model_name="bert-base-uncased",
267
- dropout_rate=0.1
 
268
  )
269
 
270
  # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
@@ -280,7 +287,7 @@ def load_policy_model() -> PolicyNetwork:
280
  # Move to device
281
  POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
282
 
283
- # Now load trained weights (sizes will match!)
284
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
285
  POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
286
  else:
@@ -300,179 +307,19 @@ def load_policy_model() -> PolicyNetwork:
300
  raise
301
  except Exception as e:
302
  print(f"❌ Failed to load policy model: {e}")
 
 
303
  raise
304
 
305
  return POLICY_MODEL
306
 
307
 
308
-
309
-
310
-
311
-
312
-
313
-
314
-
315
-
316
-
317
-
318
-
319
-
320
- # ===========================================================================
321
- # This version is used in the code, atleast for localhost testing
322
- # ===========================================================================
323
-
324
- # def load_policy_model() -> PolicyNetwork:
325
- # """
326
- # Load trained policy model (called once on startup).
327
- # Uses module-level caching - model stays in RAM.
328
-
329
- # Returns:
330
- # PolicyNetwork: Loaded policy model
331
- # """
332
- # global POLICY_MODEL, POLICY_TOKENIZER
333
-
334
- # if POLICY_MODEL is None:
335
- # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
336
-
337
- # try:
338
- # # Load checkpoint first to get vocab size
339
- # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
340
-
341
- # # Create model instance
342
- # POLICY_MODEL = PolicyNetwork(
343
- # model_name="bert-base-uncased",
344
- # dropout_rate=0.1
345
- # )
346
-
347
- # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
348
- # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
349
- # current_vocab_size = len(POLICY_MODEL.tokenizer)
350
-
351
- # if saved_vocab_size != current_vocab_size:
352
- # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
353
- # print(f"✅ Resizing tokenizer and embeddings to match saved model...")
354
-
355
- # # Resize model to match saved checkpoint
356
- # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
357
-
358
- # # Move to device
359
- # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
360
-
361
- # # Now load trained weights (sizes will match!)
362
- # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
363
- # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
364
- # else:
365
- # POLICY_MODEL.load_state_dict(checkpoint)
366
-
367
- # # Set to evaluation mode
368
- # POLICY_MODEL.eval()
369
-
370
- # # Cache tokenizer
371
- # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
372
-
373
- # print("✅ Policy network loaded and cached")
374
-
375
- # except FileNotFoundError:
376
- # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
377
- # print("⚠️ You need to train the policy network first!")
378
- # raise
379
-
380
- # except Exception as e:
381
- # print(f"❌ Failed to load policy model: {e}")
382
- # raise
383
-
384
- # return POLICY_MODEL
385
-
386
-
387
-
388
-
389
-
390
-
391
-
392
-
393
-
394
-
395
-
396
-
397
-
398
-
399
-
400
-
401
-
402
-
403
-
404
-
405
-
406
-
407
-
408
- # =====================================================================================
409
- # This is the older version or proably a different version, potentially still useful
410
- # =====================================================================================
411
-
412
- # def load_policy_model() -> PolicyNetwork:
413
- # """
414
- # Load trained policy model (called once on startup).
415
- # Uses module-level caching - model stays in RAM.
416
-
417
- # Returns:
418
- # PolicyNetwork: Loaded policy model
419
- # """
420
- # global POLICY_MODEL, POLICY_TOKENIZER
421
-
422
- # if POLICY_MODEL is None:
423
- # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
424
-
425
- # try:
426
- # # Create model instance
427
- # POLICY_MODEL = PolicyNetwork(
428
- # model_name="bert-base-uncased",
429
- # dropout_rate=0.1
430
- # ).to(settings.DEVICE)
431
-
432
- # # Load trained weights
433
- # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
434
-
435
- # # Handle different checkpoint formats
436
- # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
437
- # # Full checkpoint with metadata
438
- # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
439
- # else:
440
- # # Just state dict
441
- # POLICY_MODEL.load_state_dict(checkpoint)
442
-
443
- # # Set to evaluation mode
444
- # POLICY_MODEL.eval()
445
-
446
- # # Cache tokenizer
447
- # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
448
-
449
- # print("✅ Policy network loaded and cached")
450
-
451
- # except FileNotFoundError:
452
- # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
453
- # print("⚠️ You need to train the policy network first!")
454
- # raise
455
-
456
- # except Exception as e:
457
- # print(f"❌ Failed to load policy model: {e}")
458
- # raise
459
-
460
- # return POLICY_MODEL
461
-
462
-
463
-
464
-
465
-
466
-
467
-
468
-
469
-
470
  # ============================================================================
471
  # PREDICTION FUNCTIONS
472
  # ============================================================================
473
 
474
  def create_state_from_history(
475
- current_query: str,
476
  conversation_history: List[Dict],
477
  max_history: int = 2
478
  ) -> Dict:
@@ -519,7 +366,7 @@ def create_state_from_history(
519
 
520
 
521
  def predict_policy_action(
522
- query: str,
523
  history: List[Dict] = None,
524
  return_probs: bool = False
525
  ) -> Dict:
@@ -533,13 +380,13 @@ def predict_policy_action(
533
 
534
  Returns:
535
  dict: Prediction results
536
- {
537
- 'action': 'FETCH' or 'NO_FETCH',
538
- 'confidence': float (0-1),
539
- 'fetch_prob': float,
540
- 'no_fetch_prob': float,
541
- 'should_retrieve': bool
542
- }
543
  """
544
  # Load model (cached after first call)
545
  model = load_policy_model()
@@ -584,7 +431,6 @@ def predict_policy_action(
584
  # ============================================================================
585
  """
586
  # In your service file:
587
-
588
  from app.ml.policy_network import predict_policy_action
589
 
590
  # Predict action
@@ -607,4 +453,4 @@ print(result)
607
  # 'fetch_prob': 0.05,
608
  # 'no_fetch_prob': 0.95
609
  # }
610
- """
 
15
  import numpy as np
16
  from typing import List, Dict, Optional, Tuple
17
  from transformers import AutoTokenizer, AutoModel
 
18
  from app.config import settings
19
 
 
20
  # ============================================================================
21
  # POLICY NETWORK (From RL.py)
22
  # ============================================================================
 
35
  - Loss: Policy Gradient + Entropy Regularization
36
  - Optimizer: AdamW
37
  - Reward structure:
38
+ * FETCH: +0.5 (always)
39
+ * NO_FETCH + Good: +2.0
40
+ * NO_FETCH + Bad: -0.5
41
  """
42
 
43
+ def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
44
  super(PolicyNetwork, self).__init__()
45
 
46
  # Load pre-trained BERT
 
59
  # Initialize random embeddings for special tokens
60
  self._init_action_embeddings()
61
 
62
+ # FLEXIBLE CLASSIFIER ARCHITECTURE
63
+ if use_multilayer:
64
+ # Multi-layer classifier (your new trained model)
65
+ self.classifier = nn.Sequential(
66
+ nn.Linear(self.bert.config.hidden_size, 256),
67
+ nn.ReLU(),
68
+ nn.Dropout(dropout_rate),
69
+ nn.Linear(256, 2)
70
+ )
71
+ else:
72
+ # Single-layer classifier (fallback)
73
+ self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
74
 
75
  # Dropout for regularization
76
  self.dropout = nn.Dropout(dropout_rate)
 
122
  return logits, probs
123
 
124
  def encode_state(
125
+ self,
126
+ state: Dict,
127
  max_length: int = None
128
  ) -> Dict[str, torch.Tensor]:
129
  """
 
137
  }
138
 
139
  Encoding format:
140
+ "Previous query 1: [Action: [FETCH]] Previous query 2: [Action: [NO_FETCH]] Current query: <query>"
141
 
142
  Args:
143
  state: State dictionary
 
176
  return encoding
177
 
178
  def predict_action(
179
+ self,
180
+ state: Dict,
181
+ use_dropout: bool = False,
182
  num_samples: int = 10
183
  ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
184
  """
 
220
  else:
221
  # Standard inference (no uncertainty estimation)
222
  self.eval()
 
223
  with torch.no_grad():
224
  encoding = self.encode_state(state)
225
  input_ids = encoding['input_ids'].to(device)
226
  attention_mask = encoding['attention_mask'].to(device)
227
 
228
  _, probs = self.forward(input_ids, attention_mask)
229
+
230
+ return probs.cpu().numpy(), None
231
 
232
 
233
  # ============================================================================
 
239
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
240
 
241
 
 
 
 
 
 
242
  def load_policy_model() -> PolicyNetwork:
243
  """
244
  Load trained policy model (called once on startup).
 
258
  )
259
 
260
  print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
 
261
  try:
262
+ # Load checkpoint first to detect architecture
263
  checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
264
 
265
+ # AUTO-DETECT ARCHITECTURE from checkpoint keys
266
+ has_multilayer = "classifier.0.weight" in checkpoint
267
+
268
+ print(f"📊 Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
269
+
270
+ # Create model instance with correct architecture
271
  POLICY_MODEL = PolicyNetwork(
272
  model_name="bert-base-uncased",
273
+ dropout_rate=0.1,
274
+ use_multilayer=has_multilayer # ✅ Auto-detect!
275
  )
276
 
277
  # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
 
287
  # Move to device
288
  POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
289
 
290
+ # Now load trained weights (sizes and architecture will match!)
291
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
292
  POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
293
  else:
 
307
  raise
308
  except Exception as e:
309
  print(f"❌ Failed to load policy model: {e}")
310
+ import traceback
311
+ traceback.print_exc()
312
  raise
313
 
314
  return POLICY_MODEL
315
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  # ============================================================================
318
  # PREDICTION FUNCTIONS
319
  # ============================================================================
320
 
321
  def create_state_from_history(
322
+ current_query: str,
323
  conversation_history: List[Dict],
324
  max_history: int = 2
325
  ) -> Dict:
 
366
 
367
 
368
  def predict_policy_action(
369
+ query: str,
370
  history: List[Dict] = None,
371
  return_probs: bool = False
372
  ) -> Dict:
 
380
 
381
  Returns:
382
  dict: Prediction results
383
+ {
384
+ 'action': 'FETCH' or 'NO_FETCH',
385
+ 'confidence': float (0-1),
386
+ 'fetch_prob': float,
387
+ 'no_fetch_prob': float,
388
+ 'should_retrieve': bool
389
+ }
390
  """
391
  # Load model (cached after first call)
392
  model = load_policy_model()
 
431
  # ============================================================================
432
  """
433
  # In your service file:
 
434
  from app.ml.policy_network import predict_policy_action
435
 
436
  # Predict action
 
453
  # 'fetch_prob': 0.05,
454
  # 'no_fetch_prob': 0.95
455
  # }
456
+ """