Sborole commited on
Commit
240c48a
·
verified ·
1 Parent(s): 0da4149

Update src/agent.py

Browse files
Files changed (1) hide show
  1. src/agent.py +71 -0
src/agent.py CHANGED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from settings import Settings
2
+ from smolagents import LiteLLMModel, ToolCallingAgent
3
+ from tools import DuckDuckGoSearchTool, FinalAnswerTool
4
+ from src.utils import InputTokenRateLimiter
5
+ import time
6
+ import random
7
+
8
+ settings = Settings()
9
+ class GaiaAgent():
10
+ def __init__(self):
11
+
12
+ self.model = LiteLLMModel(
13
+ model_name=settings.model_name,
14
+ api_key=settings.api_key
15
+ )
16
+
17
+ self.agent = ToolCallingAgent(
18
+ tools=[
19
+ DuckDuckGoSearchTool(),
20
+ FinalAnswerTool()
21
+ ],
22
+ model = self.model,
23
+ max_steps=5
24
+ )
25
+ self.token_rate_limiter = InputTokenRateLimiter()
26
+ self.expected_tokens_per_step = 10000
27
+ self.max_retries = 3
28
+ self.base_delay = 5
29
+
30
+ def run(self, question: str, file_name: str = "", file_content: str = ""):
31
+ final_answer = None
32
+ retry_count = 0
33
+
34
+ input_text = f"Question: {question}\nFile Name: {file_name}\nFile Content: {file_content}"
35
+ print(f"Starting Agent with input text: {input_text}")
36
+
37
+ while True:
38
+ try:
39
+ for step in self.agent.run(input_text):
40
+ step_name = step.__class__.__name__
41
+ if step.output:
42
+ print(f"Step: {step_name} Output: {step.output}")
43
+ print(f"Step: {step_name}")
44
+
45
+ self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
46
+ tokens_used = getattr(step, "token_usage", None)
47
+ if tokens_used:
48
+ self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
49
+
50
+ if step_name == "FinalAnswerStep":
51
+ final_answer = step.output
52
+ print(f"Final Answer: {final_answer}")
53
+ break
54
+ except Exception as e:
55
+ if "overload" in str(e).lower() or "rate limit" in str(e).lower():
56
+ print("Rate limit exceeded. Retrying...")
57
+ if retry_count >= self.max_retries:
58
+ print("Max retries reached. Exiting...")
59
+ break
60
+ delay = self.base_delay * (2 ** retry_count) + random.random()
61
+ print(f"API overload/rate limit. Retrying in {delay:.1f}s ... ({e})")
62
+ time.sleep(delay)
63
+ retry_count += 1
64
+
65
+ else:
66
+ print(f"Error: {e}")
67
+ break
68
+
69
+ print(f"\nFinished agent run.\n{'='*60}")
70
+ print(f"Final Answer: {final_answer}\n")
71
+ return final_answer