File size: 5,318 Bytes
90f0e8c
 
e6c5dc7
9f4d256
1841644
7af7a00
29aa0d3
39dab7f
 
63fd0a4
9f4d256
20d6ea1
f8c6b33
a7b1269
9a8b05a
90f0e8c
 
 
 
f8c6b33
 
 
8a568e9
f8c6b33
90f0e8c
8a568e9
1841644
a7b1269
f4c0a57
90f0e8c
f4c0a57
 
 
90f0e8c
f4c0a57
 
 
90f0e8c
f4c0a57
 
90f0e8c
 
f4c0a57
 
 
 
 
90f0e8c
f4c0a57
90f0e8c
a7b1269
1841644
90f0e8c
 
 
 
 
 
 
 
a7b1269
1841644
90f0e8c
 
 
 
 
 
 
 
a7b1269
90f0e8c
7af7a00
b02e265
 
90f0e8c
b02e265
 
 
7af7a00
90f0e8c
 
 
 
 
f4c0a57
90f0e8c
f4c0a57
90f0e8c
 
f4c0a57
 
 
 
90f0e8c
f4c0a57
 
90f0e8c
f8c6b33
f4c0a57
a7b1269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90f0e8c
f8c6b33
a7b1269
 
 
 
 
 
 
 
 
 
 
 
90f0e8c
a7b1269
 
 
90f0e8c
 
a7b1269
 
 
90f0e8c
a7b1269
90f0e8c
 
a7b1269
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import requests
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from langchain_community.document_loaders import WikipediaLoader
from llama_index.core.tools.types import ToolMetadata
from llama_index.core.schema import Document
from llama_index.core.tools import FunctionTool
from langchain_community.tools.tavily_search import TavilySearchResults
from llama_index.core.agent.workflow import AgentWorkflow

hf_token = os.getenv("HF_TOKEN")

# List of models to try in order
model_list = [
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "microsoft/phi-3-mini-128k-instruct",
    "google/gemma-2b-it",
    "gpt2"
]

current_model_index = 0
llm = HuggingFaceInferenceAPI(
    model_name=model_list[current_model_index],
    token=hf_token,
)

# Numerical operation functions
def multiply(a: int, b: int) -> int:
    """Multiply two numbers."""
    return a * b

def add(a: int, b: int) -> int:
    """Add two numbers."""
    return a + b

def subtract(a: int, b: int) -> int:
    """Subtract two numbers."""
    return a - b

def divide(a: int, b: int) -> float:
    """Divide two numbers, raises error on zero divisor."""
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

def modulus(a: int, b: int) -> int:
    """Get the modulus of two numbers."""
    return a % b

# Web search tool function
def web_search(query: str) -> list:
    """Search Tavily for a query and return up to 3 results."""
    results = TavilySearchResults(max_results=3).invoke(query=query)
    docs = []
    for r in results:
        meta = {"source": r.metadata.get("source", ""), "page": r.metadata.get("page", "")}
        docs.append(Document(text=r.page_content, metadata=meta))
    return docs

# Wikipedia search tool function
def wiki_search(query: str) -> list:
    """Search Wikipedia for a query and return up to 2 results."""
    results = WikipediaLoader(query=query, load_max_docs=2).load()
    docs = []
    for r in results:
        meta = {"source": r.metadata.get("source", ""), "page": r.metadata.get("page", "")}
        docs.append(Document(text=r.page_content, metadata=meta))
    return docs

# Wrap functions into FunctionTool instances
web_search_tool = FunctionTool(
    web_search,
    metadata=ToolMetadata(name="web_search", description="Tavily 3-hit search")
)
wiki_search_tool = FunctionTool(
    wiki_search,
    metadata=ToolMetadata(name="wiki_search", description="Wikipedia 2-hit search")
)

multiply_tool = FunctionTool(multiply, metadata=ToolMetadata(name="multiply", description="Multiply two numbers."))
add_tool      = FunctionTool(add,      metadata=ToolMetadata(name="add",      description="Add two numbers."))
subtract_tool = FunctionTool(subtract, metadata=ToolMetadata(name="subtract", description="Subtract two numbers."))
divide_tool   = FunctionTool(divide,   metadata=ToolMetadata(name="divide",   description="Divide two numbers."))
modulus_tool  = FunctionTool(modulus,  metadata=ToolMetadata(name="modulus",  description="Modulus operation on two numbers."))

# Aggregate all tools
tools = [
    web_search_tool,
    wiki_search_tool,
    multiply_tool,
    add_tool,
    subtract_tool,
    divide_tool,
    modulus_tool,
]

# Initialize agent
agent = AgentWorkflow.from_tools_or_functions(tools, llm=llm)

# Function to try the next model in the list
def try_next_model():
    """Switch to the next model in the list and reinitialize the agent.
    Returns True if successful, False if we've tried all models."""
    global current_model_index, llm, agent
    
    current_model_index += 1
    if current_model_index >= len(model_list):
        return False
    
    # Reinitialize LLM with new model
    llm = HuggingFaceInferenceAPI(
        model_name=model_list[current_model_index],
        token=hf_token,
    )
    
    # Reinitialize agent with new LLM
    agent = AgentWorkflow.from_tools_or_functions(tools, llm=llm)
    return True

# Run with fallback logic
def run_with_fallback(query: str):
    global current_model_index, llm, agent
    
    # Reset to first model if we're not already on it
    if current_model_index != 0:
        current_model_index = 0
        llm = HuggingFaceInferenceAPI(
            model_name=model_list[current_model_index],
            token=hf_token,
        )
        agent = AgentWorkflow.from_tools_or_functions(tools, llm=llm)
    
    # Try each model in sequence
    for i in range(len(model_list)):
        try:
            result = agent.run(query)
            print(f"Successfully ran query with model: {model_list[current_model_index]}")
            return result
        except Exception as e:
            print(f"Error with model {model_list[current_model_index]}: {e}")
            if i < len(model_list) - 1:  # If not the last model
                try_next_model()
            else:
                break
    
    return "Sorry, encountered issues with all models."

# Make agent.run() work with asyncio by adding async support
async def run(query: str):
    """Async wrapper for the agent.run method to be compatible with app.py"""
    return run_with_fallback(query)

# Add the async run method to the agent object
agent.run = run_with_fallback  # Replace with synchronous version for direct calls