from langchain_tavily import TavilySearch from langchain_community.utilities import GoogleSerperAPIWrapper from langchain_community.tools import GoogleSerperRun from src.config.settings import settings from typing import Union, Literal class SearchToolFactory: """Factory for creating search tools""" _tavily_instance = None _serper_instance = None @classmethod def get_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily") -> Union[TavilySearch, GoogleSerperRun]: """Get or create search tool instance (singleton pattern)""" if provider == "tavily": if cls._tavily_instance is None: cls._tavily_instance = TavilySearch( api_key=settings.TAVILY_API_KEY ) return cls._tavily_instance elif provider == "serper": if cls._serper_instance is None: search_wrapper = GoogleSerperAPIWrapper( serper_api_key=settings.SERPER_API_KEY ) cls._serper_instance = GoogleSerperRun(api_wrapper=search_wrapper) return cls._serper_instance else: raise ValueError(f"Unsupported provider: {provider}") @classmethod def create_new_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily", **kwargs) -> Union[TavilySearch, GoogleSerperRun]: """Create a new search tool instance with custom parameters""" if provider == "tavily": return TavilySearch( api_key=kwargs.get("api_key", settings.TAVILY_API_KEY), max_results=kwargs.get("max_results", settings.SEARCH_RESULTS_COUNT), search_depth=kwargs.get("search_depth", settings.TAVILY_SEARCH_DEPTH), include_answer=kwargs.get("include_answer", settings.TAVILY_INCLUDE_ANSWER), include_raw_content=kwargs.get("include_raw_content", settings.TAVILY_INCLUDE_RAW_CONTENT), **{k: v for k, v in kwargs.items() if k not in ["api_key", "max_results", "search_depth", "include_answer", "include_raw_content"]} ) elif provider == "serper": search_wrapper = GoogleSerperAPIWrapper( serper_api_key=kwargs.get("api_key", settings.SERPER_API_KEY), k=kwargs.get("k", settings.SEARCH_RESULTS_COUNT), type=kwargs.get("type", settings.SERPER_SEARCH_TYPE), country=kwargs.get("country", settings.SERPER_COUNTRY), location=kwargs.get("location", settings.SERPER_LOCATION), **{k: v for k, v in kwargs.items() if k not in ["api_key", "k", "type", "country", "location"]} ) return GoogleSerperRun(api_wrapper=search_wrapper) else: raise ValueError(f"Unsupported provider: {provider}") @classmethod def get_tavily_search(cls) -> TavilySearch: """Convenience method to get Tavily search tool""" return cls.get_search_tool("tavily") @classmethod def get_serper_search(cls) -> GoogleSerperRun: """Convenience method to get Serper search tool""" return cls.get_search_tool("serper") @classmethod def reset_instances(cls): """Reset singleton instances (useful for testing)""" cls._tavily_instance = None cls._serper_instance = None