VectorStoreFlowModule / ChromaDBFlow.py
nbaldwin's picture
modified for new backend
46d0705
raw
history blame
2.92 kB
import os
from typing import Dict, List, Any
import uuid
from copy import deepcopy
from langchain.embeddings import OpenAIEmbeddings
from chromadb import Client as ChromaClient
from flows.base_flows import AtomicFlow
import hydra
class ChromaDBFlow(AtomicFlow):
def __init__(self, backend,**kwargs):
super().__init__(**kwargs)
self.client = ChromaClient()
self.collection = self.client.get_or_create_collection(name=self.flow_config["name"])
self.backend = backend
@classmethod
def _set_up_backend(cls, config):
kwargs = {}
kwargs["backend"] = \
hydra.utils.instantiate(config['backend'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up backend ~~~
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def get_input_keys(self) -> List[str]:
return self.flow_config["input_keys"]
def get_output_keys(self) -> List[str]:
return self.flow_config["output_keys"]
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
api_information = self.backend.get_key()
if api_information.backend_used == "openai":
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key)
else:
# ToDo: Add support for Azure
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
response = {}
operation = input_data["operation"]
if operation not in ["write", "read"]:
raise ValueError(f"Operation '{operation}' not supported")
content = input_data["content"]
if operation == "read":
if not isinstance(content, str):
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}")
if content == "":
response["retrieved"] = [[""]]
return response
query = content
query_result = self.collection.query(
query_embeddings=embeddings.embed_query(query),
n_results=self.flow_config["n_results"]
)
response["retrieved"] = [doc for doc in query_result["documents"]]
elif operation == "write":
if content != "":
if not isinstance(content, list):
content = [content]
documents = content
self.collection.add(
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
embeddings=embeddings.embed_documents(documents),
documents=documents
)
response["retrieved"] = ""
return response