Spaces:
Running
Running
| from loguru import logger | |
| from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store | |
| from langflow.helpers import docs_to_data | |
| from langflow.inputs import DictInput, FloatInput | |
| from langflow.io import ( | |
| BoolInput, | |
| DataInput, | |
| DropdownInput, | |
| HandleInput, | |
| IntInput, | |
| MultilineInput, | |
| SecretStrInput, | |
| StrInput, | |
| ) | |
| from langflow.schema import Data | |
| class HCDVectorStoreComponent(LCVectorStoreComponent): | |
| display_name: str = "Hyper-Converged Database" | |
| description: str = "Implementation of Vector Store using Hyper-Converged Database (HCD) with search capabilities" | |
| documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb" | |
| name = "HCD" | |
| icon: str = "HCD" | |
| inputs = [ | |
| StrInput( | |
| name="collection_name", | |
| display_name="Collection Name", | |
| info="The name of the collection within HCD where the vectors will be stored.", | |
| required=True, | |
| ), | |
| StrInput( | |
| name="username", | |
| display_name="HCD Username", | |
| info="Authentication username for accessing HCD.", | |
| value="hcd-superuser", | |
| required=True, | |
| ), | |
| SecretStrInput( | |
| name="password", | |
| display_name="HCD Password", | |
| info="Authentication password for accessing HCD.", | |
| value="HCD_PASSWORD", | |
| required=True, | |
| ), | |
| SecretStrInput( | |
| name="api_endpoint", | |
| display_name="HCD API Endpoint", | |
| info="API endpoint URL for the HCD service.", | |
| value="HCD_API_ENDPOINT", | |
| required=True, | |
| ), | |
| MultilineInput( | |
| name="search_input", | |
| display_name="Search Input", | |
| ), | |
| DataInput( | |
| name="ingest_data", | |
| display_name="Ingest Data", | |
| is_list=True, | |
| ), | |
| StrInput( | |
| name="namespace", | |
| display_name="Namespace", | |
| info="Optional namespace within HCD to use for the collection.", | |
| value="default_namespace", | |
| advanced=True, | |
| ), | |
| MultilineInput( | |
| name="ca_certificate", | |
| display_name="CA Certificate", | |
| info="Optional CA certificate for TLS connections to HCD.", | |
| advanced=True, | |
| ), | |
| DropdownInput( | |
| name="metric", | |
| display_name="Metric", | |
| info="Optional distance metric for vector comparisons in the vector store.", | |
| options=["cosine", "dot_product", "euclidean"], | |
| advanced=True, | |
| ), | |
| IntInput( | |
| name="batch_size", | |
| display_name="Batch Size", | |
| info="Optional number of data to process in a single batch.", | |
| advanced=True, | |
| ), | |
| IntInput( | |
| name="bulk_insert_batch_concurrency", | |
| display_name="Bulk Insert Batch Concurrency", | |
| info="Optional concurrency level for bulk insert operations.", | |
| advanced=True, | |
| ), | |
| IntInput( | |
| name="bulk_insert_overwrite_concurrency", | |
| display_name="Bulk Insert Overwrite Concurrency", | |
| info="Optional concurrency level for bulk insert operations that overwrite existing data.", | |
| advanced=True, | |
| ), | |
| IntInput( | |
| name="bulk_delete_concurrency", | |
| display_name="Bulk Delete Concurrency", | |
| info="Optional concurrency level for bulk delete operations.", | |
| advanced=True, | |
| ), | |
| DropdownInput( | |
| name="setup_mode", | |
| display_name="Setup Mode", | |
| info="Configuration mode for setting up the vector store, with options like 'Sync', 'Async', or 'Off'.", | |
| options=["Sync", "Async", "Off"], | |
| advanced=True, | |
| value="Sync", | |
| ), | |
| BoolInput( | |
| name="pre_delete_collection", | |
| display_name="Pre Delete Collection", | |
| info="Boolean flag to determine whether to delete the collection before creating a new one.", | |
| advanced=True, | |
| ), | |
| StrInput( | |
| name="metadata_indexing_include", | |
| display_name="Metadata Indexing Include", | |
| info="Optional list of metadata fields to include in the indexing.", | |
| advanced=True, | |
| ), | |
| HandleInput( | |
| name="embedding", | |
| display_name="Embedding or Astra Vectorize", | |
| input_types=["Embeddings", "dict"], | |
| # TODO: This should be optional, but need to refactor langchain-astradb first. | |
| info="Allows either an embedding model or an Astra Vectorize configuration.", | |
| ), | |
| StrInput( | |
| name="metadata_indexing_exclude", | |
| display_name="Metadata Indexing Exclude", | |
| info="Optional list of metadata fields to exclude from the indexing.", | |
| advanced=True, | |
| ), | |
| StrInput( | |
| name="collection_indexing_policy", | |
| display_name="Collection Indexing Policy", | |
| info="Optional dictionary defining the indexing policy for the collection.", | |
| advanced=True, | |
| ), | |
| IntInput( | |
| name="number_of_results", | |
| display_name="Number of Results", | |
| info="Number of results to return.", | |
| advanced=True, | |
| value=4, | |
| ), | |
| DropdownInput( | |
| name="search_type", | |
| display_name="Search Type", | |
| info="Search type to use", | |
| options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], | |
| value="Similarity", | |
| advanced=True, | |
| ), | |
| FloatInput( | |
| name="search_score_threshold", | |
| display_name="Search Score Threshold", | |
| info="Minimum similarity score threshold for search results. " | |
| "(when using 'Similarity with score threshold')", | |
| value=0, | |
| advanced=True, | |
| ), | |
| DictInput( | |
| name="search_filter", | |
| display_name="Search Metadata Filter", | |
| info="Optional dictionary of filters to apply to the search query.", | |
| advanced=True, | |
| is_list=True, | |
| ), | |
| ] | |
| def build_vector_store(self): | |
| try: | |
| from langchain_astradb import AstraDBVectorStore | |
| from langchain_astradb.utils.astradb import SetupMode | |
| except ImportError as e: | |
| msg = ( | |
| "Could not import langchain Astra DB integration package. " | |
| "Please install it with `pip install langchain-astradb`." | |
| ) | |
| raise ImportError(msg) from e | |
| try: | |
| from astrapy.authentication import UsernamePasswordTokenProvider | |
| from astrapy.constants import Environment | |
| except ImportError as e: | |
| msg = "Could not import astrapy integration package. Please install it with `pip install astrapy`." | |
| raise ImportError(msg) from e | |
| try: | |
| if not self.setup_mode: | |
| self.setup_mode = self._inputs["setup_mode"].options[0] | |
| setup_mode_value = SetupMode[self.setup_mode.upper()] | |
| except KeyError as e: | |
| msg = f"Invalid setup mode: {self.setup_mode}" | |
| raise ValueError(msg) from e | |
| if not isinstance(self.embedding, dict): | |
| embedding_dict = {"embedding": self.embedding} | |
| else: | |
| from astrapy.info import CollectionVectorServiceOptions | |
| dict_options = self.embedding.get("collection_vector_service_options", {}) | |
| dict_options["authentication"] = { | |
| k: v for k, v in dict_options.get("authentication", {}).items() if k and v | |
| } | |
| dict_options["parameters"] = {k: v for k, v in dict_options.get("parameters", {}).items() if k and v} | |
| embedding_dict = { | |
| "collection_vector_service_options": CollectionVectorServiceOptions.from_dict(dict_options) | |
| } | |
| collection_embedding_api_key = self.embedding.get("collection_embedding_api_key") | |
| if collection_embedding_api_key: | |
| embedding_dict["collection_embedding_api_key"] = collection_embedding_api_key | |
| token_provider = UsernamePasswordTokenProvider(self.username, self.password) | |
| vector_store_kwargs = { | |
| **embedding_dict, | |
| "collection_name": self.collection_name, | |
| "token": token_provider, | |
| "api_endpoint": self.api_endpoint, | |
| "namespace": self.namespace, | |
| "metric": self.metric or None, | |
| "batch_size": self.batch_size or None, | |
| "bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency or None, | |
| "bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency or None, | |
| "bulk_delete_concurrency": self.bulk_delete_concurrency or None, | |
| "setup_mode": setup_mode_value, | |
| "pre_delete_collection": self.pre_delete_collection or False, | |
| "environment": Environment.HCD, | |
| } | |
| if self.metadata_indexing_include: | |
| vector_store_kwargs["metadata_indexing_include"] = self.metadata_indexing_include | |
| elif self.metadata_indexing_exclude: | |
| vector_store_kwargs["metadata_indexing_exclude"] = self.metadata_indexing_exclude | |
| elif self.collection_indexing_policy: | |
| vector_store_kwargs["collection_indexing_policy"] = self.collection_indexing_policy | |
| try: | |
| vector_store = AstraDBVectorStore(**vector_store_kwargs) | |
| except Exception as e: | |
| msg = f"Error initializing AstraDBVectorStore: {e}" | |
| raise ValueError(msg) from e | |
| self._add_documents_to_vector_store(vector_store) | |
| return vector_store | |
| def _add_documents_to_vector_store(self, vector_store) -> None: | |
| documents = [] | |
| for _input in self.ingest_data or []: | |
| if isinstance(_input, Data): | |
| documents.append(_input.to_lc_document()) | |
| else: | |
| msg = "Vector Store Inputs must be Data objects." | |
| raise TypeError(msg) | |
| if documents: | |
| logger.debug(f"Adding {len(documents)} documents to the Vector Store.") | |
| try: | |
| vector_store.add_documents(documents) | |
| except Exception as e: | |
| msg = f"Error adding documents to AstraDBVectorStore: {e}" | |
| raise ValueError(msg) from e | |
| else: | |
| logger.debug("No documents to add to the Vector Store.") | |
| def _map_search_type(self) -> str: | |
| if self.search_type == "Similarity with score threshold": | |
| return "similarity_score_threshold" | |
| if self.search_type == "MMR (Max Marginal Relevance)": | |
| return "mmr" | |
| return "similarity" | |
| def _build_search_args(self): | |
| args = { | |
| "k": self.number_of_results, | |
| "score_threshold": self.search_score_threshold, | |
| } | |
| if self.search_filter: | |
| clean_filter = {k: v for k, v in self.search_filter.items() if k and v} | |
| if len(clean_filter) > 0: | |
| args["filter"] = clean_filter | |
| return args | |
| def search_documents(self) -> list[Data]: | |
| vector_store = self.build_vector_store() | |
| logger.debug(f"Search input: {self.search_input}") | |
| logger.debug(f"Search type: {self.search_type}") | |
| logger.debug(f"Number of results: {self.number_of_results}") | |
| if self.search_input and isinstance(self.search_input, str) and self.search_input.strip(): | |
| try: | |
| search_type = self._map_search_type() | |
| search_args = self._build_search_args() | |
| docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args) | |
| except Exception as e: | |
| msg = f"Error performing search in AstraDBVectorStore: {e}" | |
| raise ValueError(msg) from e | |
| logger.debug(f"Retrieved documents: {len(docs)}") | |
| data = docs_to_data(docs) | |
| logger.debug(f"Converted documents to data: {len(data)}") | |
| self.status = data | |
| return data | |
| logger.debug("No search input provided. Skipping search.") | |
| return [] | |
| def get_retriever_kwargs(self): | |
| search_args = self._build_search_args() | |
| return { | |
| "search_type": self._map_search_type(), | |
| "search_kwargs": search_args, | |
| } | |