questrag-backend / app /ml /retriever.py
eeshanyaj's picture
pls work
61a1e05
"""
Custom Retriever with E5-Base-V2 and FAISS
Trained with InfoNCE + Triplet Loss for banking domain
This is adapted from your RAG.py with:
- CustomSentenceTransformer (e5-base-v2)
- Mean pooling + L2 normalization
- FAISS vector search
- Module-level caching (load once on startup)
- ✅ Compatible with v2 FAISS format
"""
import os
import json
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import faiss
import numpy as np
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModel
from app.config import settings
# ============================================================================
# CUSTOM SENTENCE TRANSFORMER (From RAG.py)
# ============================================================================
class CustomSentenceTransformer(nn.Module):
"""
Custom SentenceTransformer matching your training code.
Uses e5-base-v2 with mean pooling and L2 normalization.
Training Details:
- Base model: intfloat/e5-base-v2
- Loss: InfoNCE + Triplet Loss
- Pooling: Mean pooling on last hidden state
- Normalization: L2 normalization
"""
def __init__(self, model_name: str = "intfloat/e5-base-v2"):
super().__init__()
# Load pre-trained e5-base-v2 encoder
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.config = self.encoder.config
def forward(self, input_ids, attention_mask):
"""
Forward pass through BERT encoder.
Args:
input_ids: Tokenized input IDs
attention_mask: Attention mask for padding
Returns:
torch.Tensor: L2-normalized embeddings (shape: [batch_size, 768])
"""
# Get BERT outputs
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Mean pooling - same as training
# Take hidden states from last layer
token_embeddings = outputs.last_hidden_state
# Expand attention mask to match token embeddings shape
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# Sum embeddings (weighted by attention mask) and divide by sum of mask
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
# L2 normalize embeddings - same as training
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def encode(
self,
sentences: List[str],
batch_size: int = 32,
convert_to_numpy: bool = True,
show_progress_bar: bool = False
) -> np.ndarray:
"""
Encode sentences using the same method as training.
Adds 'query: ' prefix for e5-base-v2 compatibility.
Args:
sentences: List of sentences to encode
batch_size: Batch size for encoding
convert_to_numpy: Whether to convert to numpy array
show_progress_bar: Whether to show progress bar
Returns:
np.ndarray: Encoded embeddings (shape: [num_sentences, 768])
"""
self.eval() # Set model to evaluation mode
# Handle single string input
if isinstance(sentences, str):
sentences = [sentences]
# Add 'query: ' prefix for e5-base-v2 (required by model)
# Handle None values and empty strings
processed_sentences = []
for sentence in sentences:
if sentence is None:
processed_sentences.append("query: ") # Default empty query
elif isinstance(sentence, str):
processed_sentences.append(f"query: {sentence.strip()}")
else:
processed_sentences.append(f"query: {str(sentence)}")
all_embeddings = []
# Encode in batches
with torch.no_grad(): # No gradient computation
for i in range(0, len(processed_sentences), batch_size):
batch_sentences = processed_sentences[i:i + batch_size]
# Tokenize batch
tokens = self.tokenizer(
batch_sentences,
truncation=True,
padding=True,
max_length=128, # Same as training
return_tensors='pt'
).to(next(self.parameters()).device)
# Get embeddings
embeddings = self.forward(tokens['input_ids'], tokens['attention_mask'])
# Convert to numpy if requested
if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
all_embeddings.append(embeddings)
# Combine all batches
if convert_to_numpy:
all_embeddings = np.vstack(all_embeddings)
else:
all_embeddings = torch.cat(all_embeddings, dim=0)
return all_embeddings
# ============================================================================
# CUSTOM RETRIEVER MODEL (Wrapper)
# ============================================================================
class CustomRetrieverModel:
"""
Wrapper for your custom trained retriever model.
Handles both knowledge base documents and query encoding.
"""
def __init__(self, model_path: str, device: str = "cpu"):
"""
Initialize retriever model.
Args:
model_path: Path to trained model weights (.pth file)
device: Device to load model on ('cpu' or 'cuda')
"""
self.device = device
# Create model instance
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
# Load your trained weights
try:
state_dict = torch.load(model_path, map_location=device)
self.model.load_state_dict(state_dict)
print(f"✅ Custom retriever model loaded from {model_path}")
except Exception as e:
print(f"❌ Failed to load custom model: {e}")
print("🔄 Using base e5-base-v2 model (not trained)...")
# Set to evaluation mode
self.model.eval()
def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
"""
Encode knowledge base documents.
These are the responses/instructions we're retrieving.
Args:
documents: List of document texts
batch_size: Batch size for encoding
Returns:
np.ndarray: Document embeddings (shape: [num_docs, 768])
"""
return self.model.encode(documents, batch_size=batch_size, convert_to_numpy=True)
def encode_query(self, query: str) -> np.ndarray:
"""
Encode user query for retrieval.
Args:
query: User query text
Returns:
np.ndarray: Query embedding (shape: [1, 768])
"""
return self.model.encode([query], convert_to_numpy=True)
# ============================================================================
# MODULE-LEVEL CACHING (Load once on import)
# ============================================================================
# Global variables for caching
RETRIEVER_MODEL: Optional[CustomRetrieverModel] = None
FAISS_INDEX: Optional[faiss.Index] = None
KB_DATA: Optional[List[Dict]] = None
# ============================================================================
# ✅ UPDATED: COMPATIBLE FAISS LOADING (SUPPORTS V2 FORMAT)
# ============================================================================
def load_retriever() -> CustomRetrieverModel:
"""
Load custom retriever model (called once on startup).
Downloads from HuggingFace Hub if not present locally.
Uses module-level caching - model stays in RAM.
Returns:
CustomRetrieverModel: Loaded retriever model
"""
global RETRIEVER_MODEL
if RETRIEVER_MODEL is None:
# Download model from HF Hub if needed (for deployment)
settings.download_model_if_needed(
hf_filename="models/best_retriever_model.pth",
local_path=settings.RETRIEVER_MODEL_PATH
)
print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...")
RETRIEVER_MODEL = CustomRetrieverModel(
model_path=settings.RETRIEVER_MODEL_PATH,
device=settings.DEVICE
)
print("✅ Retriever model loaded and cached")
return RETRIEVER_MODEL
def load_faiss_index():
"""
✅ UPDATED: Load FAISS index with v2 format compatibility.
Supports multiple formats:
- v2 format (dict with version info) - RECOMMENDED
- Old tuple format (index_bytes, kb_data)
- Legacy format (direct FAISS object) - will show warning
Downloads from HuggingFace Hub if not present locally.
Uses module-level caching - loaded once on startup.
Returns:
tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
"""
global FAISS_INDEX, KB_DATA
if FAISS_INDEX is None or KB_DATA is None:
# Download FAISS index from HF Hub if needed (for deployment)
settings.download_model_if_needed(
hf_filename="models/faiss_index.pkl",
local_path=settings.FAISS_INDEX_PATH
)
# Download knowledge base from HF Hub if needed (for deployment)
settings.download_model_if_needed(
hf_filename="data/final_knowledge_base.jsonl",
local_path=settings.KB_PATH
)
print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
try:
# Load pickled data
with open(settings.FAISS_INDEX_PATH, 'rb') as f:
data = pickle.load(f)
print(f"📦 Pickle loaded successfully")
# ========================================================================
# FORMAT 1: v2 Dictionary Format (RECOMMENDED)
# ========================================================================
if isinstance(data, dict) and data.get('format_version') == 'v2':
print("📦 Detected v2 format (compatible)")
try:
FAISS_INDEX = faiss.deserialize_index(data['index_bytes'])
KB_DATA = data['kb_data']
print(f"✅ FAISS index loaded successfully")
print(f" Vectors: {FAISS_INDEX.ntotal}")
print(f" KB docs: {len(KB_DATA)}")
print(f" Dimension: {data.get('dimension', 'unknown')}")
print(f" Built with FAISS: {data.get('faiss_version', 'unknown')}")
print(f" Current FAISS: {faiss.__version__}")
except Exception as e:
print(f"❌ Failed to deserialize v2 format: {e}")
raise RuntimeError(
f"Failed to load v2 FAISS index: {e}\n"
f"Please rebuild using: python build_faiss_index.py"
)
# ========================================================================
# FORMAT 2: Old Tuple Format (index_bytes, kb_data)
# ========================================================================
elif isinstance(data, tuple) and len(data) == 2:
first_item, KB_DATA = data
# Check if first item is bytes (serialized index)
if isinstance(first_item, bytes):
print("📦 Detected old tuple format with bytes (attempting conversion)")
try:
FAISS_INDEX = faiss.deserialize_index(first_item)
print(f"✅ FAISS index deserialized from bytes")
print(f" Vectors: {FAISS_INDEX.ntotal}")
print(f" KB docs: {len(KB_DATA)}")
except Exception as e:
print(f"❌ Failed to deserialize index bytes: {e}")
raise RuntimeError(
f"Failed to deserialize FAISS index: {e}\n"
f"Please rebuild using: python build_faiss_index.py"
)
# Otherwise it's a direct FAISS object (LEGACY - DANGEROUS!)
else:
print(f"📦 Detected old tuple format with direct object")
print(f"⚠️ WARNING: Direct FAISS objects are not compatible across versions")
# Try to use it, but expect it might fail
try:
FAISS_INDEX = first_item
# Test if it works
num_vectors = FAISS_INDEX.ntotal
print(f"✅ FAISS index appears valid ({num_vectors} vectors)")
print(f" KB docs: {len(KB_DATA)}")
print(f"⚠️ However, this format may break across FAISS versions")
print(f"🔧 Recommended: Rebuild using: python build_faiss_index.py")
except Exception as e:
print(f"❌ FAISS index object is corrupted: {e}")
print(f" This usually means FAISS version mismatch")
raise RuntimeError(
f"FAISS index is corrupted or incompatible (version mismatch).\n"
f"Error: {e}\n\n"
f"🔧 SOLUTION: Rebuild FAISS index using:\n"
f" python build_faiss_index.py\n"
)
# ========================================================================
# FORMAT 3: Unknown Format
# ========================================================================
else:
print(f"❌ Unknown pickle format: {type(data)}")
if isinstance(data, dict):
print(f" Dict keys: {list(data.keys())}")
raise ValueError(
f"Unrecognized pickle format: {type(data)}.\n"
f"Please rebuild using: python build_faiss_index.py"
)
# Final validation
if FAISS_INDEX is None or KB_DATA is None:
raise RuntimeError("Failed to load FAISS index or KB data")
print(f"✅ FAISS index ready: {FAISS_INDEX.ntotal} vectors")
print(f"✅ Knowledge base ready: {len(KB_DATA)} documents")
except FileNotFoundError:
print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
raise
except RuntimeError:
raise # Re-raise our custom error with instructions
except Exception as e:
print(f"❌ Unexpected error loading FAISS index: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(
f"Failed to load FAISS index: {e}\n"
f"Please rebuild using: python build_faiss_index.py"
)
return FAISS_INDEX, KB_DATA
def retrieve_documents(
query: str,
top_k: int = None,
min_similarity: float = None
) -> List[Dict]:
"""
Retrieve top-k documents for a query using custom retriever + FAISS.
Args:
query: User query text
top_k: Number of documents to retrieve (default from config)
min_similarity: Minimum similarity threshold (default from config)
Returns:
List[Dict]: Retrieved documents with scores
Each dict contains:
- instruction: FAQ question
- response: FAQ answer
- category: Document category
- intent: Document intent
- score: Similarity score (0-1)
- rank: Rank in results (1-indexed)
- faq_id: Document ID
"""
# Use config defaults if not provided
if top_k is None:
top_k = settings.TOP_K
if min_similarity is None:
min_similarity = settings.SIMILARITY_THRESHOLD
# Validate query
if not query or query.strip() == "":
print("⚠️ Empty query provided")
return []
try:
# Load models (cached, no overhead after first call)
retriever = load_retriever()
index, kb = load_faiss_index()
# Step 1: Encode query
query_embedding = retriever.encode_query(query)
# Step 2: Normalize for cosine similarity
faiss.normalize_L2(query_embedding)
# Step 3: Search in FAISS index
similarities, indices = index.search(query_embedding, top_k)
# Step 4: Check similarity threshold for top result
if similarities[0][0] < min_similarity:
print(f"🚫 NO_FETCH (similarity: {similarities[0][0]:.3f} < {min_similarity})")
return []
print(f"✅ FETCH (similarity: {similarities[0][0]:.3f} >= {min_similarity})")
# Step 5: Format results
results = []
for rank, (similarity, idx) in enumerate(zip(similarities[0], indices[0])):
if idx < len(kb):
doc = kb[idx]
results.append({
'instruction': doc.get('instruction', ''),
'response': doc.get('response', ''),
'category': doc.get('category', 'Unknown'),
'intent': doc.get('intent', 'Unknown'),
'score': float(similarity),
'rank': rank + 1,
'faq_id': doc.get('faq_id', f'doc_{idx}')
})
return results
except RuntimeError as e:
# Handle our custom errors with clear messages
print(f"❌ Retrieval error: {e}")
return []
except Exception as e:
print(f"❌ Unexpected retrieval error: {e}")
import traceback
traceback.print_exc()
return []
def format_context(retrieved_docs: List[Dict], max_context_length: int = None) -> str:
"""
Format retrieved documents into context string for LLM.
Prioritizes by score and limits total length.
Args:
retrieved_docs: List of retrieved documents
max_context_length: Maximum context length in characters
Returns:
str: Formatted context string
"""
if max_context_length is None:
max_context_length = settings.MAX_CONTEXT_LENGTH
if not retrieved_docs:
return ""
context_parts = []
current_length = 0
for doc in retrieved_docs:
# Create context entry with None checks
instruction = doc.get('instruction', '') or ''
response = doc.get('response', '') or ''
category = doc.get('category', 'N/A') or 'N/A'
context_entry = f"[Rank {doc['rank']}, Score: {doc['score']:.3f}]\n"
context_entry += f"Q: {instruction}\n"
context_entry += f"A: {response}\n"
context_entry += f"Category: {category}\n\n"
# Check length limit
if current_length + len(context_entry) > max_context_length:
break
context_parts.append(context_entry)
current_length += len(context_entry)
return "".join(context_parts)
# ============================================================================
# USAGE EXAMPLE (for reference)
# ============================================================================
"""
# In your service file:
from app.ml.retriever import retrieve_documents, format_context
# Retrieve documents
docs = retrieve_documents("What is my account balance?", top_k=5)
# Format context for LLM
context = format_context(docs)
# Use context in LLM prompt
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
"""
# """
# Custom Retriever with E5-Base-V2 and FAISS
# Trained with InfoNCE + Triplet Loss for banking domain
# This is adapted from your RAG.py with:
# - CustomSentenceTransformer (e5-base-v2)
# - Mean pooling + L2 normalization
# - FAISS vector search
# - Module-level caching (load once on startup)
# """
# import os
# import json
# import pickle
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import faiss
# import numpy as np
# from typing import List, Dict, Optional
# from transformers import AutoTokenizer, AutoModel
# from app.config import settings
# # ============================================================================
# # CUSTOM SENTENCE TRANSFORMER (From RAG.py)
# # ============================================================================
# class CustomSentenceTransformer(nn.Module):
# """
# Custom SentenceTransformer matching your training code.
# Uses e5-base-v2 with mean pooling and L2 normalization.
# Training Details:
# - Base model: intfloat/e5-base-v2
# - Loss: InfoNCE + Triplet Loss
# - Pooling: Mean pooling on last hidden state
# - Normalization: L2 normalization
# """
# def __init__(self, model_name: str = "intfloat/e5-base-v2"):
# super().__init__()
# # Load pre-trained e5-base-v2 encoder
# self.encoder = AutoModel.from_pretrained(model_name)
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# self.config = self.encoder.config
# def forward(self, input_ids, attention_mask):
# """
# Forward pass through BERT encoder.
# Args:
# input_ids: Tokenized input IDs
# attention_mask: Attention mask for padding
# Returns:
# torch.Tensor: L2-normalized embeddings (shape: [batch_size, 768])
# """
# # Get BERT outputs
# outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# # Mean pooling - same as training
# # Take hidden states from last layer
# token_embeddings = outputs.last_hidden_state
# # Expand attention mask to match token embeddings shape
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# # Sum embeddings (weighted by attention mask) and divide by sum of mask
# embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
# input_mask_expanded.sum(1), min=1e-9
# )
# # L2 normalize embeddings - same as training
# embeddings = F.normalize(embeddings, p=2, dim=1)
# return embeddings
# def encode(
# self,
# sentences: List[str],
# batch_size: int = 32,
# convert_to_numpy: bool = True,
# show_progress_bar: bool = False
# ) -> np.ndarray:
# """
# Encode sentences using the same method as training.
# Adds 'query: ' prefix for e5-base-v2 compatibility.
# Args:
# sentences: List of sentences to encode
# batch_size: Batch size for encoding
# convert_to_numpy: Whether to convert to numpy array
# show_progress_bar: Whether to show progress bar
# Returns:
# np.ndarray: Encoded embeddings (shape: [num_sentences, 768])
# """
# self.eval() # Set model to evaluation mode
# # Handle single string input
# if isinstance(sentences, str):
# sentences = [sentences]
# # Add 'query: ' prefix for e5-base-v2 (required by model)
# # Handle None values and empty strings
# processed_sentences = []
# for sentence in sentences:
# if sentence is None:
# processed_sentences.append("query: ") # Default empty query
# elif isinstance(sentence, str):
# processed_sentences.append(f"query: {sentence.strip()}")
# else:
# processed_sentences.append(f"query: {str(sentence)}")
# all_embeddings = []
# # Encode in batches
# with torch.no_grad(): # No gradient computation
# for i in range(0, len(processed_sentences), batch_size):
# batch_sentences = processed_sentences[i:i + batch_size]
# # Tokenize batch
# tokens = self.tokenizer(
# batch_sentences,
# truncation=True,
# padding=True,
# max_length=128, # Same as training
# return_tensors='pt'
# ).to(next(self.parameters()).device)
# # Get embeddings
# embeddings = self.forward(tokens['input_ids'], tokens['attention_mask'])
# # Convert to numpy if requested
# if convert_to_numpy:
# embeddings = embeddings.cpu().numpy()
# all_embeddings.append(embeddings)
# # Combine all batches
# if convert_to_numpy:
# all_embeddings = np.vstack(all_embeddings)
# else:
# all_embeddings = torch.cat(all_embeddings, dim=0)
# return all_embeddings
# # ============================================================================
# # CUSTOM RETRIEVER MODEL (Wrapper)
# # ============================================================================
# class CustomRetrieverModel:
# """
# Wrapper for your custom trained retriever model.
# Handles both knowledge base documents and query encoding.
# """
# def __init__(self, model_path: str, device: str = "cpu"):
# """
# Initialize retriever model.
# Args:
# model_path: Path to trained model weights (.pth file)
# device: Device to load model on ('cpu' or 'cuda')
# """
# self.device = device
# # Create model instance
# self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
# # Load your trained weights
# try:
# state_dict = torch.load(model_path, map_location=device)
# self.model.load_state_dict(state_dict)
# print(f"✅ Custom retriever model loaded from {model_path}")
# except Exception as e:
# print(f"❌ Failed to load custom model: {e}")
# print("🔄 Using base e5-base-v2 model (not trained)...")
# # Set to evaluation mode
# self.model.eval()
# def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
# """
# Encode knowledge base documents.
# These are the responses/instructions we're retrieving.
# Args:
# documents: List of document texts
# batch_size: Batch size for encoding
# Returns:
# np.ndarray: Document embeddings (shape: [num_docs, 768])
# """
# return self.model.encode(documents, batch_size=batch_size, convert_to_numpy=True)
# def encode_query(self, query: str) -> np.ndarray:
# """
# Encode user query for retrieval.
# Args:
# query: User query text
# Returns:
# np.ndarray: Query embedding (shape: [1, 768])
# """
# return self.model.encode([query], convert_to_numpy=True)
# # ============================================================================
# # MODULE-LEVEL CACHING (Load once on import)
# # ============================================================================
# # Global variables for caching
# RETRIEVER_MODEL: Optional[CustomRetrieverModel] = None
# FAISS_INDEX: Optional[faiss.Index] = None
# KB_DATA: Optional[List[Dict]] = None
# # =============================================================================================
# # Latest version given by perplexity, should work, if not then use one of the other versions.
# # =============================================================================================
# def load_retriever() -> CustomRetrieverModel:
# """
# Load custom retriever model (called once on startup).
# Downloads from HuggingFace Hub if not present locally.
# Uses module-level caching - model stays in RAM.
# Returns:
# CustomRetrieverModel: Loaded retriever model
# """
# global RETRIEVER_MODEL
# if RETRIEVER_MODEL is None:
# # Download model from HF Hub if needed (for deployment)
# settings.download_model_if_needed(
# hf_filename="models/best_retriever_model.pth",
# local_path=settings.RETRIEVER_MODEL_PATH
# )
# print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...")
# RETRIEVER_MODEL = CustomRetrieverModel(
# model_path=settings.RETRIEVER_MODEL_PATH,
# device=settings.DEVICE
# )
# print("✅ Retriever model loaded and cached")
# return RETRIEVER_MODEL
# # ===========================================================================
# # This version is used in the code, atleast for localhost testing
# # ===========================================================================
# # def load_retriever() -> CustomRetrieverModel:
# # """
# # Load custom retriever model (called once on startup).
# # Uses module-level caching - model stays in RAM.
# # Returns:
# # CustomRetrieverModel: Loaded retriever model
# # """
# # global RETRIEVER_MODEL
# # if RETRIEVER_MODEL is None:
# # print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...")
# # RETRIEVER_MODEL = CustomRetrieverModel(
# # model_path=settings.RETRIEVER_MODEL_PATH,
# # device=settings.DEVICE
# # )
# # print("✅ Retriever model loaded and cached")
# # return RETRIEVER_MODEL
# # ==================================================================================================
# # Latest version given by perplexity, should work, if not then use one of the other versions.
# # ==================================================================================================
# def load_faiss_index():
# """
# Load FAISS index + knowledge base from pickle file.
# Downloads from HuggingFace Hub if not present locally.
# Uses module-level caching - loaded once on startup.
# Returns:
# tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
# """
# global FAISS_INDEX, KB_DATA
# if FAISS_INDEX is None or KB_DATA is None:
# # Download FAISS index from HF Hub if needed (for deployment)
# settings.download_model_if_needed(
# hf_filename="models/faiss_index.pkl",
# local_path=settings.FAISS_INDEX_PATH
# )
# # Download knowledge base from HF Hub if needed (for deployment)
# settings.download_model_if_needed(
# hf_filename="data/final_knowledge_base.jsonl",
# local_path=settings.KB_PATH
# )
# print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
# try:
# # Load pickled data
# with open(settings.FAISS_INDEX_PATH, 'rb') as f:
# loaded_data = pickle.load(f)
# print(f"📦 Pickle loaded successfully")
# # ✅ Handle both formats: (index, kb_data) OR (index_bytes, kb_data)
# if isinstance(loaded_data, tuple) and len(loaded_data) == 2:
# first_item, KB_DATA = loaded_data
# # Check if first item is bytes (new format) - SAFE to check
# if isinstance(first_item, bytes):
# print("📦 Detected new format (serialized bytes)")
# FAISS_INDEX = faiss.deserialize_index(first_item)
# print(f"✅ FAISS index deserialized successfully")
# # Otherwise assume it's old format and try to use it
# else:
# print(f"📦 Detected old format (attempting to use directly)")
# # ❌ DON'T use hasattr() - it crashes on corrupted FAISS!
# # Instead, try to use it and catch errors
# try:
# FAISS_INDEX = first_item
# # Test if it works by accessing ntotal
# num_vectors = FAISS_INDEX.ntotal
# print(f"✅ FAISS index is valid ({num_vectors} vectors)")
# except Exception as e:
# print(f"❌ FAISS index object is corrupted: {e}")
# print(f"⚠️ This pickle was created with incompatible FAISS version")
# print(f"")
# print(f"🔧 SOLUTION: Rebuild FAISS index using:")
# print(f" python build_faiss_index.py")
# print(f"")
# raise RuntimeError(
# f"FAISS index is corrupted or incompatible (FAISS version mismatch). "
# f"Please rebuild using: python build_faiss_index.py"
# )
# else:
# raise ValueError(f"Invalid pickle format: expected tuple, got {type(loaded_data)}")
# print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors")
# print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents")
# except FileNotFoundError:
# print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
# print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
# raise
# except RuntimeError:
# raise # Re-raise our custom error
# except Exception as e:
# print(f"❌ Failed to load FAISS index: {e}")
# import traceback
# traceback.print_exc()
# raise
# return FAISS_INDEX, KB_DATA
# # ==================================================================================================
# # Second Latest version given by perplexity, should work, if not then use one of the other versions.
# # ==================================================================================================
# # def load_faiss_index():
# # """
# # Load FAISS index + knowledge base from pickle file.
# # Downloads from HuggingFace Hub if not present locally.
# # Uses module-level caching - loaded once on startup.
# # Returns:
# # tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
# # """
# # global FAISS_INDEX, KB_DATA
# # if FAISS_INDEX is None or KB_DATA is None:
# # # Download FAISS index from HF Hub if needed (for deployment)
# # settings.download_model_if_needed(
# # hf_filename="models/faiss_index.pkl",
# # local_path=settings.FAISS_INDEX_PATH
# # )
# # # Download knowledge base from HF Hub if needed (for deployment)
# # settings.download_model_if_needed(
# # hf_filename="data/final_knowledge_base.jsonl",
# # local_path=settings.KB_PATH
# # )
# # print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
# # try:
# # # Load pickled FAISS index + KB data
# # with open(settings.FAISS_INDEX_PATH, 'rb') as f:
# # FAISS_INDEX, KB_DATA = pickle.load(f)
# # print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors")
# # print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents")
# # except FileNotFoundError:
# # print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
# # print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
# # raise
# # except Exception as e:
# # print(f"❌ Failed to load FAISS index: {e}")
# # raise
# # return FAISS_INDEX, KB_DATA
# # ===========================================================================
# # This version is used in the code, atleast for localhost testing
# # ===========================================================================
# # def load_faiss_index():
# # """
# # Load FAISS index + knowledge base from pickle file.
# # Uses module-level caching - loaded once on startup.
# # Returns:
# # tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
# # """
# # global FAISS_INDEX, KB_DATA
# # if FAISS_INDEX is None or KB_DATA is None:
# # print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
# # try:
# # # Load pickled FAISS index + KB data
# # with open(settings.FAISS_INDEX_PATH, 'rb') as f:
# # FAISS_INDEX, KB_DATA = pickle.load(f)
# # print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors")
# # print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents")
# # except FileNotFoundError:
# # print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
# # print("⚠️ You need to create the FAISS index first!")
# # raise
# # except Exception as e:
# # print(f"❌ Failed to load FAISS index: {e}")
# # raise
# # return FAISS_INDEX, KB_DATA
# # ============================================================================
# # RETRIEVAL FUNCTIONS
# # ============================================================================
# def retrieve_documents(
# query: str,
# top_k: int = None,
# min_similarity: float = None
# ) -> List[Dict]:
# """
# Retrieve top-k documents for a query using custom retriever + FAISS.
# Args:
# query: User query text
# top_k: Number of documents to retrieve (default from config)
# min_similarity: Minimum similarity threshold (default from config)
# Returns:
# List[Dict]: Retrieved documents with scores
# Each dict contains:
# - instruction: FAQ question
# - response: FAQ answer
# - category: Document category
# - intent: Document intent
# - score: Similarity score (0-1)
# - rank: Rank in results (1-indexed)
# - faq_id: Document ID
# """
# # Use config defaults if not provided
# if top_k is None:
# top_k = settings.TOP_K
# if min_similarity is None:
# min_similarity = settings.SIMILARITY_THRESHOLD
# # Validate query
# if not query or query.strip() == "":
# print("⚠️ Empty query provided")
# return []
# # Load models (cached, no overhead after first call)
# retriever = load_retriever()
# index, kb = load_faiss_index()
# try:
# # Step 1: Encode query
# query_embedding = retriever.encode_query(query)
# # Step 2: Normalize for cosine similarity
# faiss.normalize_L2(query_embedding)
# # Step 3: Search in FAISS index
# similarities, indices = index.search(query_embedding, top_k)
# # Step 4: Check similarity threshold for top result
# if similarities[0][0] < min_similarity:
# print(f"🚫 NO_FETCH (similarity: {similarities[0][0]:.3f} < {min_similarity})")
# return []
# print(f"✅ FETCH (similarity: {similarities[0][0]:.3f} >= {min_similarity})")
# # Step 5: Format results
# results = []
# for rank, (similarity, idx) in enumerate(zip(similarities[0], indices[0])):
# if idx < len(kb):
# doc = kb[idx]
# results.append({
# 'instruction': doc.get('instruction', ''),
# 'response': doc.get('response', ''),
# 'category': doc.get('category', 'Unknown'),
# 'intent': doc.get('intent', 'Unknown'),
# 'score': float(similarity),
# 'rank': rank + 1,
# 'faq_id': doc.get('faq_id', f'doc_{idx}')
# })
# return results
# except Exception as e:
# print(f"❌ Retrieval error: {e}")
# import traceback
# traceback.print_exc()
# return []
# def format_context(retrieved_docs: List[Dict], max_context_length: int = None) -> str:
# """
# Format retrieved documents into context string for LLM.
# Prioritizes by score and limits total length.
# Args:
# retrieved_docs: List of retrieved documents
# max_context_length: Maximum context length in characters
# Returns:
# str: Formatted context string
# """
# if max_context_length is None:
# max_context_length = settings.MAX_CONTEXT_LENGTH
# if not retrieved_docs:
# return ""
# context_parts = []
# current_length = 0
# for doc in retrieved_docs:
# # Create context entry with None checks
# instruction = doc.get('instruction', '') or ''
# response = doc.get('response', '') or ''
# category = doc.get('category', 'N/A') or 'N/A'
# context_entry = f"[Rank {doc['rank']}, Score: {doc['score']:.3f}]\n"
# context_entry += f"Q: {instruction}\n"
# context_entry += f"A: {response}\n"
# context_entry += f"Category: {category}\n\n"
# # Check length limit
# if current_length + len(context_entry) > max_context_length:
# break
# context_parts.append(context_entry)
# current_length += len(context_entry)
# return "".join(context_parts)
# # ============================================================================
# # USAGE EXAMPLE (for reference)
# # ============================================================================
# """
# # In your service file:
# from app.ml.retriever import retrieve_documents, format_context
# # Retrieve documents
# docs = retrieve_documents("What is my account balance?", top_k=5)
# # Format context for LLM
# context = format_context(docs)
# # Use context in LLM prompt
# prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
# """