DemoChatBot / vector_store.py
OnlyTheTruth03's picture
Initial Commit
721ca73 verified
"""
vector_store.py
───────────────
Handles text chunking, embedding, and FAISS vector store creation/querying.
Responsibilities:
- Split raw Documents into overlapping chunks
- Embed chunks using a local HuggingFace sentence-transformer
- Build and expose a FAISS index for similarity search
- Provide a clean retrieve() function used by the RAG pipeline
"""
import logging
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from config import cfg
logger = logging.getLogger(__name__)
# ── Public API ────────────────────────────────────────────────────────────────
def build_index(documents: list[Document]) -> FAISS:
"""
Chunk β†’ embed β†’ index the supplied documents.
Parameters
----------
documents : list[Document]
Raw documents returned by data_loader.load_documents().
Returns
-------
FAISS
A ready-to-query FAISS vector store.
"""
chunks = _chunk_documents(documents)
embeddings = _load_embeddings()
index = _create_faiss_index(chunks, embeddings)
return index
def retrieve(index: FAISS, query: str, k: int | None = None) -> list[Document]:
"""
Retrieve the top-k most relevant chunks for a given query.
Parameters
----------
index : FAISS
The FAISS vector store built by build_index().
query : str
The user's natural-language question.
k : int, optional
Number of results to return. Defaults to cfg.top_k.
Returns
-------
list[Document]
Retrieved chunks, most relevant first.
"""
k = k or cfg.top_k
results = index.similarity_search(query, k=k)
logger.debug("Retrieved %d chunks for query: '%s'", len(results), query[:80])
return results
# ── Internal helpers ──────────────────────────────────────────────────────────
def _chunk_documents(documents: list[Document]) -> list[Document]:
"""Split documents into smaller overlapping chunks."""
splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.chunk_size,
chunk_overlap=cfg.chunk_overlap,
separators=["\n\n", "\n", ". ", " ", ""],
)
chunks = splitter.split_documents(documents)
logger.info(
"Chunking: %d raw docs β†’ %d chunks (size=%d, overlap=%d)",
len(documents), len(chunks), cfg.chunk_size, cfg.chunk_overlap,
)
return chunks
def _load_embeddings() -> HuggingFaceEmbeddings:
"""Load the local sentence-transformer embedding model (cached after first call)."""
logger.info("Loading embedding model: %s", cfg.embed_model)
return HuggingFaceEmbeddings(
model_name=cfg.embed_model,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
def _create_faiss_index(chunks: list[Document], embeddings: HuggingFaceEmbeddings) -> FAISS:
"""Embed all chunks and build the FAISS index."""
logger.info("Building FAISS index over %d chunks …", len(chunks))
index = FAISS.from_documents(chunks, embeddings)
logger.info("FAISS index built βœ“ (vectors: %d)", index.index.ntotal)
return index