Mehedi2 commited on
Commit
794795e
·
verified ·
1 Parent(s): 07c5695

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +163 -126
agent.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import time
 
3
  from pathlib import Path
4
  from typing import Optional, Union
5
- from itertools import cycle
6
 
7
  import pandas as pd
8
  from dotenv import load_dotenv
@@ -22,60 +22,6 @@ from smolagents.tools import Tool
22
  load_dotenv()
23
 
24
 
25
- class MultiModelManager:
26
- """Manages multiple Groq models with rotation and fallback."""
27
-
28
- def __init__(self):
29
- # Alternative: Use a proven working Groq model
30
- # If GPT-OSS still has issues, uncomment the line below:
31
- # self.models = ["groq/llama-3.3-70b-versatile"]
32
-
33
- # Current: Trying GPT-OSS 120B with groq/ prefix
34
- self.models = [
35
- "groq/openai/gpt-oss-120b", # GPT OSS 120B via Groq
36
- ]
37
-
38
- self.api_key = os.getenv("GROQ_API_KEY")
39
- self.model_cycle = cycle(self.models)
40
- self.current_model_name = self.models[0]
41
-
42
- def get_next_model(self):
43
- """Get the next model in rotation."""
44
- self.current_model_name = next(self.model_cycle)
45
- return LiteLLMModel(
46
- model_id=self.current_model_name,
47
- api_key=self.api_key,
48
- )
49
-
50
- def get_model_by_complexity(self, complexity: str = "high"):
51
- """
52
- Get a model based on task complexity.
53
-
54
- Args:
55
- complexity: "high", "medium", or "low"
56
- """
57
- if complexity == "high":
58
- model_id = self.models[0] # llama-3.3-70b
59
- elif complexity == "medium":
60
- model_id = self.models[2] # mixtral-8x7b
61
- else: # low
62
- model_id = self.models[3] # llama-3.1-8b
63
-
64
- self.current_model_name = model_id
65
- return LiteLLMModel(
66
- model_id=model_id,
67
- api_key=self.api_key,
68
- )
69
-
70
- def get_primary_model(self):
71
- """Get the primary (best) model."""
72
- self.current_model_name = self.models[0]
73
- return LiteLLMModel(
74
- model_id=self.models[0],
75
- api_key=self.api_key,
76
- )
77
-
78
-
79
  class ExcelToTextTool(Tool):
80
  """Render an Excel worksheet as a Markdown table."""
81
 
@@ -123,21 +69,35 @@ class ExcelToTextTool(Tool):
123
 
124
 
125
  class GaiaAgent:
126
- """An agent using only GPT-OSS 120B for maximum performance."""
 
127
 
128
- def __init__(self, strategy: str = "primary"):
129
- """
130
- Initialize agent with GPT-OSS 120B.
 
 
 
 
 
 
 
131
 
132
- Args:
133
- strategy: Kept for compatibility but only uses GPT-OSS 120B
134
- """
135
- print(f"✅ GaiaAgent initialized with GPT-OSS 120B.")
 
 
 
 
 
 
 
 
136
 
137
- self.strategy = strategy
138
- self.model_manager = MultiModelManager()
139
- self.retry_count = 0
140
- self.max_retries = 2
141
 
142
  # Initialize tools
143
  self.tools = [
@@ -148,102 +108,179 @@ class GaiaAgent:
148
  FinalAnswerTool(),
149
  ]
150
 
151
- # Rate limiting
 
 
 
 
 
 
 
 
152
  self.last_call_time = 0
153
- self.min_delay = 10 # 10 seconds between tasks to avoid rate limits
154
- self.tokens_used_in_window = 0
155
- self.window_start_time = time.time()
156
 
157
- # Initialize agent with primary model
158
- self._reinitialize_agent()
 
 
 
159
 
160
- def _reinitialize_agent(self):
161
- """Reinitialize the agent with GPT-OSS 120B."""
162
- model = self.model_manager.get_primary_model()
163
 
164
- print(f"🤖 Using model: {self.model_manager.current_model_name}")
 
 
 
