YuITC
Add application file
c8e875f
"""
Vector storage and retrieval implementation.
"""
import uuid
from typing import List, Any
from langchain_chroma import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from src.config import EMBEDDING_MODEL, DEVICE, COLLECTION_NAME
class VectorStore:
"""Vector storage and retrieval implementation."""
def __init__(self, collection_name: str = COLLECTION_NAME, embedding_model: str = EMBEDDING_MODEL):
"""
Initialize the vector store.
Args:
collection_name (str): Name of the vector store collection
embedding_model (str): Name of the embedding model to use
"""
self.embedding_function = self._create_embedding_function(embedding_model)
self.vector_store = self._create_vector_store(collection_name)
self.doc_store = InMemoryStore()
self.id_key = 'doc_id'
self.retriever = self._create_retriever()
def _create_embedding_function(self, model_name: str) -> HuggingFaceEmbeddings:
"""
Create an embedding function.
Args:
model_name (str): Name of the embedding model
Returns:
HuggingFaceEmbeddings: The embedding function
"""
return HuggingFaceEmbeddings(
model_name = model_name,
model_kwargs = {'device': DEVICE},
encode_kwargs = {'normalize_embeddings': True} # Change this if use an already normalized model
)
def _create_vector_store(self, collection_name: str) -> Chroma:
"""
Create a vector store.
Args:
collection_name (str): Name of the vector store collection
Returns:
Chroma: The vector store
"""
return Chroma(
collection_name = collection_name,
embedding_function = self.embedding_function,
)
def _create_retriever(self) -> MultiVectorRetriever:
"""
Create a multi-vector retriever.
Returns:
MultiVectorRetriever: The retriever
"""
return MultiVectorRetriever(
vectorstore = self.vector_store,
docstore = self.doc_store,
id_key = self.id_key,
)
def add_to_retriever(self, data: List[Any], data_summaries: List[str]) -> None:
"""
Add data and summaries to the retriever.
Args:
data (List[Any]): List of data elements
data_summaries (List[str]): List of data summaries
"""
if not data:
return
if len(data) != len(data_summaries):
raise ValueError(f"Length mismatch: {len(data)} data but {len(data_summaries)} summaries")
ids = [str(uuid.uuid4()) for _ in range(len(data))]
summaries = [
Document(
page_content = f"passage: {summary}", # Change this to suit with model requirements if use a different model
metadata = {self.id_key: i}
)
for i, summary in zip(ids, data_summaries)
]
self.retriever.vectorstore.add_documents(summaries)
self.retriever.docstore.mset(list(zip(ids, data)))
def add_contents(self,
texts : List[Any], text_summaries : List[str],
tables: List[Any], table_summaries: List[str],
images: List[Any], image_summaries: List[str]) -> None:
"""
Add all content types and their summaries to the retriever.
Args:
texts (List[Any]): List of text elements
text_summaries (List[str]): List of text summaries
tables (List[Any]): List of table elements
table_summaries (List[str]): List of table summaries
images (List[Any]): List of image elements
image_summaries (List[str]): List of image summaries
"""
self.add_to_retriever(texts , text_summaries)
self.add_to_retriever(tables, table_summaries)
self.add_to_retriever(images, image_summaries)
def reset(self) -> None:
"""Reset the vector store and document store."""
try:
self.vector_store.reset_collection()
except Exception as e:
raise RuntimeError(f"Failed to reset vector store: {e}")
# self.vector_store = self._create_vector_store(COLLECTION_NAME)
self.doc_store = InMemoryStore()
self.retriever = self._create_retriever()
def retrieve(self, query: str) -> List[Any]:
"""
Retrieve relevant documents for a query.
Args:
query (str): The query string
Returns:
List[Any]: List of retrieved documents
"""
return self.retriever.invoke(query)