File size: 3,464 Bytes
b325aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from langchain_tavily import TavilySearch
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_community.tools import GoogleSerperRun
from src.config.settings import settings
from typing import Union, Literal

class SearchToolFactory:
    """Factory for creating search tools"""
    
    _tavily_instance = None
    _serper_instance = None
    
    @classmethod
    def get_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily") -> Union[TavilySearch, GoogleSerperRun]:
        """Get or create search tool instance (singleton pattern)"""
        if provider == "tavily":
            if cls._tavily_instance is None:
                cls._tavily_instance = TavilySearch(
                    api_key=settings.TAVILY_API_KEY
                )
            return cls._tavily_instance
        elif provider == "serper":
            if cls._serper_instance is None:
                search_wrapper = GoogleSerperAPIWrapper(
                    serper_api_key=settings.SERPER_API_KEY
                )
                cls._serper_instance = GoogleSerperRun(api_wrapper=search_wrapper)
            return cls._serper_instance
        else:
            raise ValueError(f"Unsupported provider: {provider}")
    
    @classmethod
    def create_new_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily", **kwargs) -> Union[TavilySearch, GoogleSerperRun]:
        """Create a new search tool instance with custom parameters"""
        if provider == "tavily":
            return TavilySearch(
                api_key=kwargs.get("api_key", settings.TAVILY_API_KEY),
                max_results=kwargs.get("max_results", settings.SEARCH_RESULTS_COUNT),
                search_depth=kwargs.get("search_depth", settings.TAVILY_SEARCH_DEPTH),
                include_answer=kwargs.get("include_answer", settings.TAVILY_INCLUDE_ANSWER),
                include_raw_content=kwargs.get("include_raw_content", settings.TAVILY_INCLUDE_RAW_CONTENT),
                **{k: v for k, v in kwargs.items() if k not in ["api_key", "max_results", "search_depth", "include_answer", "include_raw_content"]}
            )
        elif provider == "serper":
            search_wrapper = GoogleSerperAPIWrapper(
                serper_api_key=kwargs.get("api_key", settings.SERPER_API_KEY),
                k=kwargs.get("k", settings.SEARCH_RESULTS_COUNT),
                type=kwargs.get("type", settings.SERPER_SEARCH_TYPE),
                country=kwargs.get("country", settings.SERPER_COUNTRY),
                location=kwargs.get("location", settings.SERPER_LOCATION),
                **{k: v for k, v in kwargs.items() if k not in ["api_key", "k", "type", "country", "location"]}
            )
            return GoogleSerperRun(api_wrapper=search_wrapper)
        else:
            raise ValueError(f"Unsupported provider: {provider}")
    
    @classmethod
    def get_tavily_search(cls) -> TavilySearch:
        """Convenience method to get Tavily search tool"""
        return cls.get_search_tool("tavily")
    
    @classmethod
    def get_serper_search(cls) -> GoogleSerperRun:
        """Convenience method to get Serper search tool"""
        return cls.get_search_tool("serper")
    
    @classmethod
    def reset_instances(cls):
        """Reset singleton instances (useful for testing)"""
        cls._tavily_instance = None
        cls._serper_instance = None