Sborole's picture
Update src/agent.py
6bb6cd7 verified
raw
history blame
6.35 kB
from src.settings import Settings
from smolagents import LiteLLMModel, ToolCallingAgent, CodeAgent, InferenceClientModel, Tool
from tools.FinalAnswerTool import FinalAnswerTool
from tools.ReadAudioTool import ReadAudioTool
from tools.ReadImageTool import ReadImageTool
from tools.ReadTextTool import ReadTextTool
from tools.ReadVideoTool import ReadVideoTool
from tools.WebSearchTool import SerpApiSearchTool
from tools.YouTubeTool import YouTubeTool
from tools.PythonRunnerTool import PythonRunnerTool
from tools.PythonCalcTool import PythonCalcTool
from tools.SemanticScholar import AcademicPaperSearchTool
from src.utils import InputTokenRateLimiter
import wikipedia as wiki
from markdownify import markdownify as to_markdown
import time
import random
settings = Settings()
import litellm
from litellm import completion
#litellm._turn_on_debug()
class WikiTitleFinder(Tool):
name = "wiki_titles"
description = "Search for related Wikipedia page titles."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
results = wiki.search(query)
return ", ".join(results) if results else "No results."
class WikiContentFetcher(Tool):
name = "wiki_page"
description = "Fetch Wikipedia page content."
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
output_type = "string"
def forward(self, page_title: str) -> str:
try:
return to_markdown(wiki.page(page_title).html())
except wiki.exceptions.PageError:
return f"'{page_title}' not found."
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
return str(eval(input, {"__builtins__": {}}))
except Exception as e:
return f"Math error: {e}"
class BasicAgent():
def __init__(self):
self.model = LiteLLMModel(
model_id=settings.llm_model_id,
api_key=settings.llm_api_key
)
self.agent = CodeAgent(
model=self.model,
tools=[
FinalAnswerTool(),
ReadAudioTool(),
ReadImageTool(),
ReadTextTool(),
ReadVideoTool(),
SerpAPISearchTool(),
WikiTitleFinder(),
WikiContentFetcher(),
YouTubeTool(),
PythonRunnerTool(),
PythonCalcTool(),
AcademicPaperSearchTool(),
MathSolver()
],
max_steps=10,
planning_interval=5,
)
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
self.token_rate_limiter = InputTokenRateLimiter()
self.expected_tokens_per_step = 10000
self.max_retries = 3
self.base_delay = 5
def _read_file(self, file_path: str) -> str:
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
return ""
if file_path.endswith(".txt"):
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
elif file_path.endswith(".docx"):
doc = docx.Document(file_path)
return "\n".join([p.text for p in doc.paragraphs])
else:
# For unsupported formats, return empty string
print(f"Unsupported file type: {file_path}")
return ""
def run(self, question: str, file_content: str = "", file_path: str = "") -> str:
final_answer = None
retry_count = 0
# If file content is empty but file_path exists, read the file
if not file_content and file_path:
file_content = self._read_file(file_path)
context = ""
if file_content:
context = f"Story content:\n{file_content}"
elif file_path:
context = f"File path: {file_path}"
print(f"Starting Agent with question: {question}\nContext length: {len(context)} chars")
while True:
try:
final_input = f"Question: {question}\n\n{context}"
steps = self.agent.run(final_input)
# Convert string steps to list
if isinstance(steps, str):
steps = [steps]
for step in steps:
if isinstance(step, str):
final_answer = step
continue
step_name = step.__class__.__name__
output = getattr(step, "output", None)
if output:
final_answer = output
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
tokens_used = getattr(step, "token_usage", None)
if tokens_used:
self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
break # Exit retry loop if successful
except Exception as e:
if "overload" in str(e).lower() or "rate limit" in str(e).lower():
if retry_count >= self.max_retries:
print("Max retries reached. Exiting...")
break
delay = self.base_delay * (2 ** retry_count) + random.random()
print(f"Retrying in {delay:.1f}s due to rate limit... ({e})")
time.sleep(delay)
retry_count += 1
else:
print(f"Error during agent run: {e}")
break
# Ensure a valid answer is always returned
if not final_answer:
final_answer = "I am unable to answer"
print(f"Finished agent run. Final Answer: {final_answer}\n{'='*50}")
return final_answer
def __call__(self, question: str, file_content: str = "", file_path: str = "") -> str:
return self.run(question, file_content, file_path)