Sborole's picture
Update src/agent.py
0e0b570 verified
from src.settings import Settings
from smolagents import LiteLLMModel, ToolCallingAgent, CodeAgent, InferenceClientModel, Tool
from tools.FinalAnswerTool import FinalAnswerTool
from tools.ReadAudioTool import ReadAudioTool
from tools.ReadImageTool import ReadImageTool
from tools.ReadTextTool import ReadTextTool
from tools.ReadExcelTool import ReadExcelTool
from tools.ReadVideoTool import ReadVideoTool
from tools.WebSearchTool import TavilySearchTool
from tools.WikipediaTool import LocalWikipediaTool
from tools.YouTubeTool import YouTubeSearchTool
from tools.PythonRunnerTool import PythonRunnerTool
from tools.PythonCalcTool import PythonCalcTool
from tools.SemanticScholar import TavilyResearchTool
from src.utils import InputTokenRateLimiter
import wikipedia as wiki
from markdownify import markdownify as to_markdown
import time
import random
settings = Settings()
import litellm
from litellm import completion
#litellm._turn_on_debug()
class WikiTitleFinder(Tool):
name = "wiki_titles"
description = "Search for related Wikipedia page titles."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
results = wiki.search(query)
return ", ".join(results) if results else "No results."
class WikiContentFetcher(Tool):
name = "wiki_page"
description = "Fetch Wikipedia page content."
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
output_type = "string"
def forward(self, page_title: str) -> str:
try:
return to_markdown(wiki.page(page_title).html())
except wiki.exceptions.PageError:
return f"'{page_title}' not found."
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
return str(eval(input, {"__builtins__": {}}))
except Exception as e:
return f"Math error: {e}"
class BasicAgent():
def __init__(self):
self.model = LiteLLMModel(
model_id=settings.llm_model_id,
api_key=settings.llm_api_key
)
self.agent = CodeAgent(
model=self.model,
tools=[
TavilySearchTool(),
TavilyResearchTool(),
FinalAnswerTool(),
PythonCalcTool(),
ReadAudioTool(),
ReadImageTool(),
ReadExcelTool(),
ReadTextTool(),
YouTubeSearchTool(),
PythonRunnerTool(),
MathSolver(),
LocalWikipediaTool(),
WikiTitleFinder(),
WikiContentFetcher(),
],
max_steps=10,
planning_interval=5,
)
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
def _read_file(self, file_path: str) -> str:
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
return ""
if file_path.endswith(".txt"):
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
elif file_path.endswith(".csv"):
data = []
try:
# Use csv.DictReader to read the file into a list of dictionaries
with open(file_path, mode='r', newline='', encoding='utf-8') as file:
reader = csv.DictReader(file)
for row in reader:
data.append(row)
# Return the structured data as a JSON string
return json.dumps(data, indent=2)
except Exception as e:
print(f"Error reading CSV file: {e}")
return ""
elif file_path.endswith(".docx"):
doc = docx.Document(file_path)
return "\n".join([p.text for p in doc.paragraphs])
else:
# For unsupported formats, return empty string
print(f"Unsupported file type: {file_path}")
return ""
def run(self, question: str, file_content: str = "", file_path: str = "") -> str:
final_answer = None
retry_count = 0
# If file content is empty but file_path exists, read the file
if not file_content and file_path:
file_content = self._read_file(file_path)
context = ""
if file_content:
context = f"Story content:\n{file_content}"
elif file_path:
context = f"File path: {file_path}"
print(f"Starting Agent with question: {question}\nContext length: {len(context)} chars")
while True:
try:
final_input = f"Question: {question}\n\n{context}"
steps = self.agent.run(final_input)
# Convert string steps to list
if isinstance(steps, str):
steps = [steps]
for step in steps:
if isinstance(step, str):
final_answer = step
continue
step_name = step.__class__.__name__
output = getattr(step, "output", None)
if output:
final_answer = output
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
tokens_used = getattr(step, "token_usage", None)
if tokens_used:
self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
break # Exit retry loop if successful
except Exception as e:
if "overload" in str(e).lower() or "rate limit" in str(e).lower():
if retry_count >= self.max_retries:
print("Max retries reached. Exiting...")
break
delay = self.base_delay * (2 ** retry_count) + random.random()
print(f"Retrying in {delay:.1f}s due to rate limit... ({e})")
time.sleep(delay)
retry_count += 1
else:
print(f"Error during agent run: {e}")
break
# Ensure a valid answer is always returned
if not final_answer:
final_answer = "I am unable to answer"
print(f"Finished agent run. Final Answer: {final_answer}\n{'='*50}")
return final_answer
def __call__(self, question: str, file_content: str = "", file_path: str = "") -> str:
return self.run(question, file_content, file_path)