| |
| import faiss |
| import numpy as np |
| import pandas as pd |
| from sentence_transformers import SentenceTransformer |
| from utils import setup_logger |
| from config import Config |
|
|
| logger = setup_logger('search_engine') |
|
|
| class SearchEngine: |
| def __init__(self, documents, embeddings): |
| self.documents = documents |
| self.embeddings = embeddings |
| self.index = self._build_faiss_index(embeddings) |
| self.model = SentenceTransformer(Config.EMBEDDING_MODEL) |
| |
| def _build_faiss_index(self, embeddings): |
| dimension = embeddings.shape[1] |
| |
| |
| index = faiss.IndexFlatIP(dimension) |
| |
| |
| embeddings_normalized = embeddings.astype('float32').copy() |
| faiss.normalize_L2(embeddings_normalized) |
| index.add(embeddings_normalized) |
| |
| logger.info(f"FAISS index built with {embeddings.shape[0]} vectors (cosine similarity)") |
| return index |
| |
| def search(self, query): |
| try: |
| |
| query_embedding = self.model.encode([query]).astype('float32') |
| |
| |
| faiss.normalize_L2(query_embedding) |
| |
| |
| search_k = min(Config.TOP_K * 2, len(self.documents)) |
| scores, indices = self.index.search(query_embedding, search_k) |
| |
| |
| valid_mask = scores[0] >= Config.MIN_SIMILARITY_SCORE |
| filtered_indices = indices[0][valid_mask] |
| filtered_scores = scores[0][valid_mask] |
| |
| |
| if len(filtered_indices) > Config.TOP_K: |
| filtered_indices = filtered_indices[:Config.TOP_K] |
| filtered_scores = filtered_scores[:Config.TOP_K] |
| |
| |
| if len(filtered_indices) == 0: |
| logger.warning(f"No results above similarity threshold {Config.MIN_SIMILARITY_SCORE}") |
| return pd.DataFrame() |
| |
| |
| results = self.documents.iloc[filtered_indices].copy() |
| results['similarity_score'] = filtered_scores |
| |
| |
| results = results.sort_values('similarity_score', ascending=False) |
| |
| |
| logger.info(f"Found {len(results)} chunks (scores: {filtered_scores.min():.2f} - {filtered_scores.max():.2f})") |
| |
| return results |
| |
| except Exception as e: |
| logger.error(f"Error searching documents: {e}") |
| return pd.DataFrame() |
|
|