|
|
import os |
|
|
from typing import Any |
|
|
|
|
|
from astrapy import Collection, DataAPIClient, Database |
|
|
from langchain.pydantic_v1 import BaseModel, Field, create_model |
|
|
from langchain_core.tools import StructuredTool |
|
|
|
|
|
from langflow.base.langchain_utilities.model import LCToolComponent |
|
|
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput |
|
|
from langflow.schema import Data |
|
|
|
|
|
|
|
|
class AstraDBToolComponent(LCToolComponent): |
|
|
display_name: str = "Astra DB Tool" |
|
|
description: str = "Create a tool to get transactional data from DataStax Astra DB Collection" |
|
|
documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool" |
|
|
icon: str = "AstraDB" |
|
|
|
|
|
inputs = [ |
|
|
StrInput( |
|
|
name="tool_name", |
|
|
display_name="Tool Name", |
|
|
info="The name of the tool.", |
|
|
required=True, |
|
|
), |
|
|
StrInput( |
|
|
name="tool_description", |
|
|
display_name="Tool Description", |
|
|
info="The description of the tool.", |
|
|
required=True, |
|
|
), |
|
|
StrInput( |
|
|
name="namespace", |
|
|
display_name="Namespace Name", |
|
|
info="The name of the namespace within Astra where the collection is be stored.", |
|
|
value="default_keyspace", |
|
|
advanced=True, |
|
|
), |
|
|
StrInput( |
|
|
name="collection_name", |
|
|
display_name="Collection Name", |
|
|
info="The name of the collection within Astra DB where the vectors will be stored.", |
|
|
required=True, |
|
|
), |
|
|
SecretStrInput( |
|
|
name="token", |
|
|
display_name="Astra DB Application Token", |
|
|
info="Authentication token for accessing Astra DB.", |
|
|
value="ASTRA_DB_APPLICATION_TOKEN", |
|
|
required=True, |
|
|
), |
|
|
SecretStrInput( |
|
|
name="api_endpoint", |
|
|
display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint", |
|
|
info="API endpoint URL for the Astra DB service.", |
|
|
value="ASTRA_DB_API_ENDPOINT", |
|
|
required=True, |
|
|
), |
|
|
StrInput( |
|
|
name="projection_attributes", |
|
|
display_name="Projection Attributes", |
|
|
info="Attributes to return separated by comma.", |
|
|
required=True, |
|
|
value="*", |
|
|
advanced=True, |
|
|
), |
|
|
DictInput( |
|
|
name="tool_params", |
|
|
info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)", |
|
|
display_name="Tool params", |
|
|
is_list=True, |
|
|
), |
|
|
DictInput( |
|
|
name="static_filters", |
|
|
info="Attributes to filter and correspoding value", |
|
|
display_name="Static filters", |
|
|
advanced=True, |
|
|
is_list=True, |
|
|
), |
|
|
IntInput( |
|
|
name="number_of_results", |
|
|
display_name="Number of Results", |
|
|
info="Number of results to return.", |
|
|
advanced=True, |
|
|
value=5, |
|
|
), |
|
|
] |
|
|
|
|
|
_cached_client: DataAPIClient | None = None |
|
|
_cached_db: Database | None = None |
|
|
_cached_collection: Collection | None = None |
|
|
|
|
|
def _build_collection(self): |
|
|
if self._cached_collection: |
|
|
return self._cached_collection |
|
|
|
|
|
cached_client = DataAPIClient(self.token) |
|
|
cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace) |
|
|
self._cached_collection = cached_db.get_collection(self.collection_name) |
|
|
return self._cached_collection |
|
|
|
|
|
def create_args_schema(self) -> dict[str, BaseModel]: |
|
|
args: dict[str, tuple[Any, Field] | list[str]] = {} |
|
|
|
|
|
for key in self.tool_params: |
|
|
if key.startswith("!"): |
|
|
args[key[1:]] = (str, Field(description=self.tool_params[key])) |
|
|
else: |
|
|
args[key] = (str | None, Field(description=self.tool_params[key], default=None)) |
|
|
|
|
|
model = create_model("ToolInput", **args, __base__=BaseModel) |
|
|
return {"ToolInput": model} |
|
|
|
|
|
def build_tool(self) -> StructuredTool: |
|
|
"""Builds an Astra DB Collection tool. |
|
|
|
|
|
Returns: |
|
|
Tool: The built Astra DB tool. |
|
|
""" |
|
|
schema_dict = self.create_args_schema() |
|
|
|
|
|
tool = StructuredTool.from_function( |
|
|
name=self.tool_name, |
|
|
args_schema=schema_dict["ToolInput"], |
|
|
description=self.tool_description, |
|
|
func=self.run_model, |
|
|
return_direct=False, |
|
|
) |
|
|
self.status = "Astra DB Tool created" |
|
|
|
|
|
return tool |
|
|
|
|
|
def projection_args(self, input_str: str) -> dict: |
|
|
elements = input_str.split(",") |
|
|
result = {} |
|
|
|
|
|
for element in elements: |
|
|
if element.startswith("!"): |
|
|
result[element[1:]] = False |
|
|
else: |
|
|
result[element] = True |
|
|
|
|
|
return result |
|
|
|
|
|
def run_model(self, **args) -> Data | list[Data]: |
|
|
collection = self._build_collection() |
|
|
results = collection.find( |
|
|
({**args, **self.static_filters}), |
|
|
projection=self.projection_args(self.projection_attributes), |
|
|
limit=self.number_of_results, |
|
|
) |
|
|
|
|
|
data: list[Data] = [Data(data=doc) for doc in results] |
|
|
self.status = data |
|
|
return data |
|
|
|