File size: 2,964 Bytes
bfb6e70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from copy import deepcopy
from typing import Dict, List, Any, Optional

import faiss

from langchain.docstore import InMemoryDocstore
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.vectorstores import Chroma, FAISS
from langchain.vectorstores.base import VectorStoreRetriever

from flows.base_flows import AtomicFlow


class VectorStoreFlow(AtomicFlow):
    REQUIRED_KEYS_CONFIG = ["type", "api_keys"]

    vector_db: VectorStoreRetriever

    def __init__(self, vector_db, **kwargs):
        super().__init__(**kwargs)
        self.vector_db = vector_db

    @classmethod
    def _set_up_retriever(cls, config: Dict[str, Any]) -> Dict[str, Any]:
        embeddings = OpenAIEmbeddings(openai_api_key=config["api_keys"]["openai"])
        kwargs = {}

        vs_type = config["type"]

        if vs_type == "chroma":
            vectorstore = Chroma(config["name"], embedding_function=embeddings)
        elif vs_type == "faiss":
            index = faiss.IndexFlatL2(config.get("embedding_size", 1536))
            vectorstore = FAISS(
                embedding_function=embeddings.embed_query,
                index=index,
                docstore=InMemoryDocstore({}),
                index_to_docstore_id={}
            )
        else:
            raise NotImplementedError(f"Vector store '{vs_type}' not implemented")

        kwargs["vector_db"] = vectorstore.as_retriever(**config.get("retriever_config", {}))

        return kwargs

    @classmethod
    def instantiate_from_config(cls, config: Dict[str, Any]):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        kwargs.update(cls._set_up_retriever(flow_config))

        return cls(**kwargs)

    @staticmethod
    def package_documents(documents: List[str]) -> List[Document]:
        # TODO(yeeef): support metadata
        return [Document(page_content=doc, metadata={"": ""}) for doc in documents]

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        response = {}

        operation = input_data["operation"]
        assert operation in ["write", "read"], f"Operation '{operation}' not supported"

        content = input_data["content"]
        if operation == "read":
            assert isinstance(content, str), f"Content must be a string, got {type(content)}"
            query = content
            retrieved_documents = self.vector_db.get_relevant_documents(query)
            response["retrieved"] = [doc.page_content for doc in retrieved_documents]
        elif operation == "write":
            if isinstance(content, str):
                content = [content]
            assert isinstance(content, list), f"Content must be a list of strings, got {type(content)}"
            documents = content
            documents = self.package_documents(documents)
            self.vector_db.add_documents(documents)
            response["retrieved"] = ""

        return response