Spaces:
Running
Running
| import urllib | |
| from http import HTTPStatus | |
| from typing import Any | |
| import requests | |
| from langchain.pydantic_v1 import BaseModel, Field, create_model | |
| from langchain_core.tools import StructuredTool | |
| from langflow.base.langchain_utilities.model import LCToolComponent | |
| from langflow.io import DictInput, IntInput, SecretStrInput, StrInput | |
| from langflow.schema import Data | |
| class AstraDBCQLToolComponent(LCToolComponent): | |
| display_name: str = "Astra DB CQL" | |
| description: str = "Create a tool to get transactional data from DataStax Astra DB CQL Table" | |
| documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-cql-tool" | |
| icon: str = "AstraDB" | |
| inputs = [ | |
| StrInput(name="tool_name", display_name="Tool Name", info="The name of the tool.", required=True), | |
| StrInput( | |
| name="tool_description", | |
| display_name="Tool Description", | |
| info="The tool description to be passed to the model.", | |
| required=True, | |
| ), | |
| StrInput( | |
| name="keyspace", | |
| display_name="Keyspace", | |
| value="default_keyspace", | |
| info="The keyspace name within Astra DB where the data is stored.", | |
| required=True, | |
| advanced=True, | |
| ), | |
| StrInput( | |
| name="table_name", | |
| display_name="Table Name", | |
| info="The name of the table within Astra DB where the data is stored.", | |
| required=True, | |
| ), | |
| SecretStrInput( | |
| name="token", | |
| display_name="Astra DB Application Token", | |
| info="Authentication token for accessing Astra DB.", | |
| value="ASTRA_DB_APPLICATION_TOKEN", | |
| required=True, | |
| ), | |
| StrInput( | |
| name="api_endpoint", | |
| display_name="API Endpoint", | |
| info="API endpoint URL for the Astra DB service.", | |
| value="ASTRA_DB_API_ENDPOINT", | |
| required=True, | |
| ), | |
| StrInput( | |
| name="projection_fields", | |
| display_name="Projection fields", | |
| info="Attributes to return separated by comma.", | |
| required=True, | |
| value="*", | |
| advanced=True, | |
| ), | |
| DictInput( | |
| name="partition_keys", | |
| display_name="Partition Keys", | |
| is_list=True, | |
| info="Field name and description to the model", | |
| required=True, | |
| ), | |
| DictInput( | |
| name="clustering_keys", | |
| display_name="Clustering Keys", | |
| is_list=True, | |
| info="Field name and description to the model", | |
| ), | |
| DictInput( | |
| name="static_filters", | |
| display_name="Static Filters", | |
| is_list=True, | |
| advanced=True, | |
| info="Field name and value. When filled, it will not be generated by the LLM.", | |
| ), | |
| IntInput( | |
| name="number_of_results", | |
| display_name="Number of Results", | |
| info="Number of results to return.", | |
| advanced=True, | |
| value=5, | |
| ), | |
| ] | |
| def astra_rest(self, args): | |
| headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"} | |
| astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/" | |
| key = [] | |
| # Partition keys are mandatory | |
| for k in self.partition_keys: | |
| if k in args: | |
| key.append(args[k]) | |
| elif self.static_filters[k] is not None: | |
| key.append(self.static_filters[k]) | |
| else: | |
| # TO-DO: Raise error - Missing information | |
| key.append("none") | |
| # Clustering keys are optional | |
| for k in self.clustering_keys: | |
| if k in args: | |
| key.append(args[k]) | |
| elif self.static_filters[k] is not None: | |
| key.append(self.static_filters[k]) | |
| url = f'{astra_url}{"/".join(key)}?page-size={self.number_of_results}' | |
| if self.projection_fields != "*": | |
| url += f'&fields={urllib.parse.quote(self.projection_fields.replace(" ", ""))}' | |
| res = requests.request("GET", url=url, headers=headers, timeout=10) | |
| if int(res.status_code) >= HTTPStatus.BAD_REQUEST: | |
| return res.text | |
| try: | |
| res_data = res.json() | |
| return res_data["data"] | |
| except ValueError: | |
| return res.status_code | |
| def create_args_schema(self) -> dict[str, BaseModel]: | |
| args: dict[str, tuple[Any, Field]] = {} | |
| for key in self.partition_keys: | |
| # Partition keys are mandatory is it doesn't have a static filter | |
| if key not in self.static_filters: | |
| args[key] = (str, Field(description=self.partition_keys[key])) | |
| for key in self.clustering_keys: | |
| # Partition keys are mandatory if has the exclamation mark and doesn't have a static filter | |
| if key not in self.static_filters: | |
| if key.startswith("!"): # Mandatory | |
| args[key[1:]] = (str, Field(description=self.clustering_keys[key])) | |
| else: # Optional | |
| args[key] = (str | None, Field(description=self.clustering_keys[key], default=None)) | |
| model = create_model("ToolInput", **args, __base__=BaseModel) | |
| return {"ToolInput": model} | |
| def build_tool(self) -> StructuredTool: | |
| """Builds a Astra DB CQL Table tool. | |
| Args: | |
| name (str, optional): The name of the tool. | |
| Returns: | |
| Tool: The built AstraDB tool. | |
| """ | |
| schema_dict = self.create_args_schema() | |
| return StructuredTool.from_function( | |
| name=self.tool_name, | |
| args_schema=schema_dict["ToolInput"], | |
| description=self.tool_description, | |
| func=self.run_model, | |
| return_direct=False, | |
| ) | |
| def projection_args(self, input_str: str) -> dict: | |
| elements = input_str.split(",") | |
| result = {} | |
| for element in elements: | |
| if element.startswith("!"): | |
| result[element[1:]] = False | |
| else: | |
| result[element] = True | |
| return result | |
| def run_model(self, **args) -> Data | list[Data]: | |
| results = self.astra_rest(args) | |
| data: list[Data] = [Data(data=doc) for doc in results] | |
| self.status = data | |
| return results | |