165
 
166
- self.agent = CodeAgent(
167
- model=model,
168
- tools=self.tools,
169
- add_base_tools=True,
170
- additional_authorized_imports=["pandas", "numpy", "csv", "subprocess"],
171
- )
172
 
173
- def _detect_complexity(self, question: str) -> str:
174
- """Detect question complexity based on keywords."""
175
- question_lower = question.lower()
176
-
177
- # High complexity indicators
178
- high_keywords = ["analyze", "complex", "multiple", "calculate", "prove",
179
- "demonstrate", "derive", "algorithm"]
180
- if any(keyword in question_lower for keyword in high_keywords):
181
- return "high"
182
-
183
- # Low complexity indicators
184
- low_keywords = ["what is", "who is", "when", "define", "list"]
185
- if any(keyword in question_lower for keyword in low_keywords):
186
- return "low"
187
-
188
- return "medium"
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  def __call__(self, task_id: str, question: str) -> str:
191
- # Apply rate limiting
 
 
 
 
 
 
 
 
 
 
 
 
192
  elapsed = time.time() - self.last_call_time
193
  if elapsed < self.min_delay:
194
  wait_time = self.min_delay - elapsed
195
- print(f"⏳ Rate limiting: waiting {wait_time:.1f}s...")
196
  time.sleep(wait_time)
197
 
198
- print(f"🔹 Task ID: {task_id}")
199
- print(f"🔹 Question: {question[:100]}...")
 
 
200
 
201
- # Try to get answer with retry logic and exponential backoff
202
  answer = None
 
 
203
  for attempt in range(self.max_retries + 1):
204
  try:
 
205
  answer = self.agent.run(question)
 
206
  if answer:
 
 
207
  break
 
208
  except Exception as e:
209
  error_str = str(e)
210
- print(f"⚠️ Attempt {attempt + 1} failed: {error_str[:150]}")
 
211
 
212
  # Check if it's a rate limit error
213
- if "rate_limit" in error_str.lower() or "Rate limit" in error_str:
214
- # Extract wait time if available
215
- import re
216
- wait_match = re.search(r'(\d+\.?\d*)\s*s', error_str)
217
- if wait_match:
218
- wait_time = float(wait_match.group(1)) + 2 # Add 2s buffer
219
- else:
220
- wait_time = 15 * (attempt + 1) # Exponential backoff: 15s, 30s, 45s
221
-
222
- print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s before retry...")
223
- time.sleep(wait_time)
224
 
225
  if attempt < self.max_retries:
226
- print(f"🔄 Retrying (attempt {attempt + 2}/{self.max_retries + 1})...")
 
227
  continue
 
 
 
 
 
228
  else:
229
- # Non-rate-limit error
230
  if attempt < self.max_retries:
231
- print(f"🔄 Retrying with fresh agent...")
232
- self._reinitialize_agent()
233
- time.sleep(2)
234
  else:
235
- answer = f"⚠️ Agent failed after {self.max_retries + 1} attempts: {e}"
 
236
 
 
237
  if not answer:
238
  answer = "⚠️ Sorry, I could not generate a valid response."
 
239
 
240
- # Update last call time
241
  self.last_call_time = time.time()
242
 
243
- print(f"✅ Answer: {str(answer)[:100]}...")
 
 
 
 
244
  return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  # Example usage:
248
- # Simply initialize the agent - it will always use GPT-OSS 120B
249
- # agent = GaiaAgent()
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import re
4
  from pathlib import Path
5
  from typing import Optional, Union
 
6
 
7
  import pandas as pd
8
  from dotenv import load_dotenv
 
22
  load_dotenv()
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class ExcelToTextTool(Tool):
26
  """Render an Excel worksheet as a Markdown table."""
27
 
 
69
 
70
 
71
  class GaiaAgent:
