eeshanyaj commited on
Commit
cbf7898
·
1 Parent(s): 3986a1e

fix policy model issues

Browse files
app/ml/backup_policynetwork.py CHANGED
@@ -1,3 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  BERT-based Policy Network for FETCH/NO_FETCH decisions
3
  Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization)
 
1
+ """
2
+ BERT-based Policy Network for FETCH/NO_FETCH decisions
3
+ Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization)
4
+
5
+ This is adapted from your RL.py with:
6
+ - PolicyNetwork class (BERT-based)
7
+ - State encoding from conversation history
8
+ - Action prediction (FETCH vs NO_FETCH)
9
+ - Module-level caching (load once on startup)
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from typing import List, Dict, Optional, Tuple
17
+ from transformers import AutoTokenizer, AutoModel
18
+ from app.config import settings
19
+
20
+ # ============================================================================
21
+ # POLICY NETWORK (From RL.py)
22
+ # ============================================================================
23
+
24
+ class PolicyNetwork(nn.Module):
25
+ """
26
+ BERT-based Policy Network for deciding FETCH vs NO_FETCH actions.
27
+
28
+ Architecture:
29
+ - Base: BERT-base-uncased (pre-trained)
30
+ - Input: Current query + conversation history + previous actions
31
+ - Output: 2-class softmax (FETCH=0, NO_FETCH=1)
32
+ - Special tokens: [FETCH], [NO_FETCH] for action encoding
33
+
34
+ Training Details:
35
+ - Loss: Policy Gradient + Entropy Regularization
36
+ - Optimizer: AdamW
37
+ - Reward structure:
38
+ * FETCH: +0.5 (always)
39
+ * NO_FETCH + Good: +2.0
40
+ * NO_FETCH + Bad: -0.5
41
+ """
42
+
43
+ # def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
44
+ # super(PolicyNetwork, self).__init__()
45
+
46
+ # # Load pre-trained BERT
47
+ # self.bert = AutoModel.from_pretrained(model_name)
48
+
49
+ # # Load tokenizer
50
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ # # Add special tokens for actions: [FETCH] and [NO_FETCH]
53
+ # special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
54
+ # self.tokenizer.add_special_tokens(special_tokens)
55
+
56
+ # # Resize BERT embeddings to accommodate new tokens
57
+ # self.bert.resize_token_embeddings(len(self.tokenizer))
58
+
59
+ # # Initialize random embeddings for special tokens
60
+ # self._init_action_embeddings()
61
+
62
+ # # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE
63
+ # if use_multilayer:
64
+ # # Multi-layer classifier (your new trained model)
65
+ # self.classifier = nn.Sequential(
66
+ # nn.Linear(self.bert.config.hidden_size, 256),
67
+ # nn.ReLU(),
68
+ # nn.Dropout(dropout_rate),
69
+ # nn.Linear(256, 2)
70
+ # )
71
+ # else:
72
+ # # Single-layer classifier (fallback)
73
+ # self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
74
+
75
+ # # Dropout for regularization
76
+ # self.dropout = nn.Dropout(dropout_rate)
77
+
78
+
79
+ def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True, hidden_size: int = 128):
80
+ super(PolicyNetwork, self).__init__()
81
+
82
+ # Load pre-trained BERT
83
+ self.bert = AutoModel.from_pretrained(model_name)
84
+
85
+ # Load tokenizer
86
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
87
+
88
+ # Add special tokens for actions: [FETCH] and [NO_FETCH]
89
+ special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
90
+ self.tokenizer.add_special_tokens(special_tokens)
91
+
92
+ # Resize BERT embeddings to accommodate new tokens
93
+ self.bert.resize_token_embeddings(len(self.tokenizer))
94
+
95
+ # Initialize random embeddings for special tokens
96
+ self._init_action_embeddings()
97
+
98
+ # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size)
99
+ if use_multilayer:
100
+ # Multi-layer classifier with specified hidden size (128 or 256)
101
+ self.classifier = nn.Sequential(
102
+ nn.Linear(self.bert.config.hidden_size, hidden_size), # ✅ Use hidden_size param
103
+ nn.ReLU(),
104
+ nn.Dropout(dropout_rate),
105
+ nn.Linear(hidden_size, 2) # ✅ Use hidden_size param
106
+ )
107
+ else:
108
+ # Single-layer classifier (fallback)
109
+ self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
110
+
111
+ # Dropout for regularization
112
+ self.dropout = nn.Dropout(dropout_rate)
113
+
114
+
115
+
116
+
117
+
118
+
119
+ def _init_action_embeddings(self):
120
+ """
121
+ Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
122
+ These are learned during training.
123
+ """
124
+ with torch.no_grad():
125
+ # Get token IDs for special tokens
126
+ fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]")
127
+ no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]")
128
+
129
+ # Get embedding dimension
130
+ embedding_dim = self.bert.config.hidden_size
131
+
132
+ # Initialize with small random values (same as BERT initialization)
133
+ self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02
134
+ self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02
135
+
136
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """
138
+ Forward pass through BERT + classifier.
139
+
140
+ Args:
141
+ input_ids: Tokenized input IDs (shape: [batch_size, seq_len])
142
+ attention_mask: Attention mask (shape: [batch_size, seq_len])
143
+
144
+ Returns:
145
+ logits: Raw logits (shape: [batch_size, 2])
146
+ probs: Softmax probabilities (shape: [batch_size, 2])
147
+ """
148
+ # Pass through BERT
149
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
150
+
151
+ # Extract [CLS] token representation (first token)
152
+ cls_output = outputs.last_hidden_state[:, 0, :]
153
+
154
+ # Apply dropout
155
+ cls_output = self.dropout(cls_output)
156
+
157
+ # Classification
158
+ logits = self.classifier(cls_output)
159
+
160
+ # Softmax for probabilities
161
+ probs = F.softmax(logits, dim=-1)
162
+
163
+ return logits, probs
164
+
165
+ def encode_state(
166
+ self,
167
+ state: Dict,
168
+ max_length: int = None
169
+ ) -> Dict[str, torch.Tensor]:
170
+ """
171
+ Encode conversation state into BERT input format.
172
+
173
+ State structure:
174
+ {
175
+ 'previous_queries': [query1, query2, ...],
176
+ 'previous_actions': ['FETCH', 'NO_FETCH', ...],
177
+ 'current_query': 'user query'
178
+ }
179
+
180
+ Encoding format:
181
+ "Previous query 1: [Action: [FETCH]] Previous query 2: [Action: [NO_FETCH]] Current query: <query>"
182
+
183
+ Args:
184
+ state: State dictionary
185
+ max_length: Maximum sequence length (default from config)
186
+
187
+ Returns:
188
+ dict: Tokenized inputs (input_ids, attention_mask)
189
+ """
190
+ if max_length is None:
191
+ max_length = settings.POLICY_MAX_LEN
192
+
193
+ # Build state text from conversation history
194
+ state_text = ""
195
+
196
+ # Add previous queries and their actions
197
+ prev_queries = state.get('previous_queries', [])
198
+ prev_actions = state.get('previous_actions', [])
199
+
200
+ if prev_queries and prev_actions:
201
+ for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)):
202
+ state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] "
203
+
204
+ # Add current query
205
+ current_query = state.get('current_query', '')
206
+ state_text += f"Current query: {current_query}"
207
+
208
+ # Tokenize
209
+ encoding = self.tokenizer(
210
+ state_text,
211
+ truncation=True,
212
+ padding='max_length',
213
+ max_length=max_length,
214
+ return_tensors='pt'
215
+ )
216
+
217
+ return encoding
218
+
219
+ def predict_action(
220
+ self,
221
+ state: Dict,
222
+ use_dropout: bool = False,
223
+ num_samples: int = 10
224
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
225
+ """
226
+ Predict action probabilities for a given state.
227
+
228
+ Args:
229
+ state: Conversation state dictionary
230
+ use_dropout: Whether to use MC Dropout for uncertainty estimation
231
+ num_samples: Number of MC Dropout samples (if use_dropout=True)
232
+
233
+ Returns:
234
+ probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)]
235
+ uncertainty: Standard deviation across samples (if use_dropout=True)
236
+ """
237
+ device = next(self.parameters()).device
238
+
239
+ if use_dropout:
240
+ # MC Dropout for uncertainty estimation
241
+ self.train() # Enable dropout during inference
242
+ all_probs = []
243
+
244
+ for _ in range(num_samples):
245
+ with torch.no_grad():
246
+ encoding = self.encode_state(state)
247
+ input_ids = encoding['input_ids'].to(device)
248
+ attention_mask = encoding['attention_mask'].to(device)
249
+
250
+ _, probs = self.forward(input_ids, attention_mask)
251
+ all_probs.append(probs.cpu().numpy())
252
+
253
+ # Average probabilities across samples
254
+ avg_probs = np.mean(all_probs, axis=0)
255
+
256
+ # Calculate uncertainty (standard deviation)
257
+ uncertainty = np.std(all_probs, axis=0)
258
+
259
+ return avg_probs, uncertainty
260
+
261
+ else:
262
+ # Standard inference (no uncertainty estimation)
263
+ self.eval()
264
+ with torch.no_grad():
265
+ encoding = self.encode_state(state)
266
+ input_ids = encoding['input_ids'].to(device)
267
+ attention_mask = encoding['attention_mask'].to(device)
268
+
269
+ _, probs = self.forward(input_ids, attention_mask)
270
+
271
+ return probs.cpu().numpy(), None
272
+
273
+
274
+ # ============================================================================
275
+ # MODULE-LEVEL CACHING (Load once on import)
276
+ # ============================================================================
277
+
278
+ # Global variables for caching
279
+ POLICY_MODEL: Optional[PolicyNetwork] = None
280
+ POLICY_TOKENIZER: Optional[AutoTokenizer] = None
281
+
282
+
283
+ # def load_policy_model() -> PolicyNetwork:
284
+ # """
285
+ # Load trained policy model (called once on startup).
286
+ # Downloads from HuggingFace Hub if not present locally.
287
+ # Uses module-level caching - model stays in RAM.
288
+
289
+ # Returns:
290
+ # PolicyNetwork: Loaded policy model
291
+ # """
292
+ # global POLICY_MODEL, POLICY_TOKENIZER
293
+
294
+ # if POLICY_MODEL is None:
295
+ # # Download model from HF Hub if needed (for deployment)
296
+ # settings.download_model_if_needed(
297
+ # hf_filename="models/policy_query_only.pt",
298
+ # local_path=settings.POLICY_MODEL_PATH
299
+ # )
300
+
301
+ # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
302
+ # try:
303
+ # # Load checkpoint first to detect architecture
304
+ # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
305
+
306
+ # # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys
307
+ # has_multilayer = "classifier.0.weight" in checkpoint
308
+
309
+ # print(f"📊 Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
310
+
311
+ # # Create model instance with correct architecture
312
+ # POLICY_MODEL = PolicyNetwork(
313
+ # model_name="bert-base-uncased",
314
+ # dropout_rate=0.1,
315
+ # use_multilayer=has_multilayer # ✅ Auto-detect!
316
+ # )
317
+
318
+ # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
319
+ # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
320
+ # current_vocab_size = len(POLICY_MODEL.tokenizer)
321
+
322
+ # if saved_vocab_size != current_vocab_size:
323
+ # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
324
+ # print(f"✅ Resizing tokenizer and embeddings to match saved model...")
325
+ # # Resize model to match saved checkpoint
326
+ # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
327
+
328
+ # # Move to device
329
+ # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
330
+
331
+ # # Now load trained weights (sizes and architecture will match!)
332
+ # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
333
+ # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
334
+ # else:
335
+ # POLICY_MODEL.load_state_dict(checkpoint)
336
+
337
+ # # Set to evaluation mode
338
+ # POLICY_MODEL.eval()
339
+
340
+ # # Cache tokenizer
341
+ # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
342
+
343
+ # print("✅ Policy network loaded and cached")
344
+
345
+ # except FileNotFoundError:
346
+ # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
347
+ # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
348
+ # raise
349
+ # except Exception as e:
350
+ # print(f"❌ Failed to load policy model: {e}")
351
+ # import traceback
352
+ # traceback.print_exc()
353
+ # raise
354
+
355
+ # return POLICY_MODEL
356
+
357
+ def load_policy_model() -> PolicyNetwork:
358
+ """
359
+ Load trained policy model (called once on startup).
360
+ Downloads from HuggingFace Hub if not present locally.
361
+ Uses module-level caching - model stays in RAM.
362
+
363
+ Returns:
364
+ PolicyNetwork: Loaded policy model
365
+ """
366
+ global POLICY_MODEL, POLICY_TOKENIZER
367
+
368
+ if POLICY_MODEL is None:
369
+ # Download model from HF Hub if needed (for deployment)
370
+ settings.download_model_if_needed(
371
+ hf_filename="models/policy_query_only.pt",
372
+ local_path=settings.POLICY_MODEL_PATH
373
+ )
374
+
375
+ print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
376
+ try:
377
+ # Load checkpoint first to detect architecture
378
+ checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
379
+
380
+ # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys
381
+ has_multilayer = "classifier.0.weight" in checkpoint
382
+
383
+ # ✅ AUTO-DETECT HIDDEN SIZE from checkpoint
384
+ if has_multilayer:
385
+ hidden_size = checkpoint['classifier.0.weight'].shape[0] # Get output size of first layer
386
+ print(f"📊 Detected: Multi-layer classifier (hidden_size={hidden_size})")
387
+ else:
388
+ hidden_size = 768 # Doesn't matter for single-layer
389
+ print(f"📊 Detected: Single-layer classifier")
390
+
391
+ # Create model instance with correct architecture
392
+ POLICY_MODEL = PolicyNetwork(
393
+ model_name="bert-base-uncased",
394
+ dropout_rate=0.1,
395
+ use_multilayer=has_multilayer,
396
+ hidden_size=hidden_size # ✅ Pass detected hidden size
397
+ )
398
+
399
+ # **KEY FIX**: Handle vocab size mismatch
400
+ saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
401
+ current_vocab_size = len(POLICY_MODEL.tokenizer)
402
+
403
+ if saved_vocab_size != current_vocab_size:
404
+ print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
405
+
406
+ if abs(saved_vocab_size - current_vocab_size) <= 2:
407
+ # Small difference - just load with strict=False
408
+ print(f"✅ Loading with strict=False to handle minor vocab differences...")
409
+
410
+ # Move to device first
411
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
412
+
413
+ # Load weights with strict=False
414
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
415
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'], strict=False)
416
+ else:
417
+ POLICY_MODEL.load_state_dict(checkpoint, strict=False)
418
+ else:
419
+ # Large difference - resize properly
420
+ print(f"✅ Resizing model to match saved vocab size...")
421
+ POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
422
+
423
+ # Move to device
424
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
425
+
426
+ # Load weights
427
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
428
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
429
+ else:
430
+ POLICY_MODEL.load_state_dict(checkpoint)
431
+ else:
432
+ # No mismatch
433
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
434
+
435
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
436
+ POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
437
+ else:
438
+ POLICY_MODEL.load_state_dict(checkpoint)
439
+
440
+ # Set to evaluation mode
441
+ POLICY_MODEL.eval()
442
+
443
+ # Cache tokenizer
444
+ POLICY_TOKENIZER = POLICY_MODEL.tokenizer
445
+
446
+ print("✅ Policy network loaded and cached")
447
+ print(f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}")
448
+ print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}")
449
+
450
+
451
+ print("✅ Policy network loaded and cached")
452
+
453
+ except FileNotFoundError:
454
+ print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
455
+ print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
456
+ raise
457
+ except Exception as e:
458
+ print(f"❌ Failed to load policy model: {e}")
459
+ import traceback
460
+ traceback.print_exc()
461
+ raise
462
+
463
+ return POLICY_MODEL
464
+
465
+
466
+
467
+ # ============================================================================
468
+ # PREDICTION FUNCTIONS
469
+ # ============================================================================
470
+
471
+ def create_state_from_history(
472
+ current_query: str,
473
+ conversation_history: List[Dict],
474
+ max_history: int = 2
475
+ ) -> Dict:
476
+ """
477
+ Create state dictionary from conversation history.
478
+ Extracts last N query-action pairs.
479
+
480
+ Args:
481
+ current_query: Current user query
482
+ conversation_history: List of conversation turns
483
+ Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}
484
+ max_history: Maximum number of previous turns to include (default: 2)
485
+
486
+ Returns:
487
+ dict: State dictionary for policy network
488
+ """
489
+ state = {
490
+ 'current_query': current_query,
491
+ 'previous_queries': [],
492
+ 'previous_actions': []
493
+ }
494
+
495
+ if not conversation_history:
496
+ return state
497
+
498
+ # Extract last N conversation turns (user + assistant pairs)
499
+ relevant_history = conversation_history[-(max_history * 2):]
500
+
501
+ for i, turn in enumerate(relevant_history):
502
+ # User turns
503
+ if turn.get('role') == 'user':
504
+ query = turn.get('content', '')
505
+ state['previous_queries'].append(query)
506
+
507
+ # Look for corresponding assistant turn
508
+ if i + 1 < len(relevant_history):
509
+ bot_turn = relevant_history[i + 1]
510
+ if bot_turn.get('role') == 'assistant':
511
+ metadata = bot_turn.get('metadata', {})
512
+ action = metadata.get('policy_action', 'FETCH')
513
+ state['previous_actions'].append(action)
514
+
515
+ return state
516
+
517
+
518
+ def predict_policy_action(
519
+ query: str,
520
+ history: List[Dict] = None,
521
+ return_probs: bool = False
522
+ ) -> Dict:
523
+ """
524
+ Predict FETCH/NO_FETCH action for a query.
525
+
526
+ Args:
527
+ query: User query text
528
+ history: Conversation history (optional)
529
+ return_probs: Whether to return full probability distribution
530
+
531
+ Returns:
532
+ dict: Prediction results
533
+ {
534
+ 'action': 'FETCH' or 'NO_FETCH',
535
+ 'confidence': float (0-1),
536
+ 'fetch_prob': float,
537
+ 'no_fetch_prob': float,
538
+ 'should_retrieve': bool
539
+ }
540
+ """
541
+ # Load model (cached after first call)
542
+ model = load_policy_model()
543
+
544
+ # Create state from history
545
+ if history is None:
546
+ history = []
547
+
548
+ state = create_state_from_history(query, history)
549
+
550
+ # Predict action
551
+ probs, _ = model.predict_action(state, use_dropout=False)
552
+
553
+ # Extract probabilities
554
+ fetch_prob = float(probs[0][0])
555
+ no_fetch_prob = float(probs[0][1])
556
+
557
+ # Determine action (argmax)
558
+ action_idx = np.argmax(probs[0])
559
+ action = "FETCH" if action_idx == 0 else "NO_FETCH"
560
+ confidence = float(probs[0][action_idx])
561
+
562
+ # Check confidence threshold
563
+ should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD)
564
+
565
+ result = {
566
+ 'action': action,
567
+ 'confidence': confidence,
568
+ 'should_retrieve': should_retrieve,
569
+ 'policy_decision': action
570
+ }
571
+
572
+ if return_probs:
573
+ result['fetch_prob'] = fetch_prob
574
+ result['no_fetch_prob'] = no_fetch_prob
575
+
576
+ return result
577
+
578
+
579
+ # ============================================================================
580
+ # USAGE EXAMPLE (for reference)
581
+ # ============================================================================
582
+ """
583
+ # In your service file:
584
+ from app.ml.policy_network import predict_policy_action
585
+
586
+ # Predict action
587
+ history = [
588
+ {'role': 'user', 'content': 'What is my balance?'},
589
+ {'role': 'assistant', 'content': '$1000', 'metadata': {'policy_action': 'FETCH'}}
590
+ ]
591
+
592
+ result = predict_policy_action(
593
+ query="Thank you!",
594
+ history=history,
595
+ return_probs=True
596
+ )
597
+
598
+ print(result)
599
+ # {
600
+ # 'action': 'NO_FETCH',
601
+ # 'confidence': 0.95,
602
+ # 'should_retrieve': False,
603
+ # 'fetch_prob': 0.05,
604
+ # 'no_fetch_prob': 0.95
605
+ # }
606
+ """
607
+
608
+
609
+
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+
618
+
619
+
620
+
621
+
622
+
623
+
624
+
625
+
626
+
627
+
628
+
629
+
630
+
631
+
632
+
633
+
634
+
635
+
636
+
637
+
638
+
639
+
640
+
641
+
642
+
643
+
644
+
645
+
646
+
647
+
648
+
649
+
650
+
651
+
652
+
653
+
654
+
655
+
656
+
657
+
658
+
659
+
660
+
661
+
662
+
663
+
664
+
665
+
666
+
667
+
668
+
669
+
670
+
671
+
672
+
673
+
674
+
675
+
676
+
677
+
678
+
679
+
680
+
681
+
682
+
683
+
684
+
685
+
686
+
687
+
688
+
689
+
690
+
691
+
692
+
693
+
694
+
695
+
696
+
697
+
698
+
699
+
700
+
701
+
702
+
703
  """
704
  BERT-based Policy Network for FETCH/NO_FETCH decisions
705
  Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization)
app/ml/policy_network.py CHANGED
@@ -21,16 +21,17 @@ from app.config import settings
21
  # POLICY NETWORK (From RL.py)
22
  # ============================================================================
23
 
 
24
  class PolicyNetwork(nn.Module):
25
  """
