Spaces:
No application file
No application file
| from langchain_tavily import TavilySearch | |
| from langchain_community.utilities import GoogleSerperAPIWrapper | |
| from langchain_community.tools import GoogleSerperRun | |
| from src.config.settings import settings | |
| from typing import Union, Literal | |
| class SearchToolFactory: | |
| """Factory for creating search tools""" | |
| _tavily_instance = None | |
| _serper_instance = None | |
| def get_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily") -> Union[TavilySearch, GoogleSerperRun]: | |
| """Get or create search tool instance (singleton pattern)""" | |
| if provider == "tavily": | |
| if cls._tavily_instance is None: | |
| cls._tavily_instance = TavilySearch( | |
| api_key=settings.TAVILY_API_KEY | |
| ) | |
| return cls._tavily_instance | |
| elif provider == "serper": | |
| if cls._serper_instance is None: | |
| search_wrapper = GoogleSerperAPIWrapper( | |
| serper_api_key=settings.SERPER_API_KEY | |
| ) | |
| cls._serper_instance = GoogleSerperRun(api_wrapper=search_wrapper) | |
| return cls._serper_instance | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def create_new_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily", **kwargs) -> Union[TavilySearch, GoogleSerperRun]: | |
| """Create a new search tool instance with custom parameters""" | |
| if provider == "tavily": | |
| return TavilySearch( | |
| api_key=kwargs.get("api_key", settings.TAVILY_API_KEY), | |
| max_results=kwargs.get("max_results", settings.SEARCH_RESULTS_COUNT), | |
| search_depth=kwargs.get("search_depth", settings.TAVILY_SEARCH_DEPTH), | |
| include_answer=kwargs.get("include_answer", settings.TAVILY_INCLUDE_ANSWER), | |
| include_raw_content=kwargs.get("include_raw_content", settings.TAVILY_INCLUDE_RAW_CONTENT), | |
| **{k: v for k, v in kwargs.items() if k not in ["api_key", "max_results", "search_depth", "include_answer", "include_raw_content"]} | |
| ) | |
| elif provider == "serper": | |
| search_wrapper = GoogleSerperAPIWrapper( | |
| serper_api_key=kwargs.get("api_key", settings.SERPER_API_KEY), | |
| k=kwargs.get("k", settings.SEARCH_RESULTS_COUNT), | |
| type=kwargs.get("type", settings.SERPER_SEARCH_TYPE), | |
| country=kwargs.get("country", settings.SERPER_COUNTRY), | |
| location=kwargs.get("location", settings.SERPER_LOCATION), | |
| **{k: v for k, v in kwargs.items() if k not in ["api_key", "k", "type", "country", "location"]} | |
| ) | |
| return GoogleSerperRun(api_wrapper=search_wrapper) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def get_tavily_search(cls) -> TavilySearch: | |
| """Convenience method to get Tavily search tool""" | |
| return cls.get_search_tool("tavily") | |
| def get_serper_search(cls) -> GoogleSerperRun: | |
| """Convenience method to get Serper search tool""" | |
| return cls.get_search_tool("serper") | |
| def reset_instances(cls): | |
| """Reset singleton instances (useful for testing)""" | |
| cls._tavily_instance = None | |
| cls._serper_instance = None |