Mehedi2 commited on
Commit
f36e4c1
Β·
verified Β·
1 Parent(s): 794795e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +133 -158
agent.py CHANGED
@@ -24,82 +24,49 @@ load_dotenv()
24
 
25
  class ExcelToTextTool(Tool):
26
  """Render an Excel worksheet as a Markdown table."""
27
-
28
  name = "excel_to_text"
29
- description = (
30
- "Read an Excel file and return a Markdown table of the requested sheet. "
31
- "Accepts either the sheet name or a zero-based index (as a string)."
32
- )
33
  inputs = {
34
- "excel_path": {
35
- "type": "string",
36
- "description": "Path to the Excel file (.xlsx or .xls).",
37
- },
38
- "sheet_name": {
39
- "type": "string",
40
- "description": (
41
- "Worksheet name or zero-based index (as a string). "
42
- "Optional; defaults to the first sheet."
43
- ),
44
- "nullable": True,
45
- },
46
  }
47
  output_type = "string"
48
 
49
  def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
50
- """Load the Excel file and return the sheet as a Markdown table."""
51
  file_path = Path(excel_path).expanduser().resolve()
52
-
53
  if not file_path.is_file():
54
  return f"Error: Excel file not found at {file_path}"
55
-
56
  try:
57
- sheet: Union[str, int] = (
58
- int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
59
- )
60
  df = pd.read_excel(file_path, sheet_name=sheet)
61
-
62
- if hasattr(df, "to_markdown"):
63
- return df.to_markdown(index=False)
64
-
65
- return tabulate(df, headers="keys", tablefmt="github", showindex=False)
66
-
67
  except Exception as e:
68
  return f"Error reading Excel file: {e}"
69
 
70
 
71
  class GaiaAgent:
72
  """
73
- An agent optimized for Llama 4 Scout with multimodal capabilities.
74
 
75
- Features:
76
- - Uses Llama 4 Scout (30K TPM, 500K context, multimodal)
77
- - Intelligent rate limiting with exponential backoff
78
- - Automatic retry logic for rate limit errors
79
- - Support for text and image inputs
80
  """
81
 
 
 
 
 
 
 
82
  def __init__(self):
83
- """Initialize agent with Llama 4 Scout model."""
84
- print("βœ… GaiaAgent initialized with Llama 4 Scout (30K TPM, Multimodal)")
85
-
86
- # Model configuration
87
- self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
88
  self.api_key = os.getenv("GROQ_API_KEY")
89
-
90
  if not self.api_key:
91
  raise ValueError("GROQ_API_KEY not found in environment variables")
92
-
93
- # Initialize model
94
- self.model = LiteLLMModel(
95
- model_id=self.model_id,
96
- api_key=self.api_key,
97
- )
98
-
99
- print(f"πŸ€– Using model: {self.model_id}")
100
- print(f"πŸ“Š Limits: 30K TPM | 1K RPM | 500K context")
101
-
102
- # Initialize tools
103
  self.tools = [
104
  DuckDuckGoSearchTool(),
105
  WikipediaSearchTool(),
@@ -107,146 +74,151 @@ class GaiaAgent:
107
  PythonInterpreterTool(),
108
  FinalAnswerTool(),
109
  ]
110
-
111
- # Create agent
112
- self.agent = CodeAgent(
113
- model=self.model,
114
- tools=self.tools,
115
- add_base_tools=True,
116
- additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
117
- )
118
-
119
- # Rate limiting configuration (optimized for 30K TPM)
120
  self.last_call_time = 0
121
- self.min_delay = 3 # 3 seconds between tasks (generous with 30K TPM)
122
- self.max_retries = 3
123
-
124
- # Statistics tracking
125
  self.total_tasks = 0
126
  self.successful_tasks = 0
127
  self.failed_tasks = 0
128
  self.rate_limit_hits = 0
129
-
130
- def _wait_for_rate_limit(self, wait_time: float):
131
- """Wait for rate limit with progress indicator."""
132
- print(f"⏳ Rate limit: waiting {wait_time:.1f}s...", end="", flush=True)
133
-
134
- # Show countdown
135
- for remaining in range(int(wait_time), 0, -1):
136
- print(f"\r⏳ Rate limit: waiting {remaining}s... ", end="", flush=True)
137
- time.sleep(1)
138
 
139
- print("\rβœ“ Ready to retry ")
140
-
141
- def _extract_wait_time(self, error_str: str) -> float:
142
- """Extract wait time from rate limit error message."""
143
- # Look for patterns like "3.675499999s" or "try again in 3.6s"
144
- patterns = [
145
- r'(\d+\.?\d*)\s*s',
146
- r'try again in (\d+\.?\d*)',
147
- r'retry in (\d+\.?\d*)',
148
- ]
149
 