26
  BERT-based Policy Network for deciding FETCH vs NO_FETCH actions.
27
-
28
  Architecture:
29
  - Base: BERT-base-uncased (pre-trained)
30
  - Input: Current query + conversation history + previous actions
31
  - Output: 2-class softmax (FETCH=0, NO_FETCH=1)
32
- - Special tokens: [FETCH], [NO_FETCH] for action encoding
33
-
34
  Training Details:
35
  - Loss: Policy Gradient + Entropy Regularization
36
  - Optimizer: AdamW
@@ -39,235 +40,207 @@ class PolicyNetwork(nn.Module):
39
  * NO_FETCH + Good: +2.0
40
  * NO_FETCH + Bad: -0.5
41
  """
42
-
43
- # def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True):
44
- # super(PolicyNetwork, self).__init__()
45
-
46
- # # Load pre-trained BERT
47
- # self.bert = AutoModel.from_pretrained(model_name)
48
-
49
- # # Load tokenizer
50
- # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51
-
52
- # # Add special tokens for actions: [FETCH] and [NO_FETCH]
53
- # special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
54
- # self.tokenizer.add_special_tokens(special_tokens)
55
-
56
- # # Resize BERT embeddings to accommodate new tokens
57
- # self.bert.resize_token_embeddings(len(self.tokenizer))
58
-
59
- # # Initialize random embeddings for special tokens
60
- # self._init_action_embeddings()
61
-
62
- # # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE
63
- # if use_multilayer:
64
- # # Multi-layer classifier (your new trained model)
65
- # self.classifier = nn.Sequential(
66
- # nn.Linear(self.bert.config.hidden_size, 256),
67
- # nn.ReLU(),
68
- # nn.Dropout(dropout_rate),
69
- # nn.Linear(256, 2)
70
- # )
71
- # else:
72
- # # Single-layer classifier (fallback)
73
- # self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
74
-
75
- # # Dropout for regularization
76
- # self.dropout = nn.Dropout(dropout_rate)
77
-
78
-
79
- def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True, hidden_size: int = 128):
80
  super(PolicyNetwork, self).__init__()
81
-
82
- # Load pre-trained BERT
83
  self.bert = AutoModel.from_pretrained(model_name)
84
-
85
- # Load tokenizer
86
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
87
-
88
- # Add special tokens for actions: [FETCH] and [NO_FETCH]
89
- special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
90
- self.tokenizer.add_special_tokens(special_tokens)
91
-
92
- # Resize BERT embeddings to accommodate new tokens
93
- self.bert.resize_token_embeddings(len(self.tokenizer))
94
-
95
- # Initialize random embeddings for special tokens
96
- self._init_action_embeddings()
97
-
98
- # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size)
99
  if use_multilayer:
100
- # Multi-layer classifier with specified hidden size (128 or 256)
101
  self.classifier = nn.Sequential(
102
- nn.Linear(self.bert.config.hidden_size, hidden_size), # ✅ Use hidden_size param
103
  nn.ReLU(),
104
  nn.Dropout(dropout_rate),
105
- nn.Linear(hidden_size, 2) # ✅ Use hidden_size param
106
  )
107
  else:
108
- # Single-layer classifier (fallback)
109
  self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
110
-
111
- # Dropout for regularization
112
- self.dropout = nn.Dropout(dropout_rate)
113
-
114
-
115
-
116
-
117
 
 
 
118
 
119
  def _init_action_embeddings(self):
120
  """
