| from pathlib import Path |
| from typing import List, Optional |
| import logging |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain_community.vectorstores import FAISS |
| from src.llm.utils.logging import TheryBotLogger |
|
|
| class FAISSVectorSearch: |
| def __init__( |
| self, |
| embedding_model: Optional[HuggingFaceEmbeddings] = None, |
| db_path: Path = Path("vector_embedding/mental_health_vector_db"), |
| k: int = 5, |
| logger: Optional[TheryBotLogger] = None |
| ): |
| self.embedding_model = embedding_model or self._get_default_embedding_model() |
| self.db_path = db_path |
| self.k = k |
| self.logger = logger or TheryBotLogger() |
| self._initialize_store() |
| |
| def _get_default_embedding_model(self) -> HuggingFaceEmbeddings: |
| return HuggingFaceEmbeddings( |
| model_name="sentence-transformers/all-MiniLM-L6-v2", |
| model_kwargs={"device": "cpu"}, |
| encode_kwargs={ |
| "padding": "max_length", |
| "max_length": 512, |
| "truncation": True, |
| "normalize_embeddings": True |
| } |
| ) |
| |
| def _initialize_store(self) -> None: |
| if self.db_path.exists(): |
| self.vectorstore = FAISS.load_local( |
| str(self.db_path), |
| self.embedding_model, |
| allow_dangerous_deserialization=True |
| ) |
| else: |
| |
| self.vectorstore = FAISS.from_texts( |
| [""], self.embedding_model |
| ) |
| |
| def search(self, query: str, k: Optional[int] = None) -> List[str]: |
| try: |
| results = self.vectorstore.similarity_search( |
| query, |
| k=(k or self.k) |
| ) |
| return [res.page_content for res in results] |
| except Exception as e: |
| |
| self.logger.log_interaction( |
| interaction_type="vector_search_error", |
| data={"error": str(e)}, |
| level=logging.ERROR |
| ) |
| return [] |
| |
| def add_texts(self, texts: List[str]) -> None: |
| """Add new texts to the vector store""" |
| self.vectorstore.add_texts(texts) |
| |
| self.save() |
| |
| def save(self) -> None: |
| """Save the vector store to disk""" |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) |
| self.vectorstore.save_local(str(self.db_path)) |