Mehedi2 commited on
Commit
d2a08ac
Β·
verified Β·
1 Parent(s): f36e4c1

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +129 -133
agent.py CHANGED
@@ -46,26 +46,39 @@ class ExcelToTextTool(Tool):
46
 
47
  class GaiaAgent:
48
  """
49
- Multi-model agent with intelligent routing and fallback.
50
 
51
- Models:
52
- - Fast (14.4K RPM): Simple Q&A, factual questions
53
- - Reasoning (1K RPM, 12K TPM): Logic, math, complex reasoning
54
- - Multimodal (1K RPM, 30K TPM): Images, chess, visual tasks
55
- """
56
 
57
- MODELS = {
58
- "fast": "groq/llama-3.1-8b-instant", # 14.4K RPM, 6K TPM
59
- "reasoning": "groq/llama-3.3-70b-versatile", # 1K RPM, 12K TPM
60
- "multimodal": "groq/meta-llama/llama-4-scout-17b-16e-instruct", # 1K RPM, 30K TPM
61
- }
62
 
63
  def __init__(self):
64
- print("βœ… GaiaAgent initialized with multi-model routing")
 
 
 
65
  self.api_key = os.getenv("GROQ_API_KEY")
66
  if not self.api_key:
67
  raise ValueError("GROQ_API_KEY not found in environment variables")
68
-
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Tools
70
  self.tools = [
71
  DuckDuckGoSearchTool(),
@@ -74,160 +87,147 @@ class GaiaAgent:
74
  PythonInterpreterTool(),
75
  FinalAnswerTool(),
76
  ]
77
-
78
- # Rate limiting - conservative for CodeAgent's multiple API calls
 
 
 
 
 
 
 
 
79
  self.last_call_time = 0
80
- self.min_delay = 15 # 15s between tasks (agents make 5-10+ calls internally)
81
- self.max_retries = 2 # Reduced to 2 to avoid long waits
82
-
83
  # Stats
84
  self.total_tasks = 0
85
  self.successful_tasks = 0
86
  self.failed_tasks = 0
87
  self.rate_limit_hits = 0
88
- self.model_usage = {"fast": 0, "reasoning": 0, "multimodal": 0}
89
 
90
- def _route_model(self, question: str, image: Optional[str] = None) -> str:
91
- """Intelligently route to appropriate model based on question type."""
92
- q = question.lower()
93
-
94
- # Multimodal: images, chess, visual content
95
- if image or any(kw in q for kw in ["chess", "image", "picture", "photo", "visual", "diagram"]):
96
- return "multimodal"
97
-
98
- # Reasoning: complex logic, math, proofs
99
- if any(kw in q for kw in [
100
- "prove", "counterexample", "logic", "reasoning", "algebra",
101
- "calculate", "derive", "theorem", "equation", "formula",
102
- "commutative", "associative", "distributive"
103
- ]):
104
- return "reasoning"
105
-
106
- # Fast: simple factual questions
107
- return "fast"
108
-
109
- def _create_agent(self, model_key: str) -> CodeAgent:
110
- """Create a new agent with the specified model."""
111
- model_id = self.MODELS[model_key]
112
- print(f"πŸ€– Initializing {model_key} model: {model_id}")
113
-
114
- model = LiteLLMModel(model_id=model_id, api_key=self.api_key)
115
 
