Try / src /agenticRAG /components /search_tools.py
Alamgirapi's picture
Upload folder src
b325aad verified
from langchain_tavily import TavilySearch
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_community.tools import GoogleSerperRun
from src.config.settings import settings
from typing import Union, Literal
class SearchToolFactory:
"""Factory for creating search tools"""
_tavily_instance = None
_serper_instance = None
@classmethod
def get_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily") -> Union[TavilySearch, GoogleSerperRun]:
"""Get or create search tool instance (singleton pattern)"""
if provider == "tavily":
if cls._tavily_instance is None:
cls._tavily_instance = TavilySearch(
api_key=settings.TAVILY_API_KEY
)
return cls._tavily_instance
elif provider == "serper":
if cls._serper_instance is None:
search_wrapper = GoogleSerperAPIWrapper(
serper_api_key=settings.SERPER_API_KEY
)
cls._serper_instance = GoogleSerperRun(api_wrapper=search_wrapper)
return cls._serper_instance
else:
raise ValueError(f"Unsupported provider: {provider}")
@classmethod
def create_new_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily", **kwargs) -> Union[TavilySearch, GoogleSerperRun]:
"""Create a new search tool instance with custom parameters"""
if provider == "tavily":
return TavilySearch(
api_key=kwargs.get("api_key", settings.TAVILY_API_KEY),
max_results=kwargs.get("max_results", settings.SEARCH_RESULTS_COUNT),
search_depth=kwargs.get("search_depth", settings.TAVILY_SEARCH_DEPTH),
include_answer=kwargs.get("include_answer", settings.TAVILY_INCLUDE_ANSWER),
include_raw_content=kwargs.get("include_raw_content", settings.TAVILY_INCLUDE_RAW_CONTENT),
**{k: v for k, v in kwargs.items() if k not in ["api_key", "max_results", "search_depth", "include_answer", "include_raw_content"]}
)
elif provider == "serper":
search_wrapper = GoogleSerperAPIWrapper(
serper_api_key=kwargs.get("api_key", settings.SERPER_API_KEY),
k=kwargs.get("k", settings.SEARCH_RESULTS_COUNT),
type=kwargs.get("type", settings.SERPER_SEARCH_TYPE),
country=kwargs.get("country", settings.SERPER_COUNTRY),
location=kwargs.get("location", settings.SERPER_LOCATION),
**{k: v for k, v in kwargs.items() if k not in ["api_key", "k", "type", "country", "location"]}
)
return GoogleSerperRun(api_wrapper=search_wrapper)
else:
raise ValueError(f"Unsupported provider: {provider}")
@classmethod
def get_tavily_search(cls) -> TavilySearch:
"""Convenience method to get Tavily search tool"""
return cls.get_search_tool("tavily")
@classmethod
def get_serper_search(cls) -> GoogleSerperRun:
"""Convenience method to get Serper search tool"""
return cls.get_search_tool("serper")
@classmethod
def reset_instances(cls):
"""Reset singleton instances (useful for testing)"""
cls._tavily_instance = None
cls._serper_instance = None