|
|
from src.settings import Settings |
|
|
from smolagents import LiteLLMModel, ToolCallingAgent, CodeAgent, InferenceClientModel, Tool |
|
|
from tools.FinalAnswerTool import FinalAnswerTool |
|
|
from tools.ReadAudioTool import ReadAudioTool |
|
|
from tools.ReadImageTool import ReadImageTool |
|
|
from tools.ReadTextTool import ReadTextTool |
|
|
from tools.ReadVideoTool import ReadVideoTool |
|
|
from tools.WebSearchTool import ScaleSerpSearchTool |
|
|
from tools.WikipediaTool import LocalWikipediaTool |
|
|
from tools.YouTubeTool import YouTubeSearchTool |
|
|
from tools.PythonRunnerTool import PythonRunnerTool |
|
|
from tools.PythonCalcTool import PythonCalcTool |
|
|
from tools.SemanticScholar import AcademicPaperSearchTool |
|
|
from src.utils import InputTokenRateLimiter |
|
|
import wikipedia as wiki |
|
|
from markdownify import markdownify as to_markdown |
|
|
import time |
|
|
import random |
|
|
|
|
|
settings = Settings() |
|
|
import litellm |
|
|
from litellm import completion |
|
|
|
|
|
|
|
|
class WikiTitleFinder(Tool): |
|
|
name = "wiki_titles" |
|
|
description = "Search for related Wikipedia page titles." |
|
|
inputs = {"query": {"type": "string", "description": "Search query."}} |
|
|
output_type = "string" |
|
|
|
|
|
def forward(self, query: str) -> str: |
|
|
results = wiki.search(query) |
|
|
return ", ".join(results) if results else "No results." |
|
|
|
|
|
class WikiContentFetcher(Tool): |
|
|
name = "wiki_page" |
|
|
description = "Fetch Wikipedia page content." |
|
|
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}} |
|
|
output_type = "string" |
|
|
|
|
|
def forward(self, page_title: str) -> str: |
|
|
try: |
|
|
return to_markdown(wiki.page(page_title).html()) |
|
|
except wiki.exceptions.PageError: |
|
|
return f"'{page_title}' not found." |
|
|
|
|
|
class MathSolver(Tool): |
|
|
name = "math_solver" |
|
|
description = "Safely evaluate basic math expressions." |
|
|
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}} |
|
|
output_type = "string" |
|
|
|
|
|
def forward(self, input: str) -> str: |
|
|
try: |
|
|
return str(eval(input, {"__builtins__": {}})) |
|
|
except Exception as e: |
|
|
return f"Math error: {e}" |
|
|
|
|
|
class BasicAgent(): |
|
|
def __init__(self): |
|
|
self.model = LiteLLMModel( |
|
|
model_id=settings.llm_model_id, |
|
|
api_key=settings.llm_api_key |
|
|
) |
|
|
self.agent = CodeAgent( |
|
|
model=self.model, |
|
|
tools=[ |
|
|
ScaleSerpSearchTool(), |
|
|
FinalAnswerTool(), |
|
|
PythonCalcTool(), |
|
|
ReadAudioTool(), |
|
|
ReadImageTool(), |
|
|
ReadTextTool(), |
|
|
YouTubeSearchTool(), |
|
|
PythonRunnerTool(), |
|
|
MathSolver(), |
|
|
LocalWikipediaTool(), |
|
|
WikiTitleFinder(), |
|
|
WikiContentFetcher(), |
|
|
AcademicPaperSearchTool(), |
|
|
], |
|
|
max_steps=10, |
|
|
planning_interval=5, |
|
|
) |
|
|
self.token_rate_limiter = InputTokenRateLimiter() |
|
|
self.expected_tokens_per_step = 10000 |
|
|
self.max_retries = 3 |
|
|
self.base_delay = 5 |
|
|
self.token_rate_limiter = InputTokenRateLimiter() |
|
|
self.expected_tokens_per_step = 10000 |
|
|
self.max_retries = 3 |
|
|
self.base_delay = 5 |
|
|
|
|
|
|
|
|
def _read_file(self, file_path: str) -> str: |
|
|
if not os.path.exists(file_path): |
|
|
print(f"File not found: {file_path}") |
|
|
return "" |
|
|
if file_path.endswith(".txt"): |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
return f.read() |
|
|
elif file_path.endswith(".csv"): |
|
|
data = [] |
|
|
try: |
|
|
|
|
|
with open(file_path, mode='r', newline='', encoding='utf-8') as file: |
|
|
reader = csv.DictReader(file) |
|
|
for row in reader: |
|
|
data.append(row) |
|
|
|
|
|
|
|
|
return json.dumps(data, indent=2) |
|
|
except Exception as e: |
|
|
print(f"Error reading CSV file: {e}") |
|
|
return "" |
|
|
elif file_path.endswith(".docx"): |
|
|
doc = docx.Document(file_path) |
|
|
return "\n".join([p.text for p in doc.paragraphs]) |
|
|
else: |
|
|
|
|
|
print(f"Unsupported file type: {file_path}") |
|
|
return "" |
|
|
|
|
|
def run(self, question: str, file_content: str = "", file_path: str = "") -> str: |
|
|
final_answer = None |
|
|
retry_count = 0 |
|
|
|
|
|
|
|
|
if not file_content and file_path: |
|
|
file_content = self._read_file(file_path) |
|
|
|
|
|
context = "" |
|
|
if file_content: |
|
|
context = f"Story content:\n{file_content}" |
|
|
elif file_path: |
|
|
context = f"File path: {file_path}" |
|
|
|
|
|
print(f"Starting Agent with question: {question}\nContext length: {len(context)} chars") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
final_input = f"Question: {question}\n\n{context}" |
|
|
steps = self.agent.run(final_input) |
|
|
|
|
|
|
|
|
if isinstance(steps, str): |
|
|
steps = [steps] |
|
|
|
|
|
for step in steps: |
|
|
if isinstance(step, str): |
|
|
final_answer = step |
|
|
continue |
|
|
|
|
|
step_name = step.__class__.__name__ |
|
|
output = getattr(step, "output", None) |
|
|
if output: |
|
|
final_answer = output |
|
|
|
|
|
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step) |
|
|
tokens_used = getattr(step, "token_usage", None) |
|
|
if tokens_used: |
|
|
self.token_rate_limiter.add_tokens(tokens_used.input_tokens) |
|
|
|
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
if "overload" in str(e).lower() or "rate limit" in str(e).lower(): |
|
|
if retry_count >= self.max_retries: |
|
|
print("Max retries reached. Exiting...") |
|
|
break |
|
|
delay = self.base_delay * (2 ** retry_count) + random.random() |
|
|
print(f"Retrying in {delay:.1f}s due to rate limit... ({e})") |
|
|
time.sleep(delay) |
|
|
retry_count += 1 |
|
|
else: |
|
|
print(f"Error during agent run: {e}") |
|
|
break |
|
|
|
|
|
|
|
|
if not final_answer: |
|
|
final_answer = "I am unable to answer" |
|
|
|
|
|
print(f"Finished agent run. Final Answer: {final_answer}\n{'='*50}") |
|
|
return final_answer |
|
|
|
|
|
def __call__(self, question: str, file_content: str = "", file_path: str = "") -> str: |
|
|
return self.run(question, file_content, file_path) |
|
|
|