Sborole's picture
Update src/agent.py
0460bef verified
raw
history blame
6.07 kB
from src.settings import Settings
from smolagents import LiteLLMModel, ToolCallingAgent, CodeAgent, InferenceClientModel, DuckDuckGoSearchTool, Tool
from tools.FinalAnswerTool import FinalAnswerTool
from tools.ReadAudioTool import ReadAudioTool
from tools.ReadImageTool import ReadImageTool
from tools.ReadTextTool import ReadTextTool
from tools.ReadVideoTool import ReadVideoTool
from tools.WebSearchTool import DuckDuckGoSearchTool
from tools.YouTubeTool import YouTubeTool
from tools.PythonRunnerTool import PythonRunnerTool
from tools.PythonCalcTool import PythonCalcTool
from tools.SemanticScholar import AcademicPaperSearchTool
from src.utils import InputTokenRateLimiter
import wikipedia as wiki
from markdownify import markdownify as to_markdown
import time
import random
settings = Settings()
import litellm
from litellm import completion
#litellm._turn_on_debug()
class WikiTitleFinder(Tool):
name = "wiki_titles"
description = "Search for related Wikipedia page titles."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
results = wiki.search(query)
return ", ".join(results) if results else "No results."
class WikiContentFetcher(Tool):
name = "wiki_page"
description = "Fetch Wikipedia page content."
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
output_type = "string"
def forward(self, page_title: str) -> str:
try:
return to_markdown(wiki.page(page_title).html())
except wiki.exceptions.PageError:
return f"'{page_title}' not found."
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
return str(eval(input, {"__builtins__": {}}))
except Exception as e:
return f"Math error: {e}"
class BasicAgent():
def __init__(self):
self.model = LiteLLMModel(
model_id=settings.llm_model_id,
api_key=settings.llm_api_key
)
self.agent = CodeAgent(
model=self.model,
tools=[
FinalAnswerTool(),
ReadAudioTool(),
ReadImageTool(),
ReadTextTool(),
ReadVideoTool(),
DuckDuckGoSearchTool(),
WikiTitleFinder(),
WikiContentFetcher(),
YouTubeTool(),
PythonRunnerTool(),
PythonCalcTool(),
AcademicPaperSearchTool(),
MathSolver()
],
max_steps=5,
planning_interval=3,
)
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
def run(self, question: str, file_content: str = "", file_path: str = ""):
final_answer = None
retry_count = 0
question = f"Question: {question}"
if file_content:
context = f"Story content:\n{file_content}"
elif file_path:
context = f"File path: {file_path}"
else:
context = ""
print(f"Starting Agent with question text: {question}, {context}")
while True:
try:
final_input = f"{question}\n\n{context}"
# Run the agent
steps = self.agent.run(final_input)
# If steps is a string, convert it to a single-item list
if isinstance(steps, str):
steps = [steps]
for step in steps:
# Handle string steps
if isinstance(step, str):
final_answer = step
print(f"Step: String Output: {final_answer}")
continue
# Handle object steps
step_name = step.__class__.__name__
output = getattr(step, "output", None)
if output:
print(f"Step: {step_name} Output: {output}")
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
tokens_used = getattr(step, "token_usage", None)
if tokens_used:
self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
# Capture the final answer from the final answer step
if step_name == "FinalAnswerStep":
final_answer = output
print(f"Captured Final Answer from step: {final_answer}")
break # Exit retry loop if successful
except Exception as e:
# Handle API overload/rate limits
if "overload" in str(e).lower() or "rate limit" in str(e).lower():
print("Rate limit exceeded. Retrying...")
if retry_count >= self.max_retries:
print("Max retries reached. Exiting...")
break
delay = self.base_delay * (2 ** retry_count) + random.random()
print(f"Retrying in {delay:.1f}s ... ({e})")
time.sleep(delay)
retry_count += 1
else:
print(f"Error: {e}")
break
print(f"\nFinished agent run.\n{'='*60}")
print(f"Final Answer: {final_answer}\n")
return final_answer
def __call__(self, question: str, file_content: str = "", file_path: str = ""):
return self.run(question, file_content)