116
- return CodeAgent(
117
- model=model,
118
- tools=self.tools,
119
- add_base_tools=True,
120
- additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
121
- )
122
-
123
- def _try_model(self, agent: CodeAgent, question: str, model_key: str) -> Optional[str]:
124
- """Try to get answer from a model with retries."""
125
- for attempt in range(self.max_retries + 1):
126
- try:
127
- print(f"πŸš€ Attempt {attempt + 1}/{self.max_retries + 1} with {model_key} model")
128
- answer = agent.run(question)
129
- if answer:
130
- self.model_usage[model_key] += 1
131
- return answer
132
-
133
- except Exception as e:
134
- err = str(e)
135
- print(f"⚠️ Error: {err[:200]}{'...' if len(err) > 200 else ''}")
136
-
137
- if "rate limit" in err.lower() or "rate_limit" in err.lower():
138
- self.rate_limit_hits += 1
139
-
140
- # Extract wait time from error
141
- wait_match = re.search(r'(\d+\.?\d*)\s*s', err)
142
- wait_time = float(wait_match.group(1)) + 3 if wait_match else 20
143
-
144
- if attempt < self.max_retries:
145
- print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s...")
146
- time.sleep(wait_time)
147
- continue
148
- else:
149
- print(f"❌ Rate limit exhausted for {model_key}")
150
- return None
151
- else:
152
- # Non-rate-limit error
153
- if attempt < self.max_retries:
154
- print(f"πŸ”„ Retrying in 5s...")
155
- time.sleep(5)
156
- continue
157
- else:
158
- print(f"❌ Failed after {self.max_retries + 1} attempts")
159
- return None
160
 
161
- return None
162
 
163
- def __call__(self, task_id: str, question: str, image: Optional[str] = None) -> str:
164
- """Process a task with intelligent routing and fallback."""
165
  self.total_tasks += 1
166
 
167
  # Rate limiting
168
  elapsed = time.time() - self.last_call_time
169
  if elapsed < self.min_delay:
170
  wait_time = self.min_delay - elapsed
171
- print(f"⏳ Base rate limit: waiting {wait_time:.1f}s...")
172
  time.sleep(wait_time)
173
 
174
  print(f"\n{'='*70}")
175
- print(f"πŸ”Ή Task #{self.total_tasks} | ID: {task_id}")
176
- print(f"πŸ”Ή Question: {question[:120]}{'...' if len(question) > 120 else ''}")
177
- print(f"{'='*70}")
178
 
179
- # Route to primary model
180
- primary_model = self._route_model(question, image)
181
- print(f"🎯 Routing to: {primary_model} model")
182
-
183
- # Define fallback chain
184
- fallback_chain = {
185
- "reasoning": ["fast"],
186
- "multimodal": ["reasoning", "fast"],
187
- "fast": [], # No fallback for fast model
188
- }
189
-
190
- # Try primary model first
191
- models_to_try = [primary_model] + fallback_chain.get(primary_model, [])
192
-
193
  answer = None
194
- for model_key in models_to_try:
 
 
195
  try:
196
- agent = self._create_agent(model_key)
197
- answer = self._try_model(agent, question, model_key)
198
 
199
- if answer:
200
  self.successful_tasks += 1
 
201
  break
202
  else:
203
- print(f"πŸ”„ Trying fallback model...")
204
- time.sleep(3) # Brief pause before fallback
 
 
205
 
206
  except Exception as e:
207
- print(f"❌ Failed to initialize {model_key}: {e}")
208
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
 
210
  if not answer:
211
- answer = "⚠️ All models failed to generate a valid response."
212
  self.failed_tasks += 1
213
 
 
214
  self.last_call_time = time.time()
215
 
 
216
  print(f"\n{'='*70}")