121
- Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
122
- These are learned during training.
 
 
 
 
123
  """
124
  with torch.no_grad():
125
- # Get token IDs for special tokens
126
  fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]")
127
  no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]")
128
-
129
- # Get embedding dimension
130
  embedding_dim = self.bert.config.hidden_size
131
-
132
- # Initialize with small random values (same as BERT initialization)
133
- self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02
134
- self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02
135
-
136
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
137
  """
138
  Forward pass through BERT + classifier.
139
-
140
  Args:
141
  input_ids: Tokenized input IDs (shape: [batch_size, seq_len])
142
  attention_mask: Attention mask (shape: [batch_size, seq_len])
143
-
144
  Returns:
145
  logits: Raw logits (shape: [batch_size, 2])
146
  probs: Softmax probabilities (shape: [batch_size, 2])
147
  """
148
- # Pass through BERT
149
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
150
-
151
- # Extract [CLS] token representation (first token)
152
  cls_output = outputs.last_hidden_state[:, 0, :]
153
-
154
  # Apply dropout
155
  cls_output = self.dropout(cls_output)
156
-
157
  # Classification
158
  logits = self.classifier(cls_output)
159
-
160
  # Softmax for probabilities
161
  probs = F.softmax(logits, dim=-1)
162
-
163
  return logits, probs
164
-
165
  def encode_state(
166
  self,
167
  state: Dict,
168
- max_length: int = None
169
  ) -> Dict[str, torch.Tensor]:
170
  """
