Spaces:
Sleeping
Sleeping
| """ | |
| Custom Retriever with E5-Base-V2 and FAISS | |
| Trained with InfoNCE + Triplet Loss for banking domain | |
| This is adapted from your RAG.py with: | |
| - CustomSentenceTransformer (e5-base-v2) | |
| - Mean pooling + L2 normalization | |
| - FAISS vector search | |
| - Module-level caching (load once on startup) | |
| - ✅ Compatible with v2 FAISS format | |
| """ | |
| import os | |
| import json | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import faiss | |
| import numpy as np | |
| from typing import List, Dict, Optional | |
| from transformers import AutoTokenizer, AutoModel | |
| from app.config import settings | |
| # ============================================================================ | |
| # CUSTOM SENTENCE TRANSFORMER (From RAG.py) | |
| # ============================================================================ | |
| class CustomSentenceTransformer(nn.Module): | |
| """ | |
| Custom SentenceTransformer matching your training code. | |
| Uses e5-base-v2 with mean pooling and L2 normalization. | |
| Training Details: | |
| - Base model: intfloat/e5-base-v2 | |
| - Loss: InfoNCE + Triplet Loss | |
| - Pooling: Mean pooling on last hidden state | |
| - Normalization: L2 normalization | |
| """ | |
| def __init__(self, model_name: str = "intfloat/e5-base-v2"): | |
| super().__init__() | |
| # Load pre-trained e5-base-v2 encoder | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.config = self.encoder.config | |
| def forward(self, input_ids, attention_mask): | |
| """ | |
| Forward pass through BERT encoder. | |
| Args: | |
| input_ids: Tokenized input IDs | |
| attention_mask: Attention mask for padding | |
| Returns: | |
| torch.Tensor: L2-normalized embeddings (shape: [batch_size, 768]) | |
| """ | |
| # Get BERT outputs | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| # Mean pooling - same as training | |
| # Take hidden states from last layer | |
| token_embeddings = outputs.last_hidden_state | |
| # Expand attention mask to match token embeddings shape | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| # Sum embeddings (weighted by attention mask) and divide by sum of mask | |
| embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| # L2 normalize embeddings - same as training | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| return embeddings | |
| def encode( | |
| self, | |
| sentences: List[str], | |
| batch_size: int = 32, | |
| convert_to_numpy: bool = True, | |
| show_progress_bar: bool = False | |
| ) -> np.ndarray: | |
| """ | |
| Encode sentences using the same method as training. | |
| Adds 'query: ' prefix for e5-base-v2 compatibility. | |
| Args: | |
| sentences: List of sentences to encode | |
| batch_size: Batch size for encoding | |
| convert_to_numpy: Whether to convert to numpy array | |
| show_progress_bar: Whether to show progress bar | |
| Returns: | |
| np.ndarray: Encoded embeddings (shape: [num_sentences, 768]) | |
| """ | |
| self.eval() # Set model to evaluation mode | |
| # Handle single string input | |
| if isinstance(sentences, str): | |
| sentences = [sentences] | |
| # Add 'query: ' prefix for e5-base-v2 (required by model) | |
| # Handle None values and empty strings | |
| processed_sentences = [] | |
| for sentence in sentences: | |
| if sentence is None: | |
| processed_sentences.append("query: ") # Default empty query | |
| elif isinstance(sentence, str): | |
| processed_sentences.append(f"query: {sentence.strip()}") | |
| else: | |
| processed_sentences.append(f"query: {str(sentence)}") | |
| all_embeddings = [] | |
| # Encode in batches | |
| with torch.no_grad(): # No gradient computation | |
| for i in range(0, len(processed_sentences), batch_size): | |
| batch_sentences = processed_sentences[i:i + batch_size] | |
| # Tokenize batch | |
| tokens = self.tokenizer( | |
| batch_sentences, | |
| truncation=True, | |
| padding=True, | |
| max_length=128, # Same as training | |
| return_tensors='pt' | |
| ).to(next(self.parameters()).device) | |
| # Get embeddings | |
| embeddings = self.forward(tokens['input_ids'], tokens['attention_mask']) | |
| # Convert to numpy if requested | |
| if convert_to_numpy: | |
| embeddings = embeddings.cpu().numpy() | |
| all_embeddings.append(embeddings) | |
| # Combine all batches | |
| if convert_to_numpy: | |
| all_embeddings = np.vstack(all_embeddings) | |
| else: | |
| all_embeddings = torch.cat(all_embeddings, dim=0) | |
| return all_embeddings | |
| # ============================================================================ | |
| # CUSTOM RETRIEVER MODEL (Wrapper) | |
| # ============================================================================ | |
| class CustomRetrieverModel: | |
| """ | |
| Wrapper for your custom trained retriever model. | |
| Handles both knowledge base documents and query encoding. | |
| """ | |
| def __init__(self, model_path: str, device: str = "cpu"): | |
| """ | |
| Initialize retriever model. | |
| Args: | |
| model_path: Path to trained model weights (.pth file) | |
| device: Device to load model on ('cpu' or 'cuda') | |
| """ | |
| self.device = device | |
| # Create model instance | |
| self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device) | |
| # Load your trained weights | |
| try: | |
| state_dict = torch.load(model_path, map_location=device) | |
| self.model.load_state_dict(state_dict) | |
| print(f"✅ Custom retriever model loaded from {model_path}") | |
| except Exception as e: | |
| print(f"❌ Failed to load custom model: {e}") | |
| print("🔄 Using base e5-base-v2 model (not trained)...") | |
| # Set to evaluation mode | |
| self.model.eval() | |
| def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray: | |
| """ | |
| Encode knowledge base documents. | |
| These are the responses/instructions we're retrieving. | |
| Args: | |
| documents: List of document texts | |
| batch_size: Batch size for encoding | |
| Returns: | |
| np.ndarray: Document embeddings (shape: [num_docs, 768]) | |
| """ | |
| return self.model.encode(documents, batch_size=batch_size, convert_to_numpy=True) | |
| def encode_query(self, query: str) -> np.ndarray: | |
| """ | |
| Encode user query for retrieval. | |
| Args: | |
| query: User query text | |
| Returns: | |
| np.ndarray: Query embedding (shape: [1, 768]) | |
| """ | |
| return self.model.encode([query], convert_to_numpy=True) | |
| # ============================================================================ | |
| # MODULE-LEVEL CACHING (Load once on import) | |
| # ============================================================================ | |
| # Global variables for caching | |
| RETRIEVER_MODEL: Optional[CustomRetrieverModel] = None | |
| FAISS_INDEX: Optional[faiss.Index] = None | |
| KB_DATA: Optional[List[Dict]] = None | |
| # ============================================================================ | |
| # ✅ UPDATED: COMPATIBLE FAISS LOADING (SUPPORTS V2 FORMAT) | |
| # ============================================================================ | |
| def load_retriever() -> CustomRetrieverModel: | |
| """ | |
| Load custom retriever model (called once on startup). | |
| Downloads from HuggingFace Hub if not present locally. | |
| Uses module-level caching - model stays in RAM. | |
| Returns: | |
| CustomRetrieverModel: Loaded retriever model | |
| """ | |
| global RETRIEVER_MODEL | |
| if RETRIEVER_MODEL is None: | |
| # Download model from HF Hub if needed (for deployment) | |
| settings.download_model_if_needed( | |
| hf_filename="models/best_retriever_model.pth", | |
| local_path=settings.RETRIEVER_MODEL_PATH | |
| ) | |
| print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...") | |
| RETRIEVER_MODEL = CustomRetrieverModel( | |
| model_path=settings.RETRIEVER_MODEL_PATH, | |
| device=settings.DEVICE | |
| ) | |
| print("✅ Retriever model loaded and cached") | |
| return RETRIEVER_MODEL | |
| def load_faiss_index(): | |
| """ | |
| ✅ UPDATED: Load FAISS index with v2 format compatibility. | |
| Supports multiple formats: | |
| - v2 format (dict with version info) - RECOMMENDED | |
| - Old tuple format (index_bytes, kb_data) | |
| - Legacy format (direct FAISS object) - will show warning | |
| Downloads from HuggingFace Hub if not present locally. | |
| Uses module-level caching - loaded once on startup. | |
| Returns: | |
| tuple: (faiss.Index, List[Dict]) - FAISS index and KB data | |
| """ | |
| global FAISS_INDEX, KB_DATA | |
| if FAISS_INDEX is None or KB_DATA is None: | |
| # Download FAISS index from HF Hub if needed (for deployment) | |
| settings.download_model_if_needed( | |
| hf_filename="models/faiss_index.pkl", | |
| local_path=settings.FAISS_INDEX_PATH | |
| ) | |
| # Download knowledge base from HF Hub if needed (for deployment) | |
| settings.download_model_if_needed( | |
| hf_filename="data/final_knowledge_base.jsonl", | |
| local_path=settings.KB_PATH | |
| ) | |
| print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...") | |
| try: | |
| # Load pickled data | |
| with open(settings.FAISS_INDEX_PATH, 'rb') as f: | |
| data = pickle.load(f) | |
| print(f"📦 Pickle loaded successfully") | |
| # ======================================================================== | |
| # FORMAT 1: v2 Dictionary Format (RECOMMENDED) | |
| # ======================================================================== | |
| if isinstance(data, dict) and data.get('format_version') == 'v2': | |
| print("📦 Detected v2 format (compatible)") | |
| try: | |
| FAISS_INDEX = faiss.deserialize_index(data['index_bytes']) | |
| KB_DATA = data['kb_data'] | |
| print(f"✅ FAISS index loaded successfully") | |
| print(f" Vectors: {FAISS_INDEX.ntotal}") | |
| print(f" KB docs: {len(KB_DATA)}") | |
| print(f" Dimension: {data.get('dimension', 'unknown')}") | |
| print(f" Built with FAISS: {data.get('faiss_version', 'unknown')}") | |
| print(f" Current FAISS: {faiss.__version__}") | |
| except Exception as e: | |
| print(f"❌ Failed to deserialize v2 format: {e}") | |
| raise RuntimeError( | |
| f"Failed to load v2 FAISS index: {e}\n" | |
| f"Please rebuild using: python build_faiss_index.py" | |
| ) | |
| # ======================================================================== | |
| # FORMAT 2: Old Tuple Format (index_bytes, kb_data) | |
| # ======================================================================== | |
| elif isinstance(data, tuple) and len(data) == 2: | |
| first_item, KB_DATA = data | |
| # Check if first item is bytes (serialized index) | |
| if isinstance(first_item, bytes): | |
| print("📦 Detected old tuple format with bytes (attempting conversion)") | |
| try: | |
| FAISS_INDEX = faiss.deserialize_index(first_item) | |
| print(f"✅ FAISS index deserialized from bytes") | |
| print(f" Vectors: {FAISS_INDEX.ntotal}") | |
| print(f" KB docs: {len(KB_DATA)}") | |
| except Exception as e: | |
| print(f"❌ Failed to deserialize index bytes: {e}") | |
| raise RuntimeError( | |
| f"Failed to deserialize FAISS index: {e}\n" | |
| f"Please rebuild using: python build_faiss_index.py" | |
| ) | |
| # Otherwise it's a direct FAISS object (LEGACY - DANGEROUS!) | |
| else: | |
| print(f"📦 Detected old tuple format with direct object") | |
| print(f"⚠️ WARNING: Direct FAISS objects are not compatible across versions") | |
| # Try to use it, but expect it might fail | |
| try: | |
| FAISS_INDEX = first_item | |
| # Test if it works | |
| num_vectors = FAISS_INDEX.ntotal | |
| print(f"✅ FAISS index appears valid ({num_vectors} vectors)") | |
| print(f" KB docs: {len(KB_DATA)}") | |
| print(f"⚠️ However, this format may break across FAISS versions") | |
| print(f"🔧 Recommended: Rebuild using: python build_faiss_index.py") | |
| except Exception as e: | |
| print(f"❌ FAISS index object is corrupted: {e}") | |
| print(f" This usually means FAISS version mismatch") | |
| raise RuntimeError( | |
| f"FAISS index is corrupted or incompatible (version mismatch).\n" | |
| f"Error: {e}\n\n" | |
| f"🔧 SOLUTION: Rebuild FAISS index using:\n" | |
| f" python build_faiss_index.py\n" | |
| ) | |
| # ======================================================================== | |
| # FORMAT 3: Unknown Format | |
| # ======================================================================== | |
| else: | |
| print(f"❌ Unknown pickle format: {type(data)}") | |
| if isinstance(data, dict): | |
| print(f" Dict keys: {list(data.keys())}") | |
| raise ValueError( | |
| f"Unrecognized pickle format: {type(data)}.\n" | |
| f"Please rebuild using: python build_faiss_index.py" | |
| ) | |
| # Final validation | |
| if FAISS_INDEX is None or KB_DATA is None: | |
| raise RuntimeError("Failed to load FAISS index or KB data") | |
| print(f"✅ FAISS index ready: {FAISS_INDEX.ntotal} vectors") | |
| print(f"✅ Knowledge base ready: {len(KB_DATA)} documents") | |
| except FileNotFoundError: | |
| print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}") | |
| print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| raise | |
| except RuntimeError: | |
| raise # Re-raise our custom error with instructions | |
| except Exception as e: | |
| print(f"❌ Unexpected error loading FAISS index: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise RuntimeError( | |
| f"Failed to load FAISS index: {e}\n" | |
| f"Please rebuild using: python build_faiss_index.py" | |
| ) | |
| return FAISS_INDEX, KB_DATA | |
| def retrieve_documents( | |
| query: str, | |
| top_k: int = None, | |
| min_similarity: float = None | |
| ) -> List[Dict]: | |
| """ | |
| Retrieve top-k documents for a query using custom retriever + FAISS. | |
| Args: | |
| query: User query text | |
| top_k: Number of documents to retrieve (default from config) | |
| min_similarity: Minimum similarity threshold (default from config) | |
| Returns: | |
| List[Dict]: Retrieved documents with scores | |
| Each dict contains: | |
| - instruction: FAQ question | |
| - response: FAQ answer | |
| - category: Document category | |
| - intent: Document intent | |
| - score: Similarity score (0-1) | |
| - rank: Rank in results (1-indexed) | |
| - faq_id: Document ID | |
| """ | |
| # Use config defaults if not provided | |
| if top_k is None: | |
| top_k = settings.TOP_K | |
| if min_similarity is None: | |
| min_similarity = settings.SIMILARITY_THRESHOLD | |
| # Validate query | |
| if not query or query.strip() == "": | |
| print("⚠️ Empty query provided") | |
| return [] | |
| try: | |
| # Load models (cached, no overhead after first call) | |
| retriever = load_retriever() | |
| index, kb = load_faiss_index() | |
| # Step 1: Encode query | |
| query_embedding = retriever.encode_query(query) | |
| # Step 2: Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Step 3: Search in FAISS index | |
| similarities, indices = index.search(query_embedding, top_k) | |
| # Step 4: Check similarity threshold for top result | |
| if similarities[0][0] < min_similarity: | |
| print(f"🚫 NO_FETCH (similarity: {similarities[0][0]:.3f} < {min_similarity})") | |
| return [] | |
| print(f"✅ FETCH (similarity: {similarities[0][0]:.3f} >= {min_similarity})") | |
| # Step 5: Format results | |
| results = [] | |
| for rank, (similarity, idx) in enumerate(zip(similarities[0], indices[0])): | |
| if idx < len(kb): | |
| doc = kb[idx] | |
| results.append({ | |
| 'instruction': doc.get('instruction', ''), | |
| 'response': doc.get('response', ''), | |
| 'category': doc.get('category', 'Unknown'), | |
| 'intent': doc.get('intent', 'Unknown'), | |
| 'score': float(similarity), | |
| 'rank': rank + 1, | |
| 'faq_id': doc.get('faq_id', f'doc_{idx}') | |
| }) | |
| return results | |
| except RuntimeError as e: | |
| # Handle our custom errors with clear messages | |
| print(f"❌ Retrieval error: {e}") | |
| return [] | |
| except Exception as e: | |
| print(f"❌ Unexpected retrieval error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def format_context(retrieved_docs: List[Dict], max_context_length: int = None) -> str: | |
| """ | |
| Format retrieved documents into context string for LLM. | |
| Prioritizes by score and limits total length. | |
| Args: | |
| retrieved_docs: List of retrieved documents | |
| max_context_length: Maximum context length in characters | |
| Returns: | |
| str: Formatted context string | |
| """ | |
| if max_context_length is None: | |
| max_context_length = settings.MAX_CONTEXT_LENGTH | |
| if not retrieved_docs: | |
| return "" | |
| context_parts = [] | |
| current_length = 0 | |
| for doc in retrieved_docs: | |
| # Create context entry with None checks | |
| instruction = doc.get('instruction', '') or '' | |
| response = doc.get('response', '') or '' | |
| category = doc.get('category', 'N/A') or 'N/A' | |
| context_entry = f"[Rank {doc['rank']}, Score: {doc['score']:.3f}]\n" | |
| context_entry += f"Q: {instruction}\n" | |
| context_entry += f"A: {response}\n" | |
| context_entry += f"Category: {category}\n\n" | |
| # Check length limit | |
| if current_length + len(context_entry) > max_context_length: | |
| break | |
| context_parts.append(context_entry) | |
| current_length += len(context_entry) | |
| return "".join(context_parts) | |
| # ============================================================================ | |
| # USAGE EXAMPLE (for reference) | |
| # ============================================================================ | |
| """ | |
| # In your service file: | |
| from app.ml.retriever import retrieve_documents, format_context | |
| # Retrieve documents | |
| docs = retrieve_documents("What is my account balance?", top_k=5) | |
| # Format context for LLM | |
| context = format_context(docs) | |
| # Use context in LLM prompt | |
| prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
| """ | |
| # """ | |
| # Custom Retriever with E5-Base-V2 and FAISS | |
| # Trained with InfoNCE + Triplet Loss for banking domain | |
| # This is adapted from your RAG.py with: | |
| # - CustomSentenceTransformer (e5-base-v2) | |
| # - Mean pooling + L2 normalization | |
| # - FAISS vector search | |
| # - Module-level caching (load once on startup) | |
| # """ | |
| # import os | |
| # import json | |
| # import pickle | |
| # import torch | |
| # import torch.nn as nn | |
| # import torch.nn.functional as F | |
| # import faiss | |
| # import numpy as np | |
| # from typing import List, Dict, Optional | |
| # from transformers import AutoTokenizer, AutoModel | |
| # from app.config import settings | |
| # # ============================================================================ | |
| # # CUSTOM SENTENCE TRANSFORMER (From RAG.py) | |
| # # ============================================================================ | |
| # class CustomSentenceTransformer(nn.Module): | |
| # """ | |
| # Custom SentenceTransformer matching your training code. | |
| # Uses e5-base-v2 with mean pooling and L2 normalization. | |
| # Training Details: | |
| # - Base model: intfloat/e5-base-v2 | |
| # - Loss: InfoNCE + Triplet Loss | |
| # - Pooling: Mean pooling on last hidden state | |
| # - Normalization: L2 normalization | |
| # """ | |
| # def __init__(self, model_name: str = "intfloat/e5-base-v2"): | |
| # super().__init__() | |
| # # Load pre-trained e5-base-v2 encoder | |
| # self.encoder = AutoModel.from_pretrained(model_name) | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # self.config = self.encoder.config | |
| # def forward(self, input_ids, attention_mask): | |
| # """ | |
| # Forward pass through BERT encoder. | |
| # Args: | |
| # input_ids: Tokenized input IDs | |
| # attention_mask: Attention mask for padding | |
| # Returns: | |
| # torch.Tensor: L2-normalized embeddings (shape: [batch_size, 768]) | |
| # """ | |
| # # Get BERT outputs | |
| # outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| # # Mean pooling - same as training | |
| # # Take hidden states from last layer | |
| # token_embeddings = outputs.last_hidden_state | |
| # # Expand attention mask to match token embeddings shape | |
| # input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| # # Sum embeddings (weighted by attention mask) and divide by sum of mask | |
| # embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| # input_mask_expanded.sum(1), min=1e-9 | |
| # ) | |
| # # L2 normalize embeddings - same as training | |
| # embeddings = F.normalize(embeddings, p=2, dim=1) | |
| # return embeddings | |
| # def encode( | |
| # self, | |
| # sentences: List[str], | |
| # batch_size: int = 32, | |
| # convert_to_numpy: bool = True, | |
| # show_progress_bar: bool = False | |
| # ) -> np.ndarray: | |
| # """ | |
| # Encode sentences using the same method as training. | |
| # Adds 'query: ' prefix for e5-base-v2 compatibility. | |
| # Args: | |
| # sentences: List of sentences to encode | |
| # batch_size: Batch size for encoding | |
| # convert_to_numpy: Whether to convert to numpy array | |
| # show_progress_bar: Whether to show progress bar | |
| # Returns: | |
| # np.ndarray: Encoded embeddings (shape: [num_sentences, 768]) | |
| # """ | |
| # self.eval() # Set model to evaluation mode | |
| # # Handle single string input | |
| # if isinstance(sentences, str): | |
| # sentences = [sentences] | |
| # # Add 'query: ' prefix for e5-base-v2 (required by model) | |
| # # Handle None values and empty strings | |
| # processed_sentences = [] | |
| # for sentence in sentences: | |
| # if sentence is None: | |
| # processed_sentences.append("query: ") # Default empty query | |
| # elif isinstance(sentence, str): | |
| # processed_sentences.append(f"query: {sentence.strip()}") | |
| # else: | |
| # processed_sentences.append(f"query: {str(sentence)}") | |
| # all_embeddings = [] | |
| # # Encode in batches | |
| # with torch.no_grad(): # No gradient computation | |
| # for i in range(0, len(processed_sentences), batch_size): | |
| # batch_sentences = processed_sentences[i:i + batch_size] | |
| # # Tokenize batch | |
| # tokens = self.tokenizer( | |
| # batch_sentences, | |
| # truncation=True, | |
| # padding=True, | |
| # max_length=128, # Same as training | |
| # return_tensors='pt' | |
| # ).to(next(self.parameters()).device) | |
| # # Get embeddings | |
| # embeddings = self.forward(tokens['input_ids'], tokens['attention_mask']) | |
| # # Convert to numpy if requested | |
| # if convert_to_numpy: | |
| # embeddings = embeddings.cpu().numpy() | |
| # all_embeddings.append(embeddings) | |
| # # Combine all batches | |
| # if convert_to_numpy: | |
| # all_embeddings = np.vstack(all_embeddings) | |
| # else: | |
| # all_embeddings = torch.cat(all_embeddings, dim=0) | |
| # return all_embeddings | |
| # # ============================================================================ | |
| # # CUSTOM RETRIEVER MODEL (Wrapper) | |
| # # ============================================================================ | |
| # class CustomRetrieverModel: | |
| # """ | |
| # Wrapper for your custom trained retriever model. | |
| # Handles both knowledge base documents and query encoding. | |
| # """ | |
| # def __init__(self, model_path: str, device: str = "cpu"): | |
| # """ | |
| # Initialize retriever model. | |
| # Args: | |
| # model_path: Path to trained model weights (.pth file) | |
| # device: Device to load model on ('cpu' or 'cuda') | |
| # """ | |
| # self.device = device | |
| # # Create model instance | |
| # self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device) | |
| # # Load your trained weights | |
| # try: | |
| # state_dict = torch.load(model_path, map_location=device) | |
| # self.model.load_state_dict(state_dict) | |
| # print(f"✅ Custom retriever model loaded from {model_path}") | |
| # except Exception as e: | |
| # print(f"❌ Failed to load custom model: {e}") | |
| # print("🔄 Using base e5-base-v2 model (not trained)...") | |
| # # Set to evaluation mode | |
| # self.model.eval() | |
| # def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray: | |
| # """ | |
| # Encode knowledge base documents. | |
| # These are the responses/instructions we're retrieving. | |
| # Args: | |
| # documents: List of document texts | |
| # batch_size: Batch size for encoding | |
| # Returns: | |
| # np.ndarray: Document embeddings (shape: [num_docs, 768]) | |
| # """ | |
| # return self.model.encode(documents, batch_size=batch_size, convert_to_numpy=True) | |
| # def encode_query(self, query: str) -> np.ndarray: | |
| # """ | |
| # Encode user query for retrieval. | |
| # Args: | |
| # query: User query text | |
| # Returns: | |
| # np.ndarray: Query embedding (shape: [1, 768]) | |
| # """ | |
| # return self.model.encode([query], convert_to_numpy=True) | |
| # # ============================================================================ | |
| # # MODULE-LEVEL CACHING (Load once on import) | |
| # # ============================================================================ | |
| # # Global variables for caching | |
| # RETRIEVER_MODEL: Optional[CustomRetrieverModel] = None | |
| # FAISS_INDEX: Optional[faiss.Index] = None | |
| # KB_DATA: Optional[List[Dict]] = None | |
| # # ============================================================================================= | |
| # # Latest version given by perplexity, should work, if not then use one of the other versions. | |
| # # ============================================================================================= | |
| # def load_retriever() -> CustomRetrieverModel: | |
| # """ | |
| # Load custom retriever model (called once on startup). | |
| # Downloads from HuggingFace Hub if not present locally. | |
| # Uses module-level caching - model stays in RAM. | |
| # Returns: | |
| # CustomRetrieverModel: Loaded retriever model | |
| # """ | |
| # global RETRIEVER_MODEL | |
| # if RETRIEVER_MODEL is None: | |
| # # Download model from HF Hub if needed (for deployment) | |
| # settings.download_model_if_needed( | |
| # hf_filename="models/best_retriever_model.pth", | |
| # local_path=settings.RETRIEVER_MODEL_PATH | |
| # ) | |
| # print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...") | |
| # RETRIEVER_MODEL = CustomRetrieverModel( | |
| # model_path=settings.RETRIEVER_MODEL_PATH, | |
| # device=settings.DEVICE | |
| # ) | |
| # print("✅ Retriever model loaded and cached") | |
| # return RETRIEVER_MODEL | |
| # # =========================================================================== | |
| # # This version is used in the code, atleast for localhost testing | |
| # # =========================================================================== | |
| # # def load_retriever() -> CustomRetrieverModel: | |
| # # """ | |
| # # Load custom retriever model (called once on startup). | |
| # # Uses module-level caching - model stays in RAM. | |
| # # Returns: | |
| # # CustomRetrieverModel: Loaded retriever model | |
| # # """ | |
| # # global RETRIEVER_MODEL | |
| # # if RETRIEVER_MODEL is None: | |
| # # print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...") | |
| # # RETRIEVER_MODEL = CustomRetrieverModel( | |
| # # model_path=settings.RETRIEVER_MODEL_PATH, | |
| # # device=settings.DEVICE | |
| # # ) | |
| # # print("✅ Retriever model loaded and cached") | |
| # # return RETRIEVER_MODEL | |
| # # ================================================================================================== | |
| # # Latest version given by perplexity, should work, if not then use one of the other versions. | |
| # # ================================================================================================== | |
| # def load_faiss_index(): | |
| # """ | |
| # Load FAISS index + knowledge base from pickle file. | |
| # Downloads from HuggingFace Hub if not present locally. | |
| # Uses module-level caching - loaded once on startup. | |
| # Returns: | |
| # tuple: (faiss.Index, List[Dict]) - FAISS index and KB data | |
| # """ | |
| # global FAISS_INDEX, KB_DATA | |
| # if FAISS_INDEX is None or KB_DATA is None: | |
| # # Download FAISS index from HF Hub if needed (for deployment) | |
| # settings.download_model_if_needed( | |
| # hf_filename="models/faiss_index.pkl", | |
| # local_path=settings.FAISS_INDEX_PATH | |
| # ) | |
| # # Download knowledge base from HF Hub if needed (for deployment) | |
| # settings.download_model_if_needed( | |
| # hf_filename="data/final_knowledge_base.jsonl", | |
| # local_path=settings.KB_PATH | |
| # ) | |
| # print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...") | |
| # try: | |
| # # Load pickled data | |
| # with open(settings.FAISS_INDEX_PATH, 'rb') as f: | |
| # loaded_data = pickle.load(f) | |
| # print(f"📦 Pickle loaded successfully") | |
| # # ✅ Handle both formats: (index, kb_data) OR (index_bytes, kb_data) | |
| # if isinstance(loaded_data, tuple) and len(loaded_data) == 2: | |
| # first_item, KB_DATA = loaded_data | |
| # # Check if first item is bytes (new format) - SAFE to check | |
| # if isinstance(first_item, bytes): | |
| # print("📦 Detected new format (serialized bytes)") | |
| # FAISS_INDEX = faiss.deserialize_index(first_item) | |
| # print(f"✅ FAISS index deserialized successfully") | |
| # # Otherwise assume it's old format and try to use it | |
| # else: | |
| # print(f"📦 Detected old format (attempting to use directly)") | |
| # # ❌ DON'T use hasattr() - it crashes on corrupted FAISS! | |
| # # Instead, try to use it and catch errors | |
| # try: | |
| # FAISS_INDEX = first_item | |
| # # Test if it works by accessing ntotal | |
| # num_vectors = FAISS_INDEX.ntotal | |
| # print(f"✅ FAISS index is valid ({num_vectors} vectors)") | |
| # except Exception as e: | |
| # print(f"❌ FAISS index object is corrupted: {e}") | |
| # print(f"⚠️ This pickle was created with incompatible FAISS version") | |
| # print(f"") | |
| # print(f"🔧 SOLUTION: Rebuild FAISS index using:") | |
| # print(f" python build_faiss_index.py") | |
| # print(f"") | |
| # raise RuntimeError( | |
| # f"FAISS index is corrupted or incompatible (FAISS version mismatch). " | |
| # f"Please rebuild using: python build_faiss_index.py" | |
| # ) | |
| # else: | |
| # raise ValueError(f"Invalid pickle format: expected tuple, got {type(loaded_data)}") | |
| # print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors") | |
| # print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents") | |
| # except FileNotFoundError: | |
| # print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}") | |
| # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| # raise | |
| # except RuntimeError: | |
| # raise # Re-raise our custom error | |
| # except Exception as e: | |
| # print(f"❌ Failed to load FAISS index: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # raise | |
| # return FAISS_INDEX, KB_DATA | |
| # # ================================================================================================== | |
| # # Second Latest version given by perplexity, should work, if not then use one of the other versions. | |
| # # ================================================================================================== | |
| # # def load_faiss_index(): | |
| # # """ | |
| # # Load FAISS index + knowledge base from pickle file. | |
| # # Downloads from HuggingFace Hub if not present locally. | |
| # # Uses module-level caching - loaded once on startup. | |
| # # Returns: | |
| # # tuple: (faiss.Index, List[Dict]) - FAISS index and KB data | |
| # # """ | |
| # # global FAISS_INDEX, KB_DATA | |
| # # if FAISS_INDEX is None or KB_DATA is None: | |
| # # # Download FAISS index from HF Hub if needed (for deployment) | |
| # # settings.download_model_if_needed( | |
| # # hf_filename="models/faiss_index.pkl", | |
| # # local_path=settings.FAISS_INDEX_PATH | |
| # # ) | |
| # # # Download knowledge base from HF Hub if needed (for deployment) | |
| # # settings.download_model_if_needed( | |
| # # hf_filename="data/final_knowledge_base.jsonl", | |
| # # local_path=settings.KB_PATH | |
| # # ) | |
| # # print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...") | |
| # # try: | |
| # # # Load pickled FAISS index + KB data | |
| # # with open(settings.FAISS_INDEX_PATH, 'rb') as f: | |
| # # FAISS_INDEX, KB_DATA = pickle.load(f) | |
| # # print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors") | |
| # # print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents") | |
| # # except FileNotFoundError: | |
| # # print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}") | |
| # # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}") | |
| # # raise | |
| # # except Exception as e: | |
| # # print(f"❌ Failed to load FAISS index: {e}") | |
| # # raise | |
| # # return FAISS_INDEX, KB_DATA | |
| # # =========================================================================== | |
| # # This version is used in the code, atleast for localhost testing | |
| # # =========================================================================== | |
| # # def load_faiss_index(): | |
| # # """ | |
| # # Load FAISS index + knowledge base from pickle file. | |
| # # Uses module-level caching - loaded once on startup. | |
| # # Returns: | |
| # # tuple: (faiss.Index, List[Dict]) - FAISS index and KB data | |
| # # """ | |
| # # global FAISS_INDEX, KB_DATA | |
| # # if FAISS_INDEX is None or KB_DATA is None: | |
| # # print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...") | |
| # # try: | |
| # # # Load pickled FAISS index + KB data | |
| # # with open(settings.FAISS_INDEX_PATH, 'rb') as f: | |
| # # FAISS_INDEX, KB_DATA = pickle.load(f) | |
| # # print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors") | |
| # # print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents") | |
| # # except FileNotFoundError: | |
| # # print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}") | |
| # # print("⚠️ You need to create the FAISS index first!") | |
| # # raise | |
| # # except Exception as e: | |
| # # print(f"❌ Failed to load FAISS index: {e}") | |
| # # raise | |
| # # return FAISS_INDEX, KB_DATA | |
| # # ============================================================================ | |
| # # RETRIEVAL FUNCTIONS | |
| # # ============================================================================ | |
| # def retrieve_documents( | |
| # query: str, | |
| # top_k: int = None, | |
| # min_similarity: float = None | |
| # ) -> List[Dict]: | |
| # """ | |
| # Retrieve top-k documents for a query using custom retriever + FAISS. | |
| # Args: | |
| # query: User query text | |
| # top_k: Number of documents to retrieve (default from config) | |
| # min_similarity: Minimum similarity threshold (default from config) | |
| # Returns: | |
| # List[Dict]: Retrieved documents with scores | |
| # Each dict contains: | |
| # - instruction: FAQ question | |
| # - response: FAQ answer | |
| # - category: Document category | |
| # - intent: Document intent | |
| # - score: Similarity score (0-1) | |
| # - rank: Rank in results (1-indexed) | |
| # - faq_id: Document ID | |
| # """ | |
| # # Use config defaults if not provided | |
| # if top_k is None: | |
| # top_k = settings.TOP_K | |
| # if min_similarity is None: | |
| # min_similarity = settings.SIMILARITY_THRESHOLD | |
| # # Validate query | |
| # if not query or query.strip() == "": | |
| # print("⚠️ Empty query provided") | |
| # return [] | |
| # # Load models (cached, no overhead after first call) | |
| # retriever = load_retriever() | |
| # index, kb = load_faiss_index() | |
| # try: | |
| # # Step 1: Encode query | |
| # query_embedding = retriever.encode_query(query) | |
| # # Step 2: Normalize for cosine similarity | |
| # faiss.normalize_L2(query_embedding) | |
| # # Step 3: Search in FAISS index | |
| # similarities, indices = index.search(query_embedding, top_k) | |
| # # Step 4: Check similarity threshold for top result | |
| # if similarities[0][0] < min_similarity: | |
| # print(f"🚫 NO_FETCH (similarity: {similarities[0][0]:.3f} < {min_similarity})") | |
| # return [] | |
| # print(f"✅ FETCH (similarity: {similarities[0][0]:.3f} >= {min_similarity})") | |
| # # Step 5: Format results | |
| # results = [] | |
| # for rank, (similarity, idx) in enumerate(zip(similarities[0], indices[0])): | |
| # if idx < len(kb): | |
| # doc = kb[idx] | |
| # results.append({ | |
| # 'instruction': doc.get('instruction', ''), | |
| # 'response': doc.get('response', ''), | |
| # 'category': doc.get('category', 'Unknown'), | |
| # 'intent': doc.get('intent', 'Unknown'), | |
| # 'score': float(similarity), | |
| # 'rank': rank + 1, | |
| # 'faq_id': doc.get('faq_id', f'doc_{idx}') | |
| # }) | |
| # return results | |
| # except Exception as e: | |
| # print(f"❌ Retrieval error: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # return [] | |
| # def format_context(retrieved_docs: List[Dict], max_context_length: int = None) -> str: | |
| # """ | |
| # Format retrieved documents into context string for LLM. | |
| # Prioritizes by score and limits total length. | |
| # Args: | |
| # retrieved_docs: List of retrieved documents | |
| # max_context_length: Maximum context length in characters | |
| # Returns: | |
| # str: Formatted context string | |
| # """ | |
| # if max_context_length is None: | |
| # max_context_length = settings.MAX_CONTEXT_LENGTH | |
| # if not retrieved_docs: | |
| # return "" | |
| # context_parts = [] | |
| # current_length = 0 | |
| # for doc in retrieved_docs: | |
| # # Create context entry with None checks | |
| # instruction = doc.get('instruction', '') or '' | |
| # response = doc.get('response', '') or '' | |
| # category = doc.get('category', 'N/A') or 'N/A' | |
| # context_entry = f"[Rank {doc['rank']}, Score: {doc['score']:.3f}]\n" | |
| # context_entry += f"Q: {instruction}\n" | |
| # context_entry += f"A: {response}\n" | |
| # context_entry += f"Category: {category}\n\n" | |
| # # Check length limit | |
| # if current_length + len(context_entry) > max_context_length: | |
| # break | |
| # context_parts.append(context_entry) | |
| # current_length += len(context_entry) | |
| # return "".join(context_parts) | |
| # # ============================================================================ | |
| # # USAGE EXAMPLE (for reference) | |
| # # ============================================================================ | |
| # """ | |
| # # In your service file: | |
| # from app.ml.retriever import retrieve_documents, format_context | |
| # # Retrieve documents | |
| # docs = retrieve_documents("What is my account balance?", top_k=5) | |
| # # Format context for LLM | |
| # context = format_context(docs) | |
| # # Use context in LLM prompt | |
| # prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
| # """ |