|
|
import os |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
import uuid |
|
|
from copy import deepcopy |
|
|
from langchain.embeddings import OpenAIEmbeddings |
|
|
|
|
|
from chromadb import Client as ChromaClient |
|
|
|
|
|
from flows.base_flows import AtomicFlow |
|
|
|
|
|
import hydra |
|
|
|
|
|
class ChromaDBFlow(AtomicFlow): |
|
|
|
|
|
def __init__(self, backend,**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.client = ChromaClient() |
|
|
self.collection = self.client.get_or_create_collection(name=self.flow_config["name"]) |
|
|
self.backend = backend |
|
|
|
|
|
@classmethod |
|
|
def _set_up_backend(cls, config): |
|
|
kwargs = {} |
|
|
|
|
|
kwargs["backend"] = \ |
|
|
hydra.utils.instantiate(config['backend'], _convert_="partial") |
|
|
|
|
|
return kwargs |
|
|
|
|
|
@classmethod |
|
|
def instantiate_from_config(cls, config): |
|
|
flow_config = deepcopy(config) |
|
|
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
|
|
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
|
|
|
|
|
|
|
return cls(**kwargs) |
|
|
|
|
|
def get_input_keys(self) -> List[str]: |
|
|
return self.flow_config["input_keys"] |
|
|
|
|
|
def get_output_keys(self) -> List[str]: |
|
|
return self.flow_config["output_keys"] |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
api_information = self.backend.get_key() |
|
|
|
|
|
if api_information.backend_used == "openai": |
|
|
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
|
|
else: |
|
|
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY")) |
|
|
response = {} |
|
|
|
|
|
operation = input_data["operation"] |
|
|
if operation not in ["write", "read"]: |
|
|
raise ValueError(f"Operation '{operation}' not supported") |
|
|
|
|
|
content = input_data["content"] |
|
|
if operation == "read": |
|
|
if not isinstance(content, str): |
|
|
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}") |
|
|
if content == "": |
|
|
response["retrieved"] = [[""]] |
|
|
return response |
|
|
query = content |
|
|
query_result = self.collection.query( |
|
|
query_embeddings=embeddings.embed_query(query), |
|
|
n_results=self.flow_config["n_results"] |
|
|
) |
|
|
|
|
|
response["retrieved"] = [doc for doc in query_result["documents"]] |
|
|
|
|
|
elif operation == "write": |
|
|
if content != "": |
|
|
if not isinstance(content, list): |
|
|
content = [content] |
|
|
documents = content |
|
|
self.collection.add( |
|
|
ids=[str(uuid.uuid4()) for _ in range(len(documents))], |
|
|
embeddings=embeddings.embed_documents(documents), |
|
|
documents=documents |
|
|
) |
|
|
response["retrieved"] = "" |
|
|
|
|
|
return response |
|
|
|