Arabic-Rag-Chatbot / search_engine.py
Ahmed-Alghamdi's picture
Update search_engine.py
42febe9 verified
# search_engine.py
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from utils import setup_logger
from config import Config
logger = setup_logger('search_engine')
class SearchEngine:
def __init__(self, documents, embeddings):
self.documents = documents
self.embeddings = embeddings # NEW: Store embeddings for reference
self.index = self._build_faiss_index(embeddings)
self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
def _build_faiss_index(self, embeddings):
dimension = embeddings.shape[1]
# NEW: Use IndexFlatIP for cosine similarity (better than L2)
index = faiss.IndexFlatIP(dimension)
# NEW: Normalize embeddings for cosine similarity
embeddings_normalized = embeddings.astype('float32').copy()
faiss.normalize_L2(embeddings_normalized)
index.add(embeddings_normalized)
logger.info(f"FAISS index built with {embeddings.shape[0]} vectors (cosine similarity)")
return index
def search(self, query):
try:
# Encode query
query_embedding = self.model.encode([query]).astype('float32')
# NEW: Normalize query for cosine similarity
faiss.normalize_L2(query_embedding)
# NEW: Get more results to filter
search_k = min(Config.TOP_K * 2, len(self.documents))
scores, indices = self.index.search(query_embedding, search_k)
# NEW: Filter by similarity threshold
valid_mask = scores[0] >= Config.MIN_SIMILARITY_SCORE
filtered_indices = indices[0][valid_mask]
filtered_scores = scores[0][valid_mask]
# NEW: Limit to TOP_K after filtering
if len(filtered_indices) > Config.TOP_K:
filtered_indices = filtered_indices[:Config.TOP_K]
filtered_scores = filtered_scores[:Config.TOP_K]
# NEW: Handle no results case
if len(filtered_indices) == 0:
logger.warning(f"No results above similarity threshold {Config.MIN_SIMILARITY_SCORE}")
return pd.DataFrame()
# NEW: Add similarity scores to results
results = self.documents.iloc[filtered_indices].copy()
results['similarity_score'] = filtered_scores
# NEW: Sort by similarity score (best first)
results = results.sort_values('similarity_score', ascending=False)
# NEW: Better logging
logger.info(f"Found {len(results)} chunks (scores: {filtered_scores.min():.2f} - {filtered_scores.max():.2f})")
return results
except Exception as e:
logger.error(f"Error searching documents: {e}")
return pd.DataFrame()