171
  Encode conversation state into BERT input format.
172
-
173
  State structure:
174
  {
175
  'previous_queries': [query1, query2, ...],
176
  'previous_actions': ['FETCH', 'NO_FETCH', ...],
177
  'current_query': 'user query'
178
  }
179
-
180
  Encoding format:
181
- "Previous query 1: [Action: [FETCH]] Previous query 2: [Action: [NO_FETCH]] Current query: <query>"
182
-
183
  Args:
184
  state: State dictionary
185
  max_length: Maximum sequence length (default from config)
186
-
187
  Returns:
188
  dict: Tokenized inputs (input_ids, attention_mask)
189
  """
190
  if max_length is None:
191
  max_length = settings.POLICY_MAX_LEN
192
-
193
  # Build state text from conversation history
194
  state_text = ""
195
-
196
  # Add previous queries and their actions
197
- prev_queries = state.get('previous_queries', [])
198
- prev_actions = state.get('previous_actions', [])
199
-
200
  if prev_queries and prev_actions:
201
- for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)):
202
- state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] "
203
-
 
 
 
 
204
  # Add current query
205
- current_query = state.get('current_query', '')
206
  state_text += f"Current query: {current_query}"
207
-
208
  # Tokenize
209
  encoding = self.tokenizer(
210
  state_text,
211
  truncation=True,
212
- padding='max_length',
213
  max_length=max_length,
214
- return_tensors='pt'
215
  )
216
-
217
  return encoding
218
-
219
  def predict_action(
220
  self,
221
  state: Dict,
222
  use_dropout: bool = False,
223
- num_samples: int = 10
224
  ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
225
  """
