Spaces:
Sleeping
Sleeping
| """ | |
| BERT-based Policy Network for FETCH/NO_FETCH decisions | |
| Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization) | |
| This is adapted from your RL.py with: | |
| - PolicyNetwork class (BERT-based) | |
| - State encoding from conversation history | |
| - Action prediction (FETCH vs NO_FETCH) | |
| - Module-level caching (load once on startup) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict, Optional, Tuple | |
| from transformers import AutoTokenizer, AutoModel | |
| from app.config import settings | |
| # ============================================================================ | |
| # POLICY NETWORK (From RL.py) | |
| # ============================================================================ | |
| class PolicyNetwork(nn.Module): | |
| """ | |
| BERT-based Policy Network for deciding FETCH vs NO_FETCH actions. | |
| Architecture: | |
| - Base: BERT-base-uncased (pre-trained) | |
| - Input: Current query + conversation history + previous actions | |
| - Output: 2-class softmax (FETCH=0, NO_FETCH=1) | |
| - Special tokens: [FETCH], [NO_FETCH] for action encoding | |
| Training Details: | |
| - Loss: Policy Gradient + Entropy Regularization | |
| - Optimizer: AdamW | |
| - Reward structure: | |
| * FETCH: +0.5 (always) | |
| * NO_FETCH + Good: +2.0 | |
| * NO_FETCH + Bad: -0.5 | |
| """ | |
| # def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True): | |
| # super(PolicyNetwork, self).__init__() | |
| # # Load pre-trained BERT | |
| # self.bert = AutoModel.from_pretrained(model_name) | |
| # # Load tokenizer | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # # Add special tokens for actions: [FETCH] and [NO_FETCH] | |
| # special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]} | |
| # self.tokenizer.add_special_tokens(special_tokens) | |
| # # Resize BERT embeddings to accommodate new tokens | |
| # self.bert.resize_token_embeddings(len(self.tokenizer)) | |
| # # Initialize random embeddings for special tokens | |
| # self._init_action_embeddings() | |
| # # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE | |
| # if use_multilayer: | |
| # # Multi-layer classifier (your new trained model) | |
| # self.classifier = nn.Sequential( | |
| # nn.Linear(self.bert.config.hidden_size, 256), | |
| # nn.ReLU(), | |
| # nn.Dropout(dropout_rate), | |
| # nn.Linear(256, 2) | |
| # ) | |
| # else: | |
| # # Single-layer classifier (fallback) | |
| # self.classifier = nn.Linear(self.bert.config.hidden_size, 2) | |
| # # Dropout for regularization | |
| # self.dropout = nn.Dropout(dropout_rate) | |
| def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1, use_multilayer: bool = True, hidden_size: int = 128): | |
| super(PolicyNetwork, self).__init__() | |
| # Load pre-trained BERT | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Add special tokens for actions: [FETCH] and [NO_FETCH] | |
| special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]} | |
| self.tokenizer.add_special_tokens(special_tokens) | |
| # Resize BERT embeddings to accommodate new tokens | |
| self.bert.resize_token_embeddings(len(self.tokenizer)) | |
| # Initialize random embeddings for special tokens | |
| self._init_action_embeddings() | |
| # ✅ FLEXIBLE CLASSIFIER ARCHITECTURE (with configurable hidden size) | |
| if use_multilayer: | |
| # Multi-layer classifier with specified hidden size (128 or 256) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.bert.config.hidden_size, hidden_size), # ✅ Use hidden_size param | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(hidden_size, 2) # ✅ Use hidden_size param | |
| ) | |
| else: | |
| # Single-layer classifier (fallback) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, 2) | |
| # Dropout for regularization | |
| self.dropout = nn.Dropout(dropout_rate) | |
| def _init_action_embeddings(self): | |
| """ | |
| Initialize random embeddings for [FETCH] and [NO_FETCH] tokens. | |
| These are learned during training. | |
| """ | |
| with torch.no_grad(): | |
| # Get token IDs for special tokens | |
| fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]") | |
| no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]") | |
| # Get embedding dimension | |
| embedding_dim = self.bert.config.hidden_size | |
| # Initialize with small random values (same as BERT initialization) | |
| self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02 | |
| self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02 | |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through BERT + classifier. | |
| Args: | |
| input_ids: Tokenized input IDs (shape: [batch_size, seq_len]) | |
| attention_mask: Attention mask (shape: [batch_size, seq_len]) | |
| Returns: | |
| logits: Raw logits (shape: [batch_size, 2]) | |
| probs: Softmax probabilities (shape: [batch_size, 2]) | |
| """ | |
| # Pass through BERT | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| # Extract [CLS] token representation (first token) | |
| cls_output = outputs.last_hidden_state[:, 0, :] | |
| # Apply dropout | |
| cls_output = self.dropout(cls_output) | |
| # Classification | |
| logits = self.classifier(cls_output) | |
| # Softmax for probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| return logits, probs | |
| def encode_state( | |
| self, | |
| state: Dict, | |
| max_length: int = None | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Encode conversation state into BERT input format. | |
| State structure: | |
| { | |
| 'previous_queries': [query1, query2, ...], | |
| 'previous_actions': ['FETCH', 'NO_FETCH', ...], | |
| 'current_query': 'user query' | |
| } | |
| Encoding format: | |
| "Previous query 1: [Action: [FETCH]] Previous query 2: [Action: [NO_FETCH]] Current query: <query>" | |
| Args: | |
| state: State dictionary | |
| max_length: Maximum sequence length (default from config) | |
| Returns: | |
| dict: Tokenized inputs (input_ids, attention_mask) | |
| """ | |
| if max_length is None: | |
| max_length = settings.POLICY_MAX_LEN | |
| # Build state text from conversation history | |
| state_text = "" | |
| # Add previous queries and their actions | |
| prev_queries = state.get('previous_queries', []) | |
| prev_actions = state.get('previous_actions', []) | |
| if prev_queries and prev_actions: | |
| for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)): | |
| state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] " | |
| # Add current query | |
| current_query = state.get('current_query', '') | |
| state_text += f"Current query: {current_query}" | |
| # Tokenize | |
| encoding = self.tokenizer( | |
| state_text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ) | |
| return encoding | |
| def predict_action( | |
| self, | |
| state: Dict, | |
| use_dropout: bool = False, | |
| num_samples: int = 10 | |
| ) -> Tuple[np.ndarray, Optional[np.ndarray]]: | |
| """ | |
| Predict action probabilities for a given state. | |
| Args: | |
| state: Conversation state dictionary | |
| use_dropout: Whether to use MC Dropout for uncertainty estimation | |
| num_samples: Number of MC Dropout samples (if use_dropout=True) | |
| Returns: | |
| probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)] | |
| uncertainty: Standard deviation across samples (if use_dropout=True) | |
| """ | |
| device = next(self.parameters()).device | |
| if use_dropout: | |
| # MC Dropout for uncertainty estimation | |
| self.train() # Enable dropout during inference | |
| all_probs = [] | |
| for _ in range(num_samples): | |
| with torch.no_grad(): | |
| encoding = self.encode_state(state) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| _, probs = self.forward(input_ids, attention_mask) | |
| all_probs.append(probs.cpu().numpy()) | |
| # Average probabilities across samples | |
| avg_probs = np.mean(all_probs, axis=0) | |
| # Calculate uncertainty (standard deviation) | |
| uncertainty = np.std(all_probs, axis=0) | |
| return avg_probs, uncertainty | |
| else: | |
| # Standard inference (no uncertainty estimation) | |
| self.eval() | |
| with torch.no_grad(): | |
| encoding = self.encode_state(state) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| _, probs = self.forward(input_ids, attention_mask) | |
| return probs.cpu().numpy(), None | |
| # ============================================================================ | |
| # MODULE-LEVEL CACHING (Load once on import) | |
| # ============================================================================ | |
| # Global variables for caching | |
| POLICY_MODEL: Optional[PolicyNetwork] = None | |
| POLICY_TOKENIZER: Optional[AutoTokenizer] = None | |
| # def load_policy_model() -> PolicyNetwork: | |
| # """ | |
| # Load trained policy model (called once on startup). | |
| # Downloads from HuggingFace Hub if not present locally. | |
| # Uses module-level caching - model stays in RAM. | |
| # Returns: | |
| # PolicyNetwork: Loaded policy model | |
| # """ | |
| # global POLICY_MODEL, POLICY_TOKENIZER | |
| # if POLICY_MODEL is None: | |
| # # Download model from HF Hub if needed (for deployment) | |
| # settings.download_model_if_needed( | |
| # hf_filename="models/policy_query_only.pt", | |
| # local_path=settings.POLICY_MODEL_PATH | |
| # ) | |
| # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...") | |
| # try: | |
| # # Load checkpoint first to detect architecture | |
| # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE) | |
| # # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys | |
| # has_multilayer = "classifier.0.weight" in checkpoint | |
| # print(f"📊 Detected architecture: {'Multi-layer' if has_multilayer else 'Single-layer'} classifier") | |
| # # Create model instance with correct architecture | |
| # POLICY_MODEL = PolicyNetwork( | |
| # model_name="bert-base-uncased", | |
| # dropout_rate=0.1, | |
| # use_multilayer=has_multilayer # ✅ Auto-detect! | |
| # ) | |
| # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights | |
| # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0] | |
| # current_vocab_size = len(POLICY_MODEL.tokenizer) | |
| # if saved_vocab_size != current_vocab_size: | |
| # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}") | |
| # print(f"✅ Resizing tokenizer and embeddings to match saved model...") | |
| # # Resize model to match saved checkpoint | |
| # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size) | |
| # # Move to device | |
| # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| # # Now load trained weights (sizes and architecture will match!) | |
| # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| # else: | |
| # POLICY_MODEL.load_state_dict(checkpoint) | |
| # # Set to evaluation mode | |
| # POLICY_MODEL.eval() | |
| # # Cache tokenizer | |
| # POLICY_TOKENIZER = POLICY_MODEL.tokenizer | |
| # print("✅ Policy network loaded and cached") | |
| # except FileNotFoundError: | |
| # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}") | |
| # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| # raise | |
| # except Exception as e: | |
| # print(f"❌ Failed to load policy model: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # raise | |
| # return POLICY_MODEL | |
| def load_policy_model() -> PolicyNetwork: | |
| """ | |
| Load trained policy model (called once on startup). | |
| Downloads from HuggingFace Hub if not present locally. | |
| Uses module-level caching - model stays in RAM. | |
| Returns: | |
| PolicyNetwork: Loaded policy model | |
| """ | |
| global POLICY_MODEL, POLICY_TOKENIZER | |
| if POLICY_MODEL is None: | |
| # Download model from HF Hub if needed (for deployment) | |
| settings.download_model_if_needed( | |
| hf_filename="models/policy_query_only.pt", | |
| local_path=settings.POLICY_MODEL_PATH | |
| ) | |
| print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...") | |
| try: | |
| # Load checkpoint first to detect architecture | |
| checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE) | |
| # ✅ AUTO-DETECT ARCHITECTURE from checkpoint keys | |
| has_multilayer = "classifier.0.weight" in checkpoint | |
| # ✅ AUTO-DETECT HIDDEN SIZE from checkpoint | |
| if has_multilayer: | |
| hidden_size = checkpoint['classifier.0.weight'].shape[0] # Get output size of first layer | |
| print(f"📊 Detected: Multi-layer classifier (hidden_size={hidden_size})") | |
| else: | |
| hidden_size = 768 # Doesn't matter for single-layer | |
| print(f"📊 Detected: Single-layer classifier") | |
| # Create model instance with correct architecture | |
| POLICY_MODEL = PolicyNetwork( | |
| model_name="bert-base-uncased", | |
| dropout_rate=0.1, | |
| use_multilayer=has_multilayer, | |
| hidden_size=hidden_size # ✅ Pass detected hidden size | |
| ) | |
| # **KEY FIX**: Handle vocab size mismatch | |
| saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0] | |
| current_vocab_size = len(POLICY_MODEL.tokenizer) | |
| if saved_vocab_size != current_vocab_size: | |
| print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}") | |
| if abs(saved_vocab_size - current_vocab_size) <= 2: | |
| # Small difference - just load with strict=False | |
| print(f"✅ Loading with strict=False to handle minor vocab differences...") | |
| # Move to device first | |
| POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| # Load weights with strict=False | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| else: | |
| POLICY_MODEL.load_state_dict(checkpoint, strict=False) | |
| else: | |
| # Large difference - resize properly | |
| print(f"✅ Resizing model to match saved vocab size...") | |
| POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size) | |
| # Move to device | |
| POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| # Load weights | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| POLICY_MODEL.load_state_dict(checkpoint) | |
| else: | |
| # No mismatch | |
| POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| POLICY_MODEL.load_state_dict(checkpoint) | |
| # Set to evaluation mode | |
| POLICY_MODEL.eval() | |
| # Cache tokenizer | |
| POLICY_TOKENIZER = POLICY_MODEL.tokenizer | |
| print("✅ Policy network loaded and cached") | |
| print(f" Model vocab size: {POLICY_MODEL.bert.embeddings.word_embeddings.num_embeddings}") | |
| print(f" Tokenizer vocab size: {len(POLICY_MODEL.tokenizer)}") | |
| print("✅ Policy network loaded and cached") | |
| except FileNotFoundError: | |
| print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}") | |
| print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| raise | |
| except Exception as e: | |
| print(f"❌ Failed to load policy model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| return POLICY_MODEL | |
| # ============================================================================ | |
| # PREDICTION FUNCTIONS | |
| # ============================================================================ | |
| def create_state_from_history( | |
| current_query: str, | |
| conversation_history: List[Dict], | |
| max_history: int = 2 | |
| ) -> Dict: | |
| """ | |
| Create state dictionary from conversation history. | |
| Extracts last N query-action pairs. | |
| Args: | |
| current_query: Current user query | |
| conversation_history: List of conversation turns | |
| Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}} | |
| max_history: Maximum number of previous turns to include (default: 2) | |
| Returns: | |
| dict: State dictionary for policy network | |
| """ | |
| state = { | |
| 'current_query': current_query, | |
| 'previous_queries': [], | |
| 'previous_actions': [] | |
| } | |
| if not conversation_history: | |
| return state | |
| # Extract last N conversation turns (user + assistant pairs) | |
| relevant_history = conversation_history[-(max_history * 2):] | |
| for i, turn in enumerate(relevant_history): | |
| # User turns | |
| if turn.get('role') == 'user': | |
| query = turn.get('content', '') | |
| state['previous_queries'].append(query) | |
| # Look for corresponding assistant turn | |
| if i + 1 < len(relevant_history): | |
| bot_turn = relevant_history[i + 1] | |
| if bot_turn.get('role') == 'assistant': | |
| metadata = bot_turn.get('metadata', {}) | |
| action = metadata.get('policy_action', 'FETCH') | |
| state['previous_actions'].append(action) | |
| return state | |
| def predict_policy_action( | |
| query: str, | |
| history: List[Dict] = None, | |
| return_probs: bool = False | |
| ) -> Dict: | |
| """ | |
| Predict FETCH/NO_FETCH action for a query. | |
| Args: | |
| query: User query text | |
| history: Conversation history (optional) | |
| return_probs: Whether to return full probability distribution | |
| Returns: | |
| dict: Prediction results | |
| { | |
| 'action': 'FETCH' or 'NO_FETCH', | |
| 'confidence': float (0-1), | |
| 'fetch_prob': float, | |
| 'no_fetch_prob': float, | |
| 'should_retrieve': bool | |
| } | |
| """ | |
| # Load model (cached after first call) | |
| model = load_policy_model() | |
| # Create state from history | |
| if history is None: | |
| history = [] | |
| state = create_state_from_history(query, history) | |
| # Predict action | |
| probs, _ = model.predict_action(state, use_dropout=False) | |
| # Extract probabilities | |
| fetch_prob = float(probs[0][0]) | |
| no_fetch_prob = float(probs[0][1]) | |
| # Determine action (argmax) | |
| action_idx = np.argmax(probs[0]) | |
| action = "FETCH" if action_idx == 0 else "NO_FETCH" | |
| confidence = float(probs[0][action_idx]) | |
| # Check confidence threshold | |
| should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD) | |
| result = { | |
| 'action': action, | |
| 'confidence': confidence, | |
| 'should_retrieve': should_retrieve, | |
| 'policy_decision': action | |
| } | |
| if return_probs: | |
| result['fetch_prob'] = fetch_prob | |
| result['no_fetch_prob'] = no_fetch_prob | |
| return result | |
| # ============================================================================ | |
| # USAGE EXAMPLE (for reference) | |
| # ============================================================================ | |
| """ | |
| # In your service file: | |
| from app.ml.policy_network import predict_policy_action | |
| # Predict action | |
| history = [ | |
| {'role': 'user', 'content': 'What is my balance?'}, | |
| {'role': 'assistant', 'content': '$1000', 'metadata': {'policy_action': 'FETCH'}} | |
| ] | |
| result = predict_policy_action( | |
| query="Thank you!", | |
| history=history, | |
| return_probs=True | |
| ) | |
| print(result) | |
| # { | |
| # 'action': 'NO_FETCH', | |
| # 'confidence': 0.95, | |
| # 'should_retrieve': False, | |
| # 'fetch_prob': 0.05, | |
| # 'no_fetch_prob': 0.95 | |
| # } | |
| """ | |
| """ | |
| BERT-based Policy Network for FETCH/NO_FETCH decisions | |
| Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization) | |
| This is adapted from your RL.py with: | |
| - PolicyNetwork class (BERT-based) | |
| - State encoding from conversation history | |
| - Action prediction (FETCH vs NO_FETCH) | |
| - Module-level caching (load once on startup) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict, Optional, Tuple | |
| from transformers import AutoTokenizer, AutoModel | |
| from app.config import settings | |
| # ============================================================================ | |
| # POLICY NETWORK (From RL.py) | |
| # ============================================================================ | |
| class PolicyNetwork(nn.Module): | |
| """ | |
| BERT-based Policy Network for deciding FETCH vs NO_FETCH actions. | |
| Architecture: | |
| - Base: BERT-base-uncased (pre-trained) | |
| - Input: Current query + conversation history + previous actions | |
| - Output: 2-class softmax (FETCH=0, NO_FETCH=1) | |
| - Special tokens: [FETCH], [NO_FETCH] for action encoding | |
| Training Details: | |
| - Loss: Policy Gradient + Entropy Regularization | |
| - Optimizer: AdamW | |
| - Reward structure: | |
| * FETCH: +0.5 (always) | |
| * NO_FETCH + Good: +2.0 | |
| * NO_FETCH + Bad: -0.5 | |
| """ | |
| def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1): | |
| super(PolicyNetwork, self).__init__() | |
| # Load pre-trained BERT | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Add special tokens for actions: [FETCH] and [NO_FETCH] | |
| special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]} | |
| self.tokenizer.add_special_tokens(special_tokens) | |
| # Resize BERT embeddings to accommodate new tokens | |
| self.bert.resize_token_embeddings(len(self.tokenizer)) | |
| # Initialize random embeddings for special tokens | |
| self._init_action_embeddings() | |
| # Classification head: BERT hidden size (768) → 2 classes | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, 2) | |
| # Dropout for regularization | |
| self.dropout = nn.Dropout(dropout_rate) | |
| def _init_action_embeddings(self): | |
| """ | |
| Initialize random embeddings for [FETCH] and [NO_FETCH] tokens. | |
| These are learned during training. | |
| """ | |
| with torch.no_grad(): | |
| # Get token IDs for special tokens | |
| fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]") | |
| no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]") | |
| # Get embedding dimension | |
| embedding_dim = self.bert.config.hidden_size | |
| # Initialize with small random values (same as BERT initialization) | |
| self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02 | |
| self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02 | |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through BERT + classifier. | |
| Args: | |
| input_ids: Tokenized input IDs (shape: [batch_size, seq_len]) | |
| attention_mask: Attention mask (shape: [batch_size, seq_len]) | |
| Returns: | |
| logits: Raw logits (shape: [batch_size, 2]) | |
| probs: Softmax probabilities (shape: [batch_size, 2]) | |
| """ | |
| # Pass through BERT | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| # Extract [CLS] token representation (first token) | |
| cls_output = outputs.last_hidden_state[:, 0, :] | |
| # Apply dropout | |
| cls_output = self.dropout(cls_output) | |
| # Classification | |
| logits = self.classifier(cls_output) | |
| # Softmax for probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| return logits, probs | |
| def encode_state( | |
| self, | |
| state: Dict, | |
| max_length: int = None | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Encode conversation state into BERT input format. | |
| State structure: | |
| { | |
| 'previous_queries': [query1, query2, ...], | |
| 'previous_actions': ['FETCH', 'NO_FETCH', ...], | |
| 'current_query': 'user query' | |
| } | |
| Encoding format: | |
| "Previous query 1: <text> [Action: [FETCH]] Previous query 2: <text> [Action: [NO_FETCH]] Current query: <text>" | |
| Args: | |
| state: State dictionary | |
| max_length: Maximum sequence length (default from config) | |
| Returns: | |
| dict: Tokenized inputs (input_ids, attention_mask) | |
| """ | |
| if max_length is None: | |
| max_length = settings.POLICY_MAX_LEN | |
| # Build state text from conversation history | |
| state_text = "" | |
| # Add previous queries and their actions | |
| prev_queries = state.get('previous_queries', []) | |
| prev_actions = state.get('previous_actions', []) | |
| if prev_queries and prev_actions: | |
| for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)): | |
| state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] " | |
| # Add current query | |
| current_query = state.get('current_query', '') | |
| state_text += f"Current query: {current_query}" | |
| # Tokenize | |
| encoding = self.tokenizer( | |
| state_text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ) | |
| return encoding | |
| def predict_action( | |
| self, | |
| state: Dict, | |
| use_dropout: bool = False, | |
| num_samples: int = 10 | |
| ) -> Tuple[np.ndarray, Optional[np.ndarray]]: | |
| """ | |
| Predict action probabilities for a given state. | |
| Args: | |
| state: Conversation state dictionary | |
| use_dropout: Whether to use MC Dropout for uncertainty estimation | |
| num_samples: Number of MC Dropout samples (if use_dropout=True) | |
| Returns: | |
| probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)] | |
| uncertainty: Standard deviation across samples (if use_dropout=True) | |
| """ | |
| device = next(self.parameters()).device | |
| if use_dropout: | |
| # MC Dropout for uncertainty estimation | |
| self.train() # Enable dropout during inference | |
| all_probs = [] | |
| for _ in range(num_samples): | |
| with torch.no_grad(): | |
| encoding = self.encode_state(state) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| _, probs = self.forward(input_ids, attention_mask) | |
| all_probs.append(probs.cpu().numpy()) | |
| # Average probabilities across samples | |
| avg_probs = np.mean(all_probs, axis=0) | |
| # Calculate uncertainty (standard deviation) | |
| uncertainty = np.std(all_probs, axis=0) | |
| return avg_probs, uncertainty | |
| else: | |
| # Standard inference (no uncertainty estimation) | |
| self.eval() | |
| with torch.no_grad(): | |
| encoding = self.encode_state(state) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| _, probs = self.forward(input_ids, attention_mask) | |
| return probs.cpu().numpy(), None | |
| # ============================================================================ | |
| # MODULE-LEVEL CACHING (Load once on import) | |
| # ============================================================================ | |
| # Global variables for caching | |
| POLICY_MODEL: Optional[PolicyNetwork] = None | |
| POLICY_TOKENIZER: Optional[AutoTokenizer] = None | |
| # ============================================================================================= | |
| # Latest version given by perplexity, should work, if not then use one of the other versions. | |
| # ============================================================================================= | |
| def load_policy_model() -> PolicyNetwork: | |
| """ | |
| Load trained policy model (called once on startup). | |
| Downloads from HuggingFace Hub if not present locally. | |
| Uses module-level caching - model stays in RAM. | |
| Returns: | |
| PolicyNetwork: Loaded policy model | |
| """ | |
| global POLICY_MODEL, POLICY_TOKENIZER | |
| if POLICY_MODEL is None: | |
| # Download model from HF Hub if needed (for deployment) | |
| settings.download_model_if_needed( | |
| hf_filename="models/policy_query_only.pt", | |
| local_path=settings.POLICY_MODEL_PATH | |
| ) | |
| print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...") | |
| try: | |
| # Load checkpoint first to get vocab size | |
| checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE) | |
| # Create model instance | |
| POLICY_MODEL = PolicyNetwork( | |
| model_name="bert-base-uncased", | |
| dropout_rate=0.1 | |
| ) | |
| # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights | |
| saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0] | |
| current_vocab_size = len(POLICY_MODEL.tokenizer) | |
| if saved_vocab_size != current_vocab_size: | |
| print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}") | |
| print(f"✅ Resizing tokenizer and embeddings to match saved model...") | |
| # Resize model to match saved checkpoint | |
| POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size) | |
| # Move to device | |
| POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| # Now load trained weights (sizes will match!) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| POLICY_MODEL.load_state_dict(checkpoint) | |
| # Set to evaluation mode | |
| POLICY_MODEL.eval() | |
| # Cache tokenizer | |
| POLICY_TOKENIZER = POLICY_MODEL.tokenizer | |
| print("✅ Policy network loaded and cached") | |
| except FileNotFoundError: | |
| print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}") | |
| print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| raise | |
| except Exception as e: | |
| print(f"❌ Failed to load policy model: {e}") | |
| raise | |
| return POLICY_MODEL | |
| # =========================================================================== | |
| # This version is used in the code, atleast for localhost testing | |
| # =========================================================================== | |
| # def load_policy_model() -> PolicyNetwork: | |
| # """ | |
| # Load trained policy model (called once on startup). | |
| # Uses module-level caching - model stays in RAM. | |
| # Returns: | |
| # PolicyNetwork: Loaded policy model | |
| # """ | |
| # global POLICY_MODEL, POLICY_TOKENIZER | |
| # if POLICY_MODEL is None: | |
| # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...") | |
| # try: | |
| # # Load checkpoint first to get vocab size | |
| # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE) | |
| # # Create model instance | |
| # POLICY_MODEL = PolicyNetwork( | |
| # model_name="bert-base-uncased", | |
| # dropout_rate=0.1 | |
| # ) | |
| # # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights | |
| # saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0] | |
| # current_vocab_size = len(POLICY_MODEL.tokenizer) | |
| # if saved_vocab_size != current_vocab_size: | |
| # print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}") | |
| # print(f"✅ Resizing tokenizer and embeddings to match saved model...") | |
| # # Resize model to match saved checkpoint | |
| # POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size) | |
| # # Move to device | |
| # POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE) | |
| # # Now load trained weights (sizes will match!) | |
| # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| # else: | |
| # POLICY_MODEL.load_state_dict(checkpoint) | |
| # # Set to evaluation mode | |
| # POLICY_MODEL.eval() | |
| # # Cache tokenizer | |
| # POLICY_TOKENIZER = POLICY_MODEL.tokenizer | |
| # print("✅ Policy network loaded and cached") | |
| # except FileNotFoundError: | |
| # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}") | |
| # print("⚠️ You need to train the policy network first!") | |
| # raise | |
| # except Exception as e: | |
| # print(f"❌ Failed to load policy model: {e}") | |
| # raise | |
| # return POLICY_MODEL | |
| # ===================================================================================== | |
| # This is the older version or proably a different version, potentially still useful | |
| # ===================================================================================== | |
| # def load_policy_model() -> PolicyNetwork: | |
| # """ | |
| # Load trained policy model (called once on startup). | |
| # Uses module-level caching - model stays in RAM. | |
| # Returns: | |
| # PolicyNetwork: Loaded policy model | |
| # """ | |
| # global POLICY_MODEL, POLICY_TOKENIZER | |
| # if POLICY_MODEL is None: | |
| # print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...") | |
| # try: | |
| # # Create model instance | |
| # POLICY_MODEL = PolicyNetwork( | |
| # model_name="bert-base-uncased", | |
| # dropout_rate=0.1 | |
| # ).to(settings.DEVICE) | |
| # # Load trained weights | |
| # checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE) | |
| # # Handle different checkpoint formats | |
| # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| # # Full checkpoint with metadata | |
| # POLICY_MODEL.load_state_dict(checkpoint['model_state_dict']) | |
| # else: | |
| # # Just state dict | |
| # POLICY_MODEL.load_state_dict(checkpoint) | |
| # # Set to evaluation mode | |
| # POLICY_MODEL.eval() | |
| # # Cache tokenizer | |
| # POLICY_TOKENIZER = POLICY_MODEL.tokenizer | |
| # print("✅ Policy network loaded and cached") | |
| # except FileNotFoundError: | |
| # print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}") | |
| # print("⚠️ You need to train the policy network first!") | |
| # raise | |
| # except Exception as e: | |
| # print(f"❌ Failed to load policy model: {e}") | |
| # raise | |
| # return POLICY_MODEL | |
| # ============================================================================ | |
| # PREDICTION FUNCTIONS | |
| # ============================================================================ | |
| def create_state_from_history( | |
| current_query: str, | |
| conversation_history: List[Dict], | |
| max_history: int = 2 | |
| ) -> Dict: | |
| """ | |
| Create state dictionary from conversation history. | |
| Extracts last N query-action pairs. | |
| Args: | |
| current_query: Current user query | |
| conversation_history: List of conversation turns | |
| Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}} | |
| max_history: Maximum number of previous turns to include (default: 2) | |
| Returns: | |
| dict: State dictionary for policy network | |
| """ | |
| state = { | |
| 'current_query': current_query, | |
| 'previous_queries': [], | |
| 'previous_actions': [] | |
| } | |
| if not conversation_history: | |
| return state | |
| # Extract last N conversation turns (user + assistant pairs) | |
| relevant_history = conversation_history[-(max_history * 2):] | |
| for i, turn in enumerate(relevant_history): | |
| # User turns | |
| if turn.get('role') == 'user': | |
| query = turn.get('content', '') | |
| state['previous_queries'].append(query) | |
| # Look for corresponding assistant turn | |
| if i + 1 < len(relevant_history): | |
| bot_turn = relevant_history[i + 1] | |
| if bot_turn.get('role') == 'assistant': | |
| metadata = bot_turn.get('metadata', {}) | |
| action = metadata.get('policy_action', 'FETCH') | |
| state['previous_actions'].append(action) | |
| return state | |
| def predict_policy_action( | |
| query: str, | |
| history: List[Dict] = None, | |
| return_probs: bool = False | |
| ) -> Dict: | |
| """ | |
| Predict FETCH/NO_FETCH action for a query. | |
| Args: | |
| query: User query text | |
| history: Conversation history (optional) | |
| return_probs: Whether to return full probability distribution | |
| Returns: | |
| dict: Prediction results | |
| { | |
| 'action': 'FETCH' or 'NO_FETCH', | |
| 'confidence': float (0-1), | |
| 'fetch_prob': float, | |
| 'no_fetch_prob': float, | |
| 'should_retrieve': bool | |
| } | |
| """ | |
| # Load model (cached after first call) | |
| model = load_policy_model() | |
| # Create state from history | |
| if history is None: | |
| history = [] | |
| state = create_state_from_history(query, history) | |
| # Predict action | |
| probs, _ = model.predict_action(state, use_dropout=False) | |
| # Extract probabilities | |
| fetch_prob = float(probs[0][0]) | |
| no_fetch_prob = float(probs[0][1]) | |
| # Determine action (argmax) | |
| action_idx = np.argmax(probs[0]) | |
| action = "FETCH" if action_idx == 0 else "NO_FETCH" | |
| confidence = float(probs[0][action_idx]) | |
| # Check confidence threshold | |
| should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD) | |
| result = { | |
| 'action': action, | |
| 'confidence': confidence, | |
| 'should_retrieve': should_retrieve, | |
| 'policy_decision': action | |
| } | |
| if return_probs: | |
| result['fetch_prob'] = fetch_prob | |
| result['no_fetch_prob'] = no_fetch_prob | |
| return result | |
| # ============================================================================ | |
| # USAGE EXAMPLE (for reference) | |
| # ============================================================================ | |
| """ | |
| # In your service file: | |
| from app.ml.policy_network import predict_policy_action | |
| # Predict action | |
| history = [ | |
| {'role': 'user', 'content': 'What is my balance?'}, | |
| {'role': 'assistant', 'content': '$1000', 'metadata': {'policy_action': 'FETCH'}} | |
| ] | |
| result = predict_policy_action( | |
| query="Thank you!", | |
| history=history, | |
| return_probs=True | |
| ) | |
| print(result) | |
| # { | |
| # 'action': 'NO_FETCH', | |
| # 'confidence': 0.95, | |
| # 'should_retrieve': False, | |
| # 'fetch_prob': 0.05, | |
| # 'no_fetch_prob': 0.95 | |
| # } | |
| """ |