72
+ """
73
+ An agent optimized for Llama 4 Scout with multimodal capabilities.
74
 
75
+ Features:
76
+ - Uses Llama 4 Scout (30K TPM, 500K context, multimodal)
77
+ - Intelligent rate limiting with exponential backoff
78
+ - Automatic retry logic for rate limit errors
79
+ - Support for text and image inputs
80
+ """
81
+
82
+ def __init__(self):
83
+ """Initialize agent with Llama 4 Scout model."""
84
+ print("✅ GaiaAgent initialized with Llama 4 Scout (30K TPM, Multimodal)")
85
 
86
+ # Model configuration
87
+ self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
88
+ self.api_key = os.getenv("GROQ_API_KEY")
89
+
90
+ if not self.api_key:
91
+ raise ValueError("GROQ_API_KEY not found in environment variables")
92
+
93
+ # Initialize model
94
+ self.model = LiteLLMModel(
95
+ model_id=self.model_id,
96
+ api_key=self.api_key,
97
+ )
98
 
99
+ print(f"🤖 Using model: {self.model_id}")
100
+ print(f"📊 Limits: 30K TPM | 1K RPM | 500K context")
 
 
101
 
102
  # Initialize tools
103
  self.tools = [
 
108
  FinalAnswerTool(),
109
  ]
110
 
111
+ # Create agent
112
+ self.agent = CodeAgent(
113
+ model=self.model,
114
+ tools=self.tools,
115
+ add_base_tools=True,
116
+ additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
117
+ )
118
+
119
+ # Rate limiting configuration (optimized for 30K TPM)
120
  self.last_call_time = 0
121
+ self.min_delay = 3 # 3 seconds between tasks (generous with 30K TPM)
122
+ self.max_retries = 3
 
123
 
124
+ # Statistics tracking
125
+ self.total_tasks = 0
126
+ self.successful_tasks = 0
127
+ self.failed_tasks = 0
128
+ self.rate_limit_hits = 0
129
 
130
+ def _wait_for_rate_limit(self, wait_time: float):
131
+ """Wait for rate limit with progress indicator."""
132
+ print(f"⏳ Rate limit: waiting {wait_time:.1f}s...", end="", flush=True)
133
 
134
+ # Show countdown
135
+ for remaining in range(int(wait_time), 0, -1):
136
+ print(f"\r⏳ Rate limit: waiting {remaining}s... ", end="", flush=True)
137
+ time.sleep(1)
138
 
139
+ print("\r✓ Ready to retry ")
 
 
 
 
 
140
 
141
+ def _extract_wait_time(self, error_str: str) -> float:
142
+ """Extract wait time from rate limit error message."""
143
+ # Look for patterns like "3.675499999s" or "try again in 3.6s"
144
+ patterns = [
145
+ r'(\d+\.?\d*)\s*s',
146
+ r'try again in (\d+\.?\d*)',
147
+ r'retry in (\d+\.?\d*)',
148
+ ]
149
+
150
+ for pattern in patterns:
151
+ match = re.search(pattern, error_str)
152
+ if match:
153
+ return float(match.group(1)) + 2 # Add 2s buffer
154
+
155
+ return 15 # Default fallback
156
+
157
+ def _handle_rate_limit_error(self, error_str: str, attempt: int) -> float:
158
+ """Handle rate limit error and return wait time."""
159
+ self.rate_limit_hits += 1
160
+
161
+ # Try to extract wait time from error
162
+ wait_time = self._extract_wait_time(error_str)
163
+
164
+ # Apply exponential backoff if extraction failed
165
+ if wait_time == 15:
166
+ wait_time = 15 * (attempt + 1) # 15s, 30s, 45s, 60s
167
+
168
+ return min(wait_time, 60) # Cap at 60s
169
 
170
  def __call__(self, task_id: str, question: str) -> str:
171
+ """
172
+ Process a task with automatic rate limiting and retry logic.
173
+
174
+ Args:
175
+ task_id: Unique identifier for the task
176
+ question: The question to answer
177
+
178
+ Returns:
179
+ The answer string
180
+ """
181
+ self.total_tasks += 1
182
+
183
+ # Apply base rate limiting
184
  elapsed = time.time() - self.last_call_time