226
  Predict action probabilities for a given state.
227
-
228
  Args:
229
  state: Conversation state dictionary
230
  use_dropout: Whether to use MC Dropout for uncertainty estimation
231
  num_samples: Number of MC Dropout samples (if use_dropout=True)
232
-
233
  Returns:
234
  probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)]
235
  uncertainty: Standard deviation across samples (if use_dropout=True)
236
  """
237
  device = next(self.parameters()).device
238
-
239
  if use_dropout:
240
  # MC Dropout for uncertainty estimation
241
  self.train() # Enable dropout during inference
242
  all_probs = []
243
-
244
  for _ in range(num_samples):
245
  with torch.no_grad():
246
  encoding = self.encode_state(state)
247
- input_ids = encoding['input_ids'].to(device)
248
- attention_mask = encoding['attention_mask'].to(device)
249
-
250
  _, probs = self.forward(input_ids, attention_mask)
251
  all_probs.append(probs.cpu().numpy())
252
-
253
  # Average probabilities across samples
254
  avg_probs = np.mean(all_probs, axis=0)
255
-
256
  # Calculate uncertainty (standard deviation)
257
  uncertainty = np.std(all_probs, axis=0)
258
-
259
  return avg_probs, uncertainty
260
-
261
  else:
262
  # Standard inference (no uncertainty estimation)
263
  self.eval()
264
  with torch.no_grad():
265
  encoding = self.encode_state(state)
266
- input_ids = encoding['input_ids'].to(device)
267
- attention_mask = encoding['attention_mask'].to(device)
268
-
269
  _, probs = self.forward(input_ids, attention_mask)
270
-
271
  return probs.cpu().numpy(), None
272
 
273
 
@@ -280,254 +253,173 @@ POLICY_MODEL: Optional[PolicyNetwork] = None
280
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
281
 
282
 
283
- # def load_policy_model() -> PolicyNetwork:
284
- # """
285
- # Load trained policy model (called once on startup).
286
- # Downloads from HuggingFace Hub if not present locally.
287
- # Uses module-level caching - model stays in RAM.
288
-
289
- # Returns:
290
- # PolicyNetwork: Loaded policy model
291
- # """
292
- # global POLICY_MODEL, POLICY_TOKENIZER
293
-
294
- # if POLICY_MODEL is None:
295
- # # Download model from HF Hub if needed (for deployment)
296
- # settings.download_model_if_needed(
297
- # hf_filename="models/policy_query_only.pt",
298
- # local_path=settings.POLICY_MODEL_PATH
299
- # )
300
-
301
- # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
302
- # try:
303
- # # Load checkpoint first to detect architecture
304
- # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
305
-
306
- # # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys
307
- # has_multilayer = "classifier.0.weight" in checkpoint
308
-
309
- # print(f"📊 Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier")
310
-
311
- # # Create model instance with correct architecture
312
- # POLICY_MODEL = PolicyNetwork(
313
- # model_name="bert-base-uncased",
314
- # dropout_rate=0.1,
315
- # use_multilayer=has_multilayer # ✅ Auto-detect!
316
- # )
317
-
318
- # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
319
- # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
320
- # current_vocab_size = len(POLICY_MODEL.tokenizer)
321
-
322
- # if saved_vocab_size != current_vocab_size:
323
- # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
324
- # print(f"✅ Resizing tokenizer and embeddings to match saved model...")
325
- # # Resize model to match saved checkpoint
326
- # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
327
-
328
- # # Move to device
329
- # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
330
-
331
- # # Now load trained weights (sizes and architecture will match!)
332
- # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
333
- # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
334
- # else:
335
- # POLICY_MODEL.load_state_dict(checkpoint)
336
-
337
- # # Set to evaluation mode
338
- # POLICY_MODEL.eval()
339
-
340
- # # Cache tokenizer
341
- # POLICY_TOKENIZER = POLICY_MODEL.tokenizer
342
-
343
- # print("✅ Policy network loaded and cached")
344
-
345
- # except FileNotFoundError:
346
- # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
347
- # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
348
- # raise
349
- # except Exception as e:
350
- # print(f"❌ Failed to load policy model: {e}")
351
- # import traceback
352
- # traceback.print_exc()
353
- # raise
354
-
355
- # return POLICY_MODEL
356
-
357
  def load_policy_model() -> PolicyNetwork:
358
  """
359
  Load trained policy model (called once on startup).
360
  Downloads from HuggingFace Hub if not present locally.
361
  Uses module-level caching - model stays in RAM.
362
-
363
  Returns:
364
  PolicyNetwork: Loaded policy model
365
  """
366
  global POLICY_MODEL, POLICY_TOKENIZER
367
-
368
  if POLICY_MODEL is None:
369
  # Download model from HF Hub if needed (for deployment)
370
  settings.download_model_if_needed(
371
  hf_filename="models/policy_query_only.pt",
372
- local_path=settings.POLICY_MODEL_PATH
373
  )
374
-
375
  print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
376
  try:
377
  # Load checkpoint first to detect architecture
378
- checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
379
-
 
 
 
 
 
 
 
 
380
  # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys
381
- has_multilayer = "classifier.0.weight" in checkpoint
382
-
383
- # ✅ AUTO-DETECT HIDDEN SIZE from checkpoint
384
  if has_multilayer:
