| | |
| | import os |
| |
|
| | |
| | from langchain_community.vectorstores.azuresearch import AzureSearch |
| |
|
| |
|
| | |
| | try: |
| | from dotenv import load_dotenv |
| | load_dotenv() |
| | except: |
| | pass |
| |
|
| |
|
| | class AzureSearchWrapper: |
| | """ |
| | Wrapper class for Azure AI Search vectorstore to handle filter conversion. |
| | |
| | This wrapper automatically converts dictionary-style filters to Azure Search OData filter format, |
| | ensuring seamless compatibility when switching from other providers. |
| | """ |
| | |
| | def __init__(self, azure_search_vectorstore): |
| | self.vectorstore = azure_search_vectorstore |
| | |
| | def __getattr__(self, name): |
| | """Delegate all other attributes to the wrapped vectorstore.""" |
| | return getattr(self.vectorstore, name) |
| | |
| | def _convert_dict_filter_to_odata(self, filter_dict): |
| | """ |
| | Convert dictionary-style filters to Azure Search OData filter format. |
| | |
| | Args: |
| | filter_dict (dict): Dictionary-style filter |
| | |
| | Returns: |
| | str: OData filter string |
| | """ |
| | if not filter_dict: |
| | return None |
| | |
| | conditions = [] |
| | |
| | for key, value in filter_dict.items(): |
| | if key.endswith('_exclude'): |
| | |
| | base_key = key.replace('_exclude', '') |
| | if isinstance(value, list): |
| | if len(value) == 1: |
| | conditions.append(f"{base_key} ne '{value[0]}'") |
| | else: |
| | exclude_conditions = [f"{base_key} ne '{v}'" for v in value] |
| | conditions.append(f"({' and '.join(exclude_conditions)})") |
| | else: |
| | conditions.append(f"{base_key} ne '{value}'") |
| | elif isinstance(value, list): |
| | |
| | if len(value) == 1: |
| | conditions.append(f"{key} eq '{value[0]}'") |
| | else: |
| | list_conditions = [f"{key} eq '{v}'" for v in value] |
| | conditions.append(f"({' or '.join(list_conditions)})") |
| | else: |
| | |
| | conditions.append(f"{key} eq '{value}'") |
| | |
| | return " and ".join(conditions) if conditions else None |
| | |
| | def similarity_search_with_score(self, query, k=4, filter=None, **kwargs): |
| | """Override similarity_search_with_score to convert filters.""" |
| | if filter is not None: |
| | filter = self._convert_dict_filter_to_odata(filter) |
| |
|
| | return self.vectorstore.hybrid_search_with_score( |
| | query=query, k=k, filters=filter, **kwargs |
| | ) |
| |
|
| | |
| | def similarity_search(self, query, k=4, filter=None, **kwargs): |
| | """Override similarity_search to convert filters.""" |
| | if filter is not None: |
| | filter = self._convert_dict_filter_to_odata(filter) |
| | |
| | return self.vectorstore.similarity_search( |
| | query=query, k=k, filter=filter, **kwargs |
| | ) |
| | |
| | def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs): |
| | """Override similarity_search_by_vector to convert filters.""" |
| | if filter is not None: |
| | filter = self._convert_dict_filter_to_odata(filter) |
| | |
| | return self.vectorstore.similarity_search_by_vector( |
| | embedding=embedding, k=k, filter=filter, **kwargs |
| | ) |
| | |
| | def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs): |
| | """Override as_retriever to handle filter conversion in search_kwargs.""" |
| | if search_kwargs and "filter" in search_kwargs: |
| | |
| | search_kwargs = search_kwargs.copy() |
| | if search_kwargs["filter"] is not None: |
| | search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"]) |
| | |
| | return self.vectorstore.as_retriever( |
| | search_type=search_type, search_kwargs=search_kwargs, **kwargs |
| | ) |
| |
|
| |
|
| | def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None): |
| | """ |
| | Create an Azure AI Search vectorstore instance. |
| | |
| | Args: |
| | embeddings: The embeddings function to use |
| | text_key: The key for text content in the payload (default: "content") |
| | index_name: The name of the Azure Search index |
| | |
| | Returns: |
| | AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility |
| | """ |
| | |
| | azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT") |
| | azure_search_key = os.getenv("AI_SEARCH_KEY") |
| | |
| | if not azure_search_endpoint: |
| | raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required") |
| | |
| | if not azure_search_key: |
| | raise ValueError("AI_SEARCH_KEY environment variable is required") |
| | |
| | if not index_name: |
| | raise ValueError("index_name must be provided for Azure Search") |
| | |
| | |
| | vectorstore = AzureSearch( |
| | azure_search_endpoint=azure_search_endpoint, |
| | azure_search_key=azure_search_key, |
| | index_name=index_name, |
| | embedding_function=embeddings.embed_query, |
| | content_key=text_key, |
| | ) |
| | |
| | |
| | return AzureSearchWrapper(vectorstore) |
| |
|
| |
|