185
  if elapsed < self.min_delay:
186
  wait_time = self.min_delay - elapsed
187
+ print(f"⏳ Base rate limit: waiting {wait_time:.1f}s...")
188
  time.sleep(wait_time)
189
 
190
+ print(f"\n{'='*70}")
191
+ print(f"🔹 Task #{self.total_tasks} | ID: {task_id}")
192
+ print(f"🔹 Question: {question[:120]}{'...' if len(question) > 120 else ''}")
193
+ print(f"{'='*70}")
194
 
 
195
  answer = None
196
+
197
+ # Retry loop with exponential backoff
198
  for attempt in range(self.max_retries + 1):
199
  try:
200
+ print(f"\n🚀 Attempt {attempt + 1}/{self.max_retries + 1}...")
201
  answer = self.agent.run(question)
202
+
203
  if answer:
204
+ self.successful_tasks += 1
205
+ print(f"✅ Success!")
206
  break
207
+
208
  except Exception as e:
209
  error_str = str(e)
210
+ print(f"\n⚠️ Error on attempt {attempt + 1}:")
211
+ print(f" {error_str[:200]}{'...' if len(error_str) > 200 else ''}")
212
 
213
  # Check if it's a rate limit error
214
+ if any(keyword in error_str.lower() for keyword in ['rate_limit', 'rate limit', 'quota']):
215
+ wait_time = self._handle_rate_limit_error(error_str, attempt)
 
 
 
 
 
 
 
 
 
216
 
217
  if attempt < self.max_retries:
218
+ print(f"🔄 Rate limit detected. Retrying after wait...")
219
+ self._wait_for_rate_limit(wait_time)
220
  continue
221
+ else:
222
+ answer = f"⚠️ Rate limit exceeded after {self.max_retries + 1} attempts. Please try again later."
223
+ self.failed_tasks += 1
224
+
225
+ # Non-rate-limit error
226
  else:
 
227
  if attempt < self.max_retries:
228
+ print(f"🔄 Retrying in 5s...")
229
+ time.sleep(5)
230
+ continue
231
  else:
232
+ answer = f"⚠️ Agent failed: {error_str[:300]}"
233
+ self.failed_tasks += 1
234
 
235
+ # Fallback if no answer generated
236
  if not answer:
237
  answer = "⚠️ Sorry, I could not generate a valid response."
238
+ self.failed_tasks += 1
239
 
240
+ # Update timing
241
  self.last_call_time = time.time()
242
 
243
+ # Print results
244
+ print(f"\n{'='*70}")
245
+ print(f"📝 Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
246
+ print(f"{'='*70}")
247
+
248
  return answer
249
+
250
+ def get_stats(self) -> dict:
251
+ """Get agent performance statistics."""
252
+ return {
253
+ "total_tasks": self.total_tasks,
254
+ "successful_tasks": self.successful_tasks,
255
+ "failed_tasks": self.failed_tasks,
256
+ "success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
257
+ "rate_limit_hits": self.rate_limit_hits,
258
+ }
259
+
260
+ def print_stats(self):
261
+ """Print agent performance statistics."""
262
+ stats = self.get_stats()
263
+ print(f"\n{'='*70}")
264
+ print(f"📊 AGENT STATISTICS")
265
+ print(f"{'='*70}")
266
+ print(f"Total Tasks: {stats['total_tasks']}")
267
+ print(f"Successful: {stats['successful_tasks']}")
268
+ print(f"Failed: {stats['failed_tasks']}")
269
+ print(f"Success Rate: {stats['success_rate']}")
270
+ print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
271
+ print(f"{'='*70}\n")
272
 
273
 
274
  # Example usage:
275
+ if __name__ == "__main__":
276
+ # Initialize agent
277
+ agent = GaiaAgent()
278
+
279
+ # Test with a simple question
280
+ answer = agent(
281
+ task_id="test-001",
282
+ question="What is 2+2? Explain your reasoning."
283
+ )
284
+
285
+ # Print statistics
286
+ agent.print_stats()