385
- hidden_size = checkpoint['classifier.0.weight'].shape[0] # Get output size of first layer
386
- print(f"📊 Detected: Multi-layer classifier (hidden_size={hidden_size})")
 
 
387
  else:
388
- hidden_size = 768 # Doesn't matter for single-layer
389
- print(f"📊 Detected: Single-layer classifier")
390
-
391
  # Create model instance with correct architecture
392
  POLICY_MODEL = PolicyNetwork(
393
  model_name="bert-base-uncased",
394
  dropout_rate=0.1,
395
  use_multilayer=has_multilayer,
396
- hidden_size=hidden_size # ✅ Pass detected hidden size
 
 
 
 
 
 
 
 
397
  )
398
-
399
- # **KEY FIX**: Handle vocab size mismatch
400
- saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
401
- current_vocab_size = len(POLICY_MODEL.tokenizer)
402
 
403
  if saved_vocab_size != current_vocab_size:
404
- print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
405
-
406
- if abs(saved_vocab_size - current_vocab_size) <= 2:
407
- # Small difference - just load with strict=False
408
- print(f"✅ Loading with strict=False to handle minor vocab differences...")
409
-
410
- # Move to device first
411
- POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
412
-
413
- # Load weights with strict=False
414
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
415
- POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'], strict=False)
416
- else:
417
- POLICY_MODEL.load_state_dict(checkpoint, strict=False)
418
- else:
419
- # Large difference - resize properly
420
- print(f"✅ Resizing model to match saved vocab size...")
421
- POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
422
-
423
- # Move to device
424
- POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
425
-
426
- # Load weights
427
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
428
- POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
429
- else:
430
- POLICY_MODEL.load_state_dict(checkpoint)
431
- else:
432
- # No mismatch
433
- POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
434
-
435
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
436
- POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
437
- else:
438
- POLICY_MODEL.load_state_dict(checkpoint)
439
-
440
- # Set to evaluation mode
441
  POLICY_MODEL.eval()
442
 
443
- # Cache tokenizer
444
  POLICY_TOKENIZER = POLICY_MODEL.tokenizer
445
 
446
  print("✅ Policy network loaded and cached")
447
- print(f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}")
 
 
448
  print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}")
449
 
450
-
451
- print("✅ Policy network loaded and cached")
452
-
453
  except FileNotFoundError:
454
  print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
455
- print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
 
 
456
  raise
457
  except Exception as e:
458
  print(f"❌ Failed to load policy model: {e}")
459
  import traceback
 
460
  traceback.print_exc()
461
  raise
462
-
463
- return POLICY_MODEL
464
 
 
465
 
466
 
467
  # ============================================================================
468
  # PREDICTION FUNCTIONS
469
  # ============================================================================
470
 
 
471
  def create_state_from_history(
472
  current_query: str,
473
  conversation_history: List[Dict],
474
- max_history: int = 2
475
  ) -> Dict:
476
  """
477
  Create state dictionary from conversation history.
478
  Extracts last N query-action pairs.
479
-
480
  Args:
481
  current_query: Current user query
482
  conversation_history: List of conversation turns
483
  Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}
484
  max_history: Maximum number of previous turns to include (default: 2)
485
-
486
  Returns:
487
  dict: State dictionary for policy network
488
  """
489
  state = {
490
- 'current_query': current_query,
491
- 'previous_queries': [],
492
- 'previous_actions': []
493
  }
494
-
495
  if not conversation_history:
496
  return state
497
-
498
  # Extract last N conversation turns (user + assistant pairs)
499
- relevant_history = conversation_history[-(max_history * 2):]
500
-
501
  for i, turn in enumerate(relevant_history):
502
  # User turns
503
- if turn.get('role') == 'user':
504
- query = turn.get('content', '')
505
- state['previous_queries'].append(query)
506
-
507
  # Look for corresponding assistant turn
508
  if i + 1 < len(relevant_history):
509
  bot_turn = relevant_history[i + 1]
510
- if bot_turn.get('role') == 'assistant':
511
- metadata = bot_turn.get('metadata', {})
512
- action = metadata.get('policy_action', 'FETCH')
513
- state['previous_actions'].append(action)
514
-
515
  return state
516
 
517
 
518
  def predict_policy_action(
519
  query: str,
520
  history: List[Dict] = None,
521
- return_probs: bool = False
522
  ) -> Dict:
523
  """
524
  Predict FETCH/NO_FETCH action for a query.
525
-
526
  Args:
527
  query: User query text
528
  history: Conversation history (optional)
529
  return_probs: Whether to return full probability distribution
530
-
531
  Returns:
532
  dict: Prediction results
533
  {
@@ -540,39 +432,41 @@ def predict_policy_action(
540
  """
541
  # Load model (cached after first call)
542
  model = load_policy_model()
543
-
544
  # Create state from history
545
  if history is None:
546
  history = []
547
-
548
  state = create_state_from_history(query, history)
549
-
550
  # Predict action
551
  probs, _ = model.predict_action(state, use_dropout=False)
552
-
553
  # Extract probabilities
554
  fetch_prob = float(probs[0][0])
555
  no_fetch_prob = float(probs[0][1])
556
-
557
  # Determine action (argmax)
558
- action_idx = np.argmax(probs[0])
559
  action = "FETCH" if action_idx == 0 else "NO_FETCH"
560
  confidence = float(probs[0][action_idx])
561
-
562
  # Check confidence threshold
563
- should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD)
564
-
 
 
565
  result = {
566
- 'action': action,
567
- 'confidence': confidence,
568
- 'should_retrieve': should_retrieve,
569
- 'policy_decision': action
570
  }
571
-
572
  if return_probs:
573
- result['fetch_prob'] = fetch_prob
574
- result['no_fetch_prob'] = no_fetch_prob
575
-
576
  return result
577
 
578
 
 
21
  # POLICY NETWORK (From RL.py)
22
  # ============================================================================
23
 
24
+
25
  class PolicyNetwork(nn.Module):
26
  """
27
  BERT-based Policy Network for deciding FETCH vs NO_FETCH actions.
28
+
29
  Architecture:
30
  - Base: BERT-base-uncased (pre-trained)
31
  - Input: Current query + conversation history + previous actions
32
  - Output: 2-class softmax (FETCH=0, NO_FETCH=1)
33
+ - Special tokens: [FETCH], [NO_FETCH] for action encoding (encoded as plain text)
34
+
35
  Training Details:
36
  - Loss: Policy Gradient + Entropy Regularization
37
  - Optimizer: AdamW
 
40
  * NO_FETCH + Good: +2.0
41
  * NO_FETCH + Bad: -0.5
42
  """
