from copy import deepcopy from typing import Dict, List, Any, Optional import faiss from langchain.docstore import InMemoryDocstore from langchain.embeddings import OpenAIEmbeddings from langchain.schema import Document from langchain.vectorstores import Chroma, FAISS from langchain.vectorstores.base import VectorStoreRetriever from flows.base_flows import AtomicFlow import hydra class VectorStoreFlow(AtomicFlow): REQUIRED_KEYS_CONFIG = ["type"] vector_db: VectorStoreRetriever def __init__(self, backend,vector_db, **kwargs): super().__init__(**kwargs) self.vector_db = vector_db @classmethod def _set_up_backend(cls, config): kwargs = {} kwargs["backend"] = \ hydra.utils.instantiate(config['backend'], _convert_="partial") return kwargs @classmethod def _set_up_retriever(cls, api_information,config: Dict[str, Any]) -> Dict[str, Any]: embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) kwargs = {} vs_type = config["type"] if vs_type == "chroma": vectorstore = Chroma(config["name"], embedding_function=embeddings) elif vs_type == "faiss": index = faiss.IndexFlatL2(config.get("embedding_size", 1536)) vectorstore = FAISS( embedding_function=embeddings.embed_query, index=index, docstore=InMemoryDocstore({}), index_to_docstore_id={} ) else: raise NotImplementedError(f"Vector store '{vs_type}' not implemented") kwargs["vector_db"] = vectorstore.as_retriever(**config.get("retriever_config", {})) return kwargs @classmethod def instantiate_from_config(cls, config: Dict[str, Any]): flow_config = deepcopy(config) kwargs = {"flow_config": flow_config} # ~~~ Set up backend ~~~ kwargs.update(cls._set_up_backend(flow_config)) api_information = kwargs["backend"].get_key() kwargs.update(cls._set_up_retriever(api_information,flow_config)) return cls(**kwargs) @staticmethod def package_documents(documents: List[str]) -> List[Document]: # TODO(yeeef): support metadata return [Document(page_content=doc, metadata={"": ""}) for doc in documents] def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: response = {} operation = input_data["operation"] assert operation in ["write", "read"], f"Operation '{operation}' not supported" content = input_data["content"] if operation == "read": assert isinstance(content, str), f"Content must be a string, got {type(content)}" query = content retrieved_documents = self.vector_db.get_relevant_documents(query) response["retrieved"] = [doc.page_content for doc in retrieved_documents] elif operation == "write": if isinstance(content, str): content = [content] assert isinstance(content, list), f"Content must be a list of strings, got {type(content)}" documents = content documents = self.package_documents(documents) self.vector_db.add_documents(documents) response["retrieved"] = "" return response