217
- print(f"πŸ“ Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
 
218
  print(f"{'='*70}\n")
219
 
220
- return answer
221
 
222
  def get_stats(self) -> dict:
223
  """Get agent performance statistics."""
 
224
  return {
225
  "total_tasks": self.total_tasks,
226
  "successful_tasks": self.successful_tasks,
227
  "failed_tasks": self.failed_tasks,
228
- "success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
229
  "rate_limit_hits": self.rate_limit_hits,
230
- "model_usage": self.model_usage,
231
  }
232
 
233
  def print_stats(self):
@@ -237,25 +237,21 @@ class GaiaAgent:
237
  print(f"πŸ“Š AGENT STATISTICS")
238
  print(f"{'='*70}")
239
  print(f"Total Tasks: {stats['total_tasks']}")
240
- print(f"Successful: {stats['successful_tasks']}")
241
- print(f"Failed: {stats['failed_tasks']}")
242
  print(f"Success Rate: {stats['success_rate']}")
243
- print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
244
- print(f"\nModel Usage:")
245
- for model, count in stats['model_usage'].items():
246
- print(f" {model.capitalize():12} {count} tasks")
247
  print(f"{'='*70}\n")
248
 
249
 
250
  # Example usage
251
  if __name__ == "__main__":
252
  agent = GaiaAgent()
253
-
254
- # Test with a simple question
255
  answer = agent(
256
  task_id="test-001",
257
- question="What is 2+2? Explain your reasoning."
258
  )
259
-
260
- # Print statistics
261
  agent.print_stats()
 
46
 
47
  class GaiaAgent:
48
  """
49
+ Single-model agent using Llama 4 Scout exclusively.
50
 
51
+ Why Llama 4 Scout:
52
+ - 30K TPM (highest available - 5x more than llama-3.1-8b)
53
+ - 500K context window
54
+ - Multimodal support (images, chess)
55
+ - 1K RPM
56
 
57
+ This avoids the 6K TPM bottleneck of llama-3.1-8b-instant.
58
+ """
 
 
 
59
 
60
  def __init__(self):
61
+ print("="*70)
62
+ print("βœ… GaiaAgent initialized with Llama 4 Scout (30K TPM)")
63
+ print("="*70)
64
+
65
  self.api_key = os.getenv("GROQ_API_KEY")
66
  if not self.api_key:
67
  raise ValueError("GROQ_API_KEY not found in environment variables")
68
+
69
+ # Single model configuration - Llama 4 Scout for all tasks
70
+ self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
71
+
72
+ print(f"πŸ€– Model: {self.model_id}")
73
+ print(f"πŸ“Š Limits: 30K TPM | 1K RPM | 500K context | Multimodal")
74
+ print("="*70 + "\n")
75
+
76
+ # Initialize model
77
+ self.model = LiteLLMModel(
78
+ model_id=self.model_id,
79
+ api_key=self.api_key,
80
+ )
81
+
82
  # Tools
83
  self.tools = [
84
  DuckDuckGoSearchTool(),
 
87
  PythonInterpreterTool(),
88
  FinalAnswerTool(),
89
  ]
90
+
91
+ # Create agent
92
+ self.agent = CodeAgent(
93
+ model=self.model,
94
+ tools=self.tools,
95
+ add_base_tools=True,
96
+ additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
97
+ )
98
+
99
+ # Rate limiting - 30K TPM is generous but agents make multiple calls
100
  self.last_call_time = 0
101
+ self.min_delay = 10 # 10s between tasks (reasonable with 30K TPM)
102
+ self.max_retries = 3 # More retries since we have higher TPM
103
+
104
  # Stats
105
  self.total_tasks = 0
106
  self.successful_tasks = 0
107
  self.failed_tasks = 0
108
  self.rate_limit_hits = 0
 
109
 
110
+ def _extract_wait_time(self, error_str: str) -> float:
111
+ """Extract wait time from rate limit error message."""
112
+ patterns = [
113
+ r'try again in (\d+\.?\d*)\s*s',
114
+ r'retry in (\d+\.?\d*)\s*s',
115
+ r'(\d+\.?\d*)\s*s',
116
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ for pattern in patterns:
119
+ match = re.search(pattern, error_str)
120
+ if match:
121
+ return float(match.group(1)) + 5 # Add 5s buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ return 30 # Default fallback
124
 
125
+ def __call__(self, task_id: str, question: str) -> str:
126
+ """Process a task with automatic rate limiting and retry."""
127
  self.total_tasks += 1
128
 
129
  # Rate limiting
130
  elapsed = time.time() - self.last_call_time
131
  if elapsed < self.min_delay:
132
  wait_time = self.min_delay - elapsed
133
+ print(f"⏳ Rate limit: waiting {wait_time:.1f}s...")
134
  time.sleep(wait_time)
135
 
136
  print(f"\n{'='*70}")
137
+ print(f"πŸ“‹ Task #{self.total_tasks} | ID: {task_id}")
138
+ print(f"❓ Question: {question[:150]}{'...' if len(question) > 150 else ''}")
139
+ print(f"{'='*70}\n")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  answer = None
142
+
143
+ # Retry loop with exponential backoff
144
+ for attempt in range(self.max_retries + 1):
145
  try:
146
+ print(f"πŸš€ Attempt {attempt + 1}/{self.max_retries + 1}")
147
+ answer = self.agent.run(question)
148
 
149
+ if answer and len(str(answer).strip()) > 0:
150
  self.successful_tasks += 1
151
+ print(f"βœ… Success!")
152
  break
153
  else:
154
+ print(f"⚠️ Empty answer received")
155
+ if attempt < self.max_retries:
156
+ time.sleep(5)
157
+ continue
158
 
159
  except Exception as e:
160
+ error_str = str(e)
161
+
162
+ # Show condensed error
163
+ if len(error_str) > 300:
164
+ print(f"❌ Error: {error_str[:300]}...")
165
+ else:
166
+ print(f"❌ Error: {error_str}")
167
+
168
+ # Check if it's a rate limit error
169
+ if "rate limit" in error_str.lower() or "rate_limit" in error_str.lower():
170
+ self.rate_limit_hits += 1
171
+ wait_time = self._extract_wait_time(error_str)
172
+
173
+ if attempt < self.max_retries:
174
+ print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s before retry...")
175
+
176
+ # Show countdown for long waits
177
+ if wait_time > 10:
178
+ for remaining in range(int(wait_time), 0, -5):
179
+ print(f" ⏱️ {remaining}s remaining...", flush=True)
180
+ time.sleep(5)
181
+ else:
182
+ time.sleep(wait_time)
183
+
184
+ print(f"πŸ”„ Retrying...")
185
+ continue
186
+ else:
187
+ answer = "⚠️ Rate limit exceeded after all retries."
188
+ self.failed_tasks += 1
189
+
190
+ # Authentication error
191
+ elif "authentication" in error_str.lower() or "api key" in error_str.lower():
192
+ answer = "⚠️ Authentication failed. Check your GROQ_API_KEY."
193
+ self.failed_tasks += 1
194
+ break
195
+
196
+ # Other errors
197
+ else:
198
+ if attempt < self.max_retries:
199
+ print(f"πŸ”„ Retrying in 5s...")
200
+ time.sleep(5)
201
+ continue
202
+ else:
203
+ answer = f"⚠️ Failed after {self.max_retries + 1} attempts."
204
+ self.failed_tasks += 1
205
 
206
+ # Fallback
207
  if not answer:
208
+ answer = "⚠️ Could not generate a valid response."
209
  self.failed_tasks += 1
210
 
211
+ # Update timing
212
  self.last_call_time = time.time()
213
 
214
+ # Print result
215
  print(f"\n{'='*70}")
216
+ answer_preview = str(answer)[:250] + ('...' if len(str(answer)) > 250 else '')
217
+ print(f"✍️ Answer: {answer_preview}")
218
  print(f"{'='*70}\n")
219
 
220
+ return str(answer)
221
 
222
  def get_stats(self) -> dict:
223
  """Get agent performance statistics."""
224
+ success_rate = (self.successful_tasks / self.total_tasks * 100) if self.total_tasks > 0 else 0
225
  return {
226
  "total_tasks": self.total_tasks,
227
  "successful_tasks": self.successful_tasks,
228
  "failed_tasks": self.failed_tasks,
229
+ "success_rate": f"{success_rate:.1f}%",
230
  "rate_limit_hits": self.rate_limit_hits,
 
231
  }
232
 
233
  def print_stats(self):
 
237
  print(f"πŸ“Š AGENT STATISTICS")
238
  print(f"{'='*70}")
239
  print(f"Total Tasks: {stats['total_tasks']}")
240
+ print(f"Successful: {stats['successful_tasks']} βœ…")
241
+ print(f"Failed: {stats['failed_tasks']} ❌")
242
  print(f"Success Rate: {stats['success_rate']}")
243
+ print(f"Rate Limit Hits: {stats['rate_limit_hits']} 🚫")
 
 
 
244
  print(f"{'='*70}\n")
245
 
246
 
247
  # Example usage
248
  if __name__ == "__main__":
249
  agent = GaiaAgent()
250
+
251
+ # Test
252
  answer = agent(
253
  task_id="test-001",
254
+ question="What is 2+2? Show your calculation."
255
  )
256
+
 
257
  agent.print_stats()