43
+
44
+ def __init__(
45
+ self,
46
+ model_name: str = "bert-base-uncased",
47
+ dropout_rate: float = 0.1,
48
+ use_multilayer: bool = True,
49
+ hidden_size: int = 128,
50
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  super(PolicyNetwork, self).__init__()
52
+
53
+ # Load pre-trained BERT and tokenizer
54
  self.bert = AutoModel.from_pretrained(model_name)
 
 
55
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+
57
+ # IMPORTANT:
58
+ # Do NOT add extra special tokens or resize embeddings here.
59
+ # The saved checkpoint was trained with the ORIGINAL BERT vocab
60
+ # (vocab_size=30522). Changing vocab size before loading will cause
61
+ # the size mismatch error:
62
+ # saved=30522, current=30524
63
+
64
+ self.use_multilayer = use_multilayer
65
+
66
+ # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size)
 
67
  if use_multilayer:
68
+ # Multi-layer classifier with specified hidden size (128 or 256)
69
  self.classifier = nn.Sequential(
70
+ nn.Linear(self.bert.config.hidden_size, hidden_size),
71
  nn.ReLU(),
72
  nn.Dropout(dropout_rate),
73
+ nn.Linear(hidden_size, 2),
74
  )
75
  else:
76
+ # Single-layer classifier (fallback)
77
  self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
 
 
 
 
 
 
 
78
 
79
+ # Dropout for regularization
80
+ self.dropout = nn.Dropout(dropout_rate)
81
 
82
  def _init_action_embeddings(self):
83
  """
84
+ (Currently unused)
85
+
86
+ In an alternative setup, this could initialize random embeddings
87
+ for [FETCH] and [NO_FETCH] tokens if they were added as true
88
+ special tokens. For this checkpoint we DO NOT change the vocab
89
+ size, so we leave this unused to avoid shape mismatches.
90
  """
91
  with torch.no_grad():
 
92
  fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]")
93
  no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]")
94
+
 
95
  embedding_dim = self.bert.config.hidden_size
96
+
97
+ self.bert.embeddings.word_embeddings.weight[fetch_id] = (
98
+ torch.randn(embedding_dim) * 0.02
99
+ )
100
+ self.bert.embeddings.word_embeddings.weight[no_fetch_id] = (
101
+ torch.randn(embedding_dim) * 0.02
102
+ )
103
+
104
+ def forward(
105
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor
106
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
107
  """
108
  Forward pass through BERT + classifier.
109
+
110
  Args:
111
  input_ids: Tokenized input IDs (shape: [batch_size, seq_len])
112
  attention_mask: Attention mask (shape: [batch_size, seq_len])
113
+
114
  Returns:
115
  logits: Raw logits (shape: [batch_size, 2])
116
  probs: Softmax probabilities (shape: [batch_size, 2])
117
  """
 
118
  outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
119
+
120
+ # [CLS] token representation (first token)
121
  cls_output = outputs.last_hidden_state[:, 0, :]
122
+
123
  # Apply dropout
124
  cls_output = self.dropout(cls_output)
125
+
126
  # Classification
127
  logits = self.classifier(cls_output)
128
+
129
  # Softmax for probabilities
130
  probs = F.softmax(logits, dim=-1)
131
+
132
  return logits, probs
133
+
134
  def encode_state(
135
  self,
136
  state: Dict,
137
+ max_length: int = None,
138
  ) -> Dict[str, torch.Tensor]:
139
  """
140
  Encode conversation state into BERT input format.
141
+
142
  State structure:
143
  {
144
  'previous_queries': [query1, query2, ...],
145
  'previous_actions': ['FETCH', 'NO_FETCH', ...],
146
  'current_query': 'user query'
147
  }
148
+
149
  Encoding format:
150
+ "Previous query 1: {q1} [Action: [FETCH]] Previous query 2: {q2} [Action: [NO_FETCH]] Current query: <query>"
151
+
152
  Args:
153
  state: State dictionary
154
  max_length: Maximum sequence length (default from config)
155
+
156
  Returns:
157
  dict: Tokenized inputs (input_ids, attention_mask)
158
  """
159
  if max_length is None:
160
  max_length = settings.POLICY_MAX_LEN
161
+
162
  # Build state text from conversation history
163
  state_text = ""
164
+
165
  # Add previous queries and their actions
166
+ prev_queries = state.get("previous_queries", [])
167
+ prev_actions = state.get("previous_actions", [])
168
+
169
  if prev_queries and prev_actions:
170
+ for i, (prev_query, prev_action) in enumerate(
171
+ zip(prev_queries, prev_actions)
172
+ ):
173
+ state_text += (
174
+ f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] "
175
+ )
176
+
177
  # Add current query
178
+ current_query = state.get("current_query", "")
179
  state_text += f"Current query: {current_query}"
180
+
181
  # Tokenize
182
  encoding = self.tokenizer(
183
  state_text,
184
  truncation=True,
185
+ padding="max_length",
186
  max_length=max_length,
187
+ return_tensors="pt",
188
  )
189
+
190
  return encoding
191
+
192
  def predict_action(
193
  self,
194
  state: Dict,
195
  use_dropout: bool = False,
196
+ num_samples: int = 10,
197
  ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
198
  """
199
  Predict action probabilities for a given state.
200
+
201
  Args:
202
  state: Conversation state dictionary
203
  use_dropout: Whether to use MC Dropout for uncertainty estimation
204
  num_samples: Number of MC Dropout samples (if use_dropout=True)
205
+
206
  Returns:
207
  probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)]
208
  uncertainty: Standard deviation across samples (if use_dropout=True)
209
  """
210
  device = next(self.parameters()).device
211
+
212
  if use_dropout:
213
  # MC Dropout for uncertainty estimation
214
  self.train() # Enable dropout during inference
215
  all_probs = []
216
+
217
  for _ in range(num_samples):
218
  with torch.no_grad():
219
  encoding = self.encode_state(state)
220
+ input_ids = encoding["input_ids"].to(device)
221
+ attention_mask = encoding["attention_mask"].to(device)
222
+
223
  _, probs = self.forward(input_ids, attention_mask)
224
  all_probs.append(probs.cpu().numpy())
225
+
226
  # Average probabilities across samples
227
  avg_probs = np.mean(all_probs, axis=0)
228
+
229
  # Calculate uncertainty (standard deviation)
230
  uncertainty = np.std(all_probs, axis=0)
231
+
232
  return avg_probs, uncertainty
233
+
234
  else:
235
  # Standard inference (no uncertainty estimation)
236
  self.eval()
237
  with torch.no_grad():
238
  encoding = self.encode_state(state)
239
+ input_ids = encoding["input_ids"].to(device)
240
+ attention_mask = encoding["attention_mask"].to(device)
241
+
242
  _, probs = self.forward(input_ids, attention_mask)
243
+
244
  return probs.cpu().numpy(), None
245
 
246
 
 
253
  POLICY_TOKENIZER: Optional[AutoTokenizer] = None
254
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def load_policy_model() -> PolicyNetwork:
257
  """
258
  Load trained policy model (called once on startup).
259
  Downloads from HuggingFace Hub if not present locally.
260
  Uses module-level caching - model stays in RAM.
261
+
262
  Returns:
263
  PolicyNetwork: Loaded policy model
264
  """
265
  global POLICY_MODEL, POLICY_TOKENIZER
266
+
267
  if POLICY_MODEL is None:
268
  # Download model from HF Hub if needed (for deployment)
269
  settings.download_model_if_needed(
270
  hf_filename="models/policy_query_only.pt",
271
+ local_path=settings.POLICY_MODEL_PATH,
272
  )
273
+
274
  print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
275
  try:
276
  # Load checkpoint first to detect architecture
277
+ checkpoint = torch.load(
278
+ settings.POLICY_MODEL_PATH, map_location=settings.DEVICE
279
+ )
280
+
281
+ # Unwrap if saved as {"model_state_dict": ...}
282
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
283
+ state_dict = checkpoint["model_state_dict"]
284
+ else:
285
+ state_dict = checkpoint
286
+
287
  # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys
288
+ has_multilayer = "classifier.0.weight" in state_dict
289
+
 
290
  if has_multilayer:
291
+ hidden_size = state_dict["classifier.0.weight"].shape[0]
292
+ print(
293
+ f"📊 Detected: Multi-layer classifier (hidden_size={hidden_size})"
294
+ )
295
  else:
296
+ hidden_size = 768 # not really used for single-layer
297
+ print("📊 Detected: Single-layer classifier")
298
+
299
  # Create model instance with correct architecture
300
  POLICY_MODEL = PolicyNetwork(
301
  model_name="bert-base-uncased",
302
  dropout_rate=0.1,
303
  use_multilayer=has_multilayer,
304
+ hidden_size=hidden_size,
305
+ )
306
+
307
+ # Align vocab size / embeddings with checkpoint
308
+ saved_vocab_size = state_dict[
309
+ "bert.embeddings.word_embeddings.weight"
310
+ ].shape[0]
311
+ current_vocab_size = (
312
+ POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings
313
  )
 
 
 
 
314
 
315
  if saved_vocab_size != current_vocab_size:
316
+ print(
317
+ f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}"
318
+ )
319
+ print(
320
+ "✅ Resizing BERT embeddings to match saved checkpoint vocab size..."
321
+ )
322
+ POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
323
+
324
+ # Move to device
325
+ POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
326
+
327
+ # Load weights (shapes now match, so strict=False is just safety)
328
+ POLICY_MODEL.load_state_dict(state_dict, strict=False)
329
+
330
+ # Set to evaluation mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  POLICY_MODEL.eval()
332
 
333
+ # Cache tokenizer
334
  POLICY_TOKENIZER = POLICY_MODEL.tokenizer
335
 
336
  print("✅ Policy network loaded and cached")
337
+ print(
338
+ f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}"
339
+ )
340
  print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}")
341
 
 
 
 
342
  except FileNotFoundError:
343
  print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
344
+ print(
345
+ f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}"
346
+ )
347
  raise
348
  except Exception as e:
349
  print(f"❌ Failed to load policy model: {e}")
350
  import traceback
351
+
352
  traceback.print_exc()
353
  raise
 
 
354
 
355
+ return POLICY_MODEL
356
 
357
 
358
  # ============================================================================
359
  # PREDICTION FUNCTIONS
360
  # ============================================================================
361
 
362
+
363
  def create_state_from_history(
364
  current_query: str,
365
  conversation_history: List[Dict],
366
+ max_history: int = 2,
367
  ) -> Dict:
368
  """
369
  Create state dictionary from conversation history.
370
  Extracts last N query-action pairs.
371
+
372
  Args:
373
  current_query: Current user query
374
  conversation_history: List of conversation turns
375
  Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}
376
  max_history: Maximum number of previous turns to include (default: 2)
377
+
378
  Returns:
379
  dict: State dictionary for policy network
380
  """
381
  state = {
382
+ "current_query": current_query,
383
+ "previous_queries": [],
384
+ "previous_actions": [],
385
  }
386
+
387
  if not conversation_history:
388
  return state
389
+
390
  # Extract last N conversation turns (user + assistant pairs)
391
+ relevant_history = conversation_history[-(max_history * 2) :]
392
+
393
  for i, turn in enumerate(relevant_history):
394
  # User turns
395
+ if turn.get("role") == "user":
396
+ query = turn.get("content", "")
397
+ state["previous_queries"].append(query)
398
+
399
  # Look for corresponding assistant turn
400
  if i + 1 < len(relevant_history):
401
  bot_turn = relevant_history[i + 1]
402
+ if bot_turn.get("role") == "assistant":
403
+ metadata = bot_turn.get("metadata", {})
404
+ action = metadata.get("policy_action", "FETCH")
405
+ state["previous_actions"].append(action)
406
+
407
  return state
408
 
409
 
410
  def predict_policy_action(
411
  query: str,
412
  history: List[Dict] = None,
413
+ return_probs: bool = False,
414
  ) -> Dict:
415
  """
416
  Predict FETCH/NO_FETCH action for a query.
417
+
418
  Args:
419
  query: User query text
420
  history: Conversation history (optional)
421
  return_probs: Whether to return full probability distribution
422
+
423
  Returns:
424
  dict: Prediction results
425
  {
 
432
  """
433
  # Load model (cached after first call)
434
  model = load_policy_model()
435
+
436
  # Create state from history
437
  if history is None:
438
  history = []
439
+
440
  state = create_state_from_history(query, history)
441
+
442
  # Predict action
443
  probs, _ = model.predict_action(state, use_dropout=False)
444
+
445
  # Extract probabilities
446
  fetch_prob = float(probs[0][0])
447
  no_fetch_prob = float(probs[0][1])
448
+
449
  # Determine action (argmax)
450
+ action_idx = int(np.argmax(probs[0]))
451
  action = "FETCH" if action_idx == 0 else "NO_FETCH"
452
  confidence = float(probs[0][action_idx])
453
+
454
  # Check confidence threshold
455
+ should_retrieve = (action == "FETCH") or (
456
+ action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD
457
+ )
458
+
459
  result = {
460
+ "action": action,
461
+ "confidence": confidence,
462
+ "should_retrieve": should_retrieve,
463
+ "policy_decision": action,
464
  }
465
+
466
  if return_probs:
467
+ result["fetch_prob"] = fetch_prob
468
+ result["no_fetch_prob"] = no_fetch_prob
469
+
470
  return result
471
 
472