150
- for pattern in patterns:
151
- match = re.search(pattern, error_str)
152
- if match:
153
- return float(match.group(1)) + 2 # Add 2s buffer
 
 
 
154
 
155
- return 15 # Default fallback
156
-
157
- def _handle_rate_limit_error(self, error_str: str, attempt: int) -> float:
158
- """Handle rate limit error and return wait time."""
159
- self.rate_limit_hits += 1
160
-
161
- # Try to extract wait time from error
162
- wait_time = self._extract_wait_time(error_str)
163
 
164
- # Apply exponential backoff if extraction failed
165
- if wait_time == 15:
166
- wait_time = 15 * (attempt + 1) # 15s, 30s, 45s, 60s
167
 
168
- return min(wait_time, 60) # Cap at 60s
169
-
170
- def __call__(self, task_id: str, question: str) -> str:
171
- """
172
- Process a task with automatic rate limiting and retry logic.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- Args:
175
- task_id: Unique identifier for the task
176
- question: The question to answer
177
-
178
- Returns:
179
- The answer string
180
- """
181
  self.total_tasks += 1
182
 
183
- # Apply base rate limiting
184
  elapsed = time.time() - self.last_call_time
185
  if elapsed < self.min_delay:
186
  wait_time = self.min_delay - elapsed
187
  print(f"⏳ Base rate limit: waiting {wait_time:.1f}s...")
188
  time.sleep(wait_time)
189
-
190
  print(f"\n{'='*70}")
191
  print(f"πŸ”Ή Task #{self.total_tasks} | ID: {task_id}")
192
  print(f"πŸ”Ή Question: {question[:120]}{'...' if len(question) > 120 else ''}")
193
  print(f"{'='*70}")
 
 
 
 
194
 
195
- answer = None
 
 
 
 
 
196
 
197
- # Retry loop with exponential backoff
198
- for attempt in range(self.max_retries + 1):
 
 
 
199
  try:
200
- print(f"\nπŸš€ Attempt {attempt + 1}/{self.max_retries + 1}...")
201
- answer = self.agent.run(question)
202
 
203
  if answer:
204
  self.successful_tasks += 1
205
- print(f"βœ… Success!")
206
  break
 
 
 
207
 
208
  except Exception as e:
209
- error_str = str(e)
210
- print(f"\n⚠️ Error on attempt {attempt + 1}:")
211
- print(f" {error_str[:200]}{'...' if len(error_str) > 200 else ''}")
212
-
213
- # Check if it's a rate limit error
214
- if any(keyword in error_str.lower() for keyword in ['rate_limit', 'rate limit', 'quota']):
215
- wait_time = self._handle_rate_limit_error(error_str, attempt)
216
-
217
- if attempt < self.max_retries:
218
- print(f"πŸ”„ Rate limit detected. Retrying after wait...")
219
- self._wait_for_rate_limit(wait_time)
220
- continue
221
- else:
222
- answer = f"⚠️ Rate limit exceeded after {self.max_retries + 1} attempts. Please try again later."
223
- self.failed_tasks += 1
224
-
225
- # Non-rate-limit error
226
- else:
227
- if attempt < self.max_retries:
228
- print(f"πŸ”„ Retrying in 5s...")
229
- time.sleep(5)
230
- continue
231
- else:
232
- answer = f"⚠️ Agent failed: {error_str[:300]}"
233
- self.failed_tasks += 1
234
-
235
- # Fallback if no answer generated
236
  if not answer:
237
- answer = "⚠️ Sorry, I could not generate a valid response."
238
  self.failed_tasks += 1
239
-
240
- # Update timing
241
  self.last_call_time = time.time()
242
 
243
- # Print results
244
  print(f"\n{'='*70}")
245
  print(f"πŸ“ Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
246
- print(f"{'='*70}")
247
 
248
  return answer
249
-
250
  def get_stats(self) -> dict:
251
  """Get agent performance statistics."""
252
  return {
@@ -255,8 +227,9 @@ class GaiaAgent:
255
  "failed_tasks": self.failed_tasks,
256
  "success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
257
  "rate_limit_hits": self.rate_limit_hits,
 
258
  }
259
-
260
  def print_stats(self):
261
  """Print agent performance statistics."""
262
  stats = self.get_stats()
@@ -268,19 +241,21 @@ class GaiaAgent:
268
  print(f"Failed: {stats['failed_tasks']}")
269
  print(f"Success Rate: {stats['success_rate']}")
270
  print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
 
 
 
271
  print(f"{'='*70}\n")
272
 
273
 
274
- # Example usage:
275
  if __name__ == "__main__":
276
- # Initialize agent
277
  agent = GaiaAgent()
278
-
279
  # Test with a simple question
280
  answer = agent(
281
  task_id="test-001",
282
  question="What is 2+2? Explain your reasoning."
283
  )
284
-
285
  # Print statistics
286
  agent.print_stats()
 
24
 
25
  class ExcelToTextTool(Tool):
26
  """Render an Excel worksheet as a Markdown table."""
 
27
  name = "excel_to_text"
28
+ description = "Read an Excel file and return a Markdown table of the requested sheet."
 
 
 
29
  inputs = {
30
+ "excel_path": {"type": "string", "description": "Path to the Excel file."},
31
+ "sheet_name": {"type": "string", "description": "Worksheet name or index. Optional.", "nullable": True},
 
 
 
 
 
 
 
 
 
 
32
  }
33
  output_type = "string"
34
 
35
  def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
 
36
  file_path = Path(excel_path).expanduser().resolve()
 
37
  if not file_path.is_file():
38
  return f"Error: Excel file not found at {file_path}"
 
39
  try:
40
+ sheet: Union[str, int] = int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
 
 
41
  df = pd.read_excel(file_path, sheet_name=sheet)
42
+ return df.to_markdown(index=False) if hasattr(df, "to_markdown") else tabulate(df, headers="keys", tablefmt="github", showindex=False)
 
 
 
 
 
43
  except Exception as e:
44
  return f"Error reading Excel file: {e}"
45
 
46
 
47
  class GaiaAgent:
48
  """
49
+ Multi-model agent with intelligent routing and fallback.
50
 
51
+ Models:
52
+ - Fast (14.4K RPM): Simple Q&A, factual questions
53
+ - Reasoning (1K RPM, 12K TPM): Logic, math, complex reasoning
54
+ - Multimodal (1K RPM, 30K TPM): Images, chess, visual tasks
 
55
  """
56
 
57
+ MODELS = {
58
+ "fast": "groq/llama-3.1-8b-instant", # 14.4K RPM, 6K TPM
59
+ "reasoning": "groq/llama-3.3-70b-versatile", # 1K RPM, 12K TPM
60
+ "multimodal": "groq/meta-llama/llama-4-scout-17b-16e-instruct", # 1K RPM, 30K TPM
61
+ }
62
+
63
  def __init__(self):
64
+ print("βœ… GaiaAgent initialized with multi-model routing")
 
 
 
 
65
  self.api_key = os.getenv("GROQ_API_KEY")
 
66
  if not self.api_key:
67
  raise ValueError("GROQ_API_KEY not found in environment variables")
68
+
69
+ # Tools
 
 
 
 
 
 
 
 
 
70
  self.tools = [
71
  DuckDuckGoSearchTool(),
72
  WikipediaSearchTool(),
 
74
  PythonInterpreterTool(),
75
  FinalAnswerTool(),
76
  ]
77
+
78
+ # Rate limiting - conservative for CodeAgent's multiple API calls
 
 
 
 
 
 
 
 
79
  self.last_call_time = 0
80
+ self.min_delay = 15 # 15s between tasks (agents make 5-10+ calls internally)
81
+ self.max_retries = 2 # Reduced to 2 to avoid long waits
82
+
83
+ # Stats
84
  self.total_tasks = 0
85
  self.successful_tasks = 0
86
  self.failed_tasks = 0
87
  self.rate_limit_hits = 0
88
+ self.model_usage = {"fast": 0, "reasoning": 0, "multimodal": 0}
89
+
90
+ def _route_model(self, question: str, image: Optional[str] = None) -> str:
91
+ """Intelligently route to appropriate model based on question type."""
92
+ q = question.lower()
 
 
 
 
93
 
94
+ # Multimodal: images, chess, visual content
95
+ if image or any(kw in q for kw in ["chess", "image", "picture", "photo", "visual", "diagram"]):
96
+ return "multimodal"
 
 
 
 
 
 
 
97
 
98
+ # Reasoning: complex logic, math, proofs
99
+ if any(kw in q for kw in [
100
+ "prove", "counterexample", "logic", "reasoning", "algebra",
101
+ "calculate", "derive", "theorem", "equation", "formula",
102
+ "commutative", "associative", "distributive"
103
+ ]):
104
+ return "reasoning"
105
 
106
+ # Fast: simple factual questions
107
+ return "fast"
108
+
109
+ def _create_agent(self, model_key: str) -> CodeAgent:
110
+ """Create a new agent with the specified model."""
111
+ model_id = self.MODELS[model_key]
112
+ print(f"πŸ€– Initializing {model_key} model: {model_id}")
 
113
 
114
+ model = LiteLLMModel(model_id=model_id, api_key=self.api_key)
 
 
115
 
116
+ return CodeAgent(
117
+ model=model,
118
+ tools=self.tools,
119
+ add_base_tools=True,
120
+ additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
121
+ )
122
+
123
+ def _try_model(self, agent: CodeAgent, question: str, model_key: str) -> Optional[str]:
124
+ """Try to get answer from a model with retries."""
125
+ for attempt in range(self.max_retries + 1):
126
+ try:
127
+ print(f"πŸš€ Attempt {attempt + 1}/{self.max_retries + 1} with {model_key} model")
128
+ answer = agent.run(question)
129
+ if answer:
130
+ self.model_usage[model_key] += 1
131
+ return answer
132
+
133
+ except Exception as e:
134
+ err = str(e)
135
+ print(f"⚠️ Error: {err[:200]}{'...' if len(err) > 200 else ''}")
136
+
137
+ if "rate limit" in err.lower() or "rate_limit" in err.lower():
138
+ self.rate_limit_hits += 1
139
+
140
+ # Extract wait time from error
141
+ wait_match = re.search(r'(\d+\.?\d*)\s*s', err)
142
+ wait_time = float(wait_match.group(1)) + 3 if wait_match else 20
143
+
144
+ if attempt < self.max_retries:
145
+ print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s...")
146
+ time.sleep(wait_time)
147
+ continue
148
+ else:
149
+ print(f"❌ Rate limit exhausted for {model_key}")
150
+ return None
151
+ else:
152
+ # Non-rate-limit error
153
+ if attempt < self.max_retries:
154
+ print(f"πŸ”„ Retrying in 5s...")
155
+ time.sleep(5)
156
+ continue
157
+ else:
158
+ print(f"❌ Failed after {self.max_retries + 1} attempts")
159
+ return None
160
 
161
+ return None
162
+
163
+ def __call__(self, task_id: str, question: str, image: Optional[str] = None) -> str:
164
+ """Process a task with intelligent routing and fallback."""
 
 
 
165
  self.total_tasks += 1
166
 
167
+ # Rate limiting
168
  elapsed = time.time() - self.last_call_time
169
  if elapsed < self.min_delay:
170
  wait_time = self.min_delay - elapsed
171
  print(f"⏳ Base rate limit: waiting {wait_time:.1f}s...")
172
  time.sleep(wait_time)
173
+
174
  print(f"\n{'='*70}")
175
  print(f"πŸ”Ή Task #{self.total_tasks} | ID: {task_id}")
176
  print(f"πŸ”Ή Question: {question[:120]}{'...' if len(question) > 120 else ''}")
177
  print(f"{'='*70}")
178
+
179
+ # Route to primary model
180
+ primary_model = self._route_model(question, image)
181
+ print(f"🎯 Routing to: {primary_model} model")
182
 
183
+ # Define fallback chain
184
+ fallback_chain = {
185
+ "reasoning": ["fast"],
186
+ "multimodal": ["reasoning", "fast"],
187
+ "fast": [], # No fallback for fast model
188
+ }
189
 
190
+ # Try primary model first
191
+ models_to_try = [primary_model] + fallback_chain.get(primary_model, [])
192
+
193
+ answer = None
194
+ for model_key in models_to_try:
195
  try:
196
+ agent = self._create_agent(model_key)
197
+ answer = self._try_model(agent, question, model_key)
198
 
199
  if answer:
200
  self.successful_tasks += 1
 
201
  break
202
+ else:
203
+ print(f"πŸ”„ Trying fallback model...")
204
+ time.sleep(3) # Brief pause before fallback
205
 
206
  except Exception as e:
207
+ print(f"❌ Failed to initialize {model_key}: {e}")
208
+ continue
209
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if not answer:
211
+ answer = "⚠️ All models failed to generate a valid response."
212
  self.failed_tasks += 1
213
+
 
214
  self.last_call_time = time.time()
215
 
 
216
  print(f"\n{'='*70}")
217
  print(f"πŸ“ Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
218
+ print(f"{'='*70}\n")
219
 
220
  return answer
221
+
222
  def get_stats(self) -> dict:
223
  """Get agent performance statistics."""
224
  return {
 
227
  "failed_tasks": self.failed_tasks,
228
  "success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
229
  "rate_limit_hits": self.rate_limit_hits,
230
+ "model_usage": self.model_usage,
231
  }
232
+
233
  def print_stats(self):
234
  """Print agent performance statistics."""
235
  stats = self.get_stats()
 
241
  print(f"Failed: {stats['failed_tasks']}")
242
  print(f"Success Rate: {stats['success_rate']}")
243
  print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
244
+ print(f"\nModel Usage:")
245
+ for model, count in stats['model_usage'].items():
246
+ print(f" {model.capitalize():12} {count} tasks")
247
  print(f"{'='*70}\n")
248
 
249
 
250
+ # Example usage
251
  if __name__ == "__main__":
 
252
  agent = GaiaAgent()
253
+
254
  # Test with a simple question
255
  answer = agent(
256
  task_id="test-001",
257
  question="What is 2+2? Explain your reasoning."
258
  )
259
+
260
  # Print statistics
261
  agent.print_stats()