Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -24,82 +24,49 @@ load_dotenv()
|
|
| 24 |
|
| 25 |
class ExcelToTextTool(Tool):
|
| 26 |
"""Render an Excel worksheet as a Markdown table."""
|
| 27 |
-
|
| 28 |
name = "excel_to_text"
|
| 29 |
-
description =
|
| 30 |
-
"Read an Excel file and return a Markdown table of the requested sheet. "
|
| 31 |
-
"Accepts either the sheet name or a zero-based index (as a string)."
|
| 32 |
-
)
|
| 33 |
inputs = {
|
| 34 |
-
"excel_path": {
|
| 35 |
-
|
| 36 |
-
"description": "Path to the Excel file (.xlsx or .xls).",
|
| 37 |
-
},
|
| 38 |
-
"sheet_name": {
|
| 39 |
-
"type": "string",
|
| 40 |
-
"description": (
|
| 41 |
-
"Worksheet name or zero-based index (as a string). "
|
| 42 |
-
"Optional; defaults to the first sheet."
|
| 43 |
-
),
|
| 44 |
-
"nullable": True,
|
| 45 |
-
},
|
| 46 |
}
|
| 47 |
output_type = "string"
|
| 48 |
|
| 49 |
def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
|
| 50 |
-
"""Load the Excel file and return the sheet as a Markdown table."""
|
| 51 |
file_path = Path(excel_path).expanduser().resolve()
|
| 52 |
-
|
| 53 |
if not file_path.is_file():
|
| 54 |
return f"Error: Excel file not found at {file_path}"
|
| 55 |
-
|
| 56 |
try:
|
| 57 |
-
sheet: Union[str, int] = (
|
| 58 |
-
int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
|
| 59 |
-
)
|
| 60 |
df = pd.read_excel(file_path, sheet_name=sheet)
|
| 61 |
-
|
| 62 |
-
if hasattr(df, "to_markdown"):
|
| 63 |
-
return df.to_markdown(index=False)
|
| 64 |
-
|
| 65 |
-
return tabulate(df, headers="keys", tablefmt="github", showindex=False)
|
| 66 |
-
|
| 67 |
except Exception as e:
|
| 68 |
return f"Error reading Excel file: {e}"
|
| 69 |
|
| 70 |
|
| 71 |
class GaiaAgent:
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
-
|
| 77 |
-
-
|
| 78 |
-
-
|
| 79 |
-
- Support for text and image inputs
|
| 80 |
"""
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def __init__(self):
|
| 83 |
-
"
|
| 84 |
-
print("β
GaiaAgent initialized with Llama 4 Scout (30K TPM, Multimodal)")
|
| 85 |
-
|
| 86 |
-
# Model configuration
|
| 87 |
-
self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
|
| 88 |
self.api_key = os.getenv("GROQ_API_KEY")
|
| 89 |
-
|
| 90 |
if not self.api_key:
|
| 91 |
raise ValueError("GROQ_API_KEY not found in environment variables")
|
| 92 |
-
|
| 93 |
-
#
|
| 94 |
-
self.model = LiteLLMModel(
|
| 95 |
-
model_id=self.model_id,
|
| 96 |
-
api_key=self.api_key,
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
print(f"π€ Using model: {self.model_id}")
|
| 100 |
-
print(f"π Limits: 30K TPM | 1K RPM | 500K context")
|
| 101 |
-
|
| 102 |
-
# Initialize tools
|
| 103 |
self.tools = [
|
| 104 |
DuckDuckGoSearchTool(),
|
| 105 |
WikipediaSearchTool(),
|
|
@@ -107,146 +74,151 @@ class GaiaAgent:
|
|
| 107 |
PythonInterpreterTool(),
|
| 108 |
FinalAnswerTool(),
|
| 109 |
]
|
| 110 |
-
|
| 111 |
-
#
|
| 112 |
-
self.agent = CodeAgent(
|
| 113 |
-
model=self.model,
|
| 114 |
-
tools=self.tools,
|
| 115 |
-
add_base_tools=True,
|
| 116 |
-
additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
# Rate limiting configuration (optimized for 30K TPM)
|
| 120 |
self.last_call_time = 0
|
| 121 |
-
self.min_delay =
|
| 122 |
-
self.max_retries =
|
| 123 |
-
|
| 124 |
-
#
|
| 125 |
self.total_tasks = 0
|
| 126 |
self.successful_tasks = 0
|
| 127 |
self.failed_tasks = 0
|
| 128 |
self.rate_limit_hits = 0
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# Show countdown
|
| 135 |
-
for remaining in range(int(wait_time), 0, -1):
|
| 136 |
-
print(f"\rβ³ Rate limit: waiting {remaining}s... ", end="", flush=True)
|
| 137 |
-
time.sleep(1)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
"""Extract wait time from rate limit error message."""
|
| 143 |
-
# Look for patterns like "3.675499999s" or "try again in 3.6s"
|
| 144 |
-
patterns = [
|
| 145 |
-
r'(\d+\.?\d*)\s*s',
|
| 146 |
-
r'try again in (\d+\.?\d*)',
|
| 147 |
-
r'retry in (\d+\.?\d*)',
|
| 148 |
-
]
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
wait_time = self._extract_wait_time(error_str)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
if wait_time == 15:
|
| 166 |
-
wait_time = 15 * (attempt + 1) # 15s, 30s, 45s, 60s
|
| 167 |
|
| 168 |
-
return
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
Returns:
|
| 179 |
-
The answer string
|
| 180 |
-
"""
|
| 181 |
self.total_tasks += 1
|
| 182 |
|
| 183 |
-
#
|
| 184 |
elapsed = time.time() - self.last_call_time
|
| 185 |
if elapsed < self.min_delay:
|
| 186 |
wait_time = self.min_delay - elapsed
|
| 187 |
print(f"β³ Base rate limit: waiting {wait_time:.1f}s...")
|
| 188 |
time.sleep(wait_time)
|
| 189 |
-
|
| 190 |
print(f"\n{'='*70}")
|
| 191 |
print(f"πΉ Task #{self.total_tasks} | ID: {task_id}")
|
| 192 |
print(f"πΉ Question: {question[:120]}{'...' if len(question) > 120 else ''}")
|
| 193 |
print(f"{'='*70}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
| 199 |
try:
|
| 200 |
-
|
| 201 |
-
answer = self.agent
|
| 202 |
|
| 203 |
if answer:
|
| 204 |
self.successful_tasks += 1
|
| 205 |
-
print(f"β
Success!")
|
| 206 |
break
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
except Exception as e:
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
# Check if it's a rate limit error
|
| 214 |
-
if any(keyword in error_str.lower() for keyword in ['rate_limit', 'rate limit', 'quota']):
|
| 215 |
-
wait_time = self._handle_rate_limit_error(error_str, attempt)
|
| 216 |
-
|
| 217 |
-
if attempt < self.max_retries:
|
| 218 |
-
print(f"π Rate limit detected. Retrying after wait...")
|
| 219 |
-
self._wait_for_rate_limit(wait_time)
|
| 220 |
-
continue
|
| 221 |
-
else:
|
| 222 |
-
answer = f"β οΈ Rate limit exceeded after {self.max_retries + 1} attempts. Please try again later."
|
| 223 |
-
self.failed_tasks += 1
|
| 224 |
-
|
| 225 |
-
# Non-rate-limit error
|
| 226 |
-
else:
|
| 227 |
-
if attempt < self.max_retries:
|
| 228 |
-
print(f"π Retrying in 5s...")
|
| 229 |
-
time.sleep(5)
|
| 230 |
-
continue
|
| 231 |
-
else:
|
| 232 |
-
answer = f"β οΈ Agent failed: {error_str[:300]}"
|
| 233 |
-
self.failed_tasks += 1
|
| 234 |
-
|
| 235 |
-
# Fallback if no answer generated
|
| 236 |
if not answer:
|
| 237 |
-
answer = "β οΈ
|
| 238 |
self.failed_tasks += 1
|
| 239 |
-
|
| 240 |
-
# Update timing
|
| 241 |
self.last_call_time = time.time()
|
| 242 |
|
| 243 |
-
# Print results
|
| 244 |
print(f"\n{'='*70}")
|
| 245 |
print(f"π Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
|
| 246 |
-
print(f"{'='*70}")
|
| 247 |
|
| 248 |
return answer
|
| 249 |
-
|
| 250 |
def get_stats(self) -> dict:
|
| 251 |
"""Get agent performance statistics."""
|
| 252 |
return {
|
|
@@ -255,8 +227,9 @@ class GaiaAgent:
|
|
| 255 |
"failed_tasks": self.failed_tasks,
|
| 256 |
"success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
|
| 257 |
"rate_limit_hits": self.rate_limit_hits,
|
|
|
|
| 258 |
}
|
| 259 |
-
|
| 260 |
def print_stats(self):
|
| 261 |
"""Print agent performance statistics."""
|
| 262 |
stats = self.get_stats()
|
|
@@ -268,19 +241,21 @@ class GaiaAgent:
|
|
| 268 |
print(f"Failed: {stats['failed_tasks']}")
|
| 269 |
print(f"Success Rate: {stats['success_rate']}")
|
| 270 |
print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
|
|
|
|
|
|
|
|
|
|
| 271 |
print(f"{'='*70}\n")
|
| 272 |
|
| 273 |
|
| 274 |
-
# Example usage
|
| 275 |
if __name__ == "__main__":
|
| 276 |
-
# Initialize agent
|
| 277 |
agent = GaiaAgent()
|
| 278 |
-
|
| 279 |
# Test with a simple question
|
| 280 |
answer = agent(
|
| 281 |
task_id="test-001",
|
| 282 |
question="What is 2+2? Explain your reasoning."
|
| 283 |
)
|
| 284 |
-
|
| 285 |
# Print statistics
|
| 286 |
agent.print_stats()
|
|
|
|
| 24 |
|
| 25 |
class ExcelToTextTool(Tool):
|
| 26 |
"""Render an Excel worksheet as a Markdown table."""
|
|
|
|
| 27 |
name = "excel_to_text"
|
| 28 |
+
description = "Read an Excel file and return a Markdown table of the requested sheet."
|
|
|
|
|
|
|
|
|
|
| 29 |
inputs = {
|
| 30 |
+
"excel_path": {"type": "string", "description": "Path to the Excel file."},
|
| 31 |
+
"sheet_name": {"type": "string", "description": "Worksheet name or index. Optional.", "nullable": True},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
}
|
| 33 |
output_type = "string"
|
| 34 |
|
| 35 |
def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
|
|
|
|
| 36 |
file_path = Path(excel_path).expanduser().resolve()
|
|
|
|
| 37 |
if not file_path.is_file():
|
| 38 |
return f"Error: Excel file not found at {file_path}"
|
|
|
|
| 39 |
try:
|
| 40 |
+
sheet: Union[str, int] = int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
|
|
|
|
|
|
|
| 41 |
df = pd.read_excel(file_path, sheet_name=sheet)
|
| 42 |
+
return df.to_markdown(index=False) if hasattr(df, "to_markdown") else tabulate(df, headers="keys", tablefmt="github", showindex=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
return f"Error reading Excel file: {e}"
|
| 45 |
|
| 46 |
|
| 47 |
class GaiaAgent:
|
| 48 |
"""
|
| 49 |
+
Multi-model agent with intelligent routing and fallback.
|
| 50 |
|
| 51 |
+
Models:
|
| 52 |
+
- Fast (14.4K RPM): Simple Q&A, factual questions
|
| 53 |
+
- Reasoning (1K RPM, 12K TPM): Logic, math, complex reasoning
|
| 54 |
+
- Multimodal (1K RPM, 30K TPM): Images, chess, visual tasks
|
|
|
|
| 55 |
"""
|
| 56 |
|
| 57 |
+
MODELS = {
|
| 58 |
+
"fast": "groq/llama-3.1-8b-instant", # 14.4K RPM, 6K TPM
|
| 59 |
+
"reasoning": "groq/llama-3.3-70b-versatile", # 1K RPM, 12K TPM
|
| 60 |
+
"multimodal": "groq/meta-llama/llama-4-scout-17b-16e-instruct", # 1K RPM, 30K TPM
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
def __init__(self):
|
| 64 |
+
print("β
GaiaAgent initialized with multi-model routing")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
self.api_key = os.getenv("GROQ_API_KEY")
|
|
|
|
| 66 |
if not self.api_key:
|
| 67 |
raise ValueError("GROQ_API_KEY not found in environment variables")
|
| 68 |
+
|
| 69 |
+
# Tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.tools = [
|
| 71 |
DuckDuckGoSearchTool(),
|
| 72 |
WikipediaSearchTool(),
|
|
|
|
| 74 |
PythonInterpreterTool(),
|
| 75 |
FinalAnswerTool(),
|
| 76 |
]
|
| 77 |
+
|
| 78 |
+
# Rate limiting - conservative for CodeAgent's multiple API calls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
self.last_call_time = 0
|
| 80 |
+
self.min_delay = 15 # 15s between tasks (agents make 5-10+ calls internally)
|
| 81 |
+
self.max_retries = 2 # Reduced to 2 to avoid long waits
|
| 82 |
+
|
| 83 |
+
# Stats
|
| 84 |
self.total_tasks = 0
|
| 85 |
self.successful_tasks = 0
|
| 86 |
self.failed_tasks = 0
|
| 87 |
self.rate_limit_hits = 0
|
| 88 |
+
self.model_usage = {"fast": 0, "reasoning": 0, "multimodal": 0}
|
| 89 |
+
|
| 90 |
+
def _route_model(self, question: str, image: Optional[str] = None) -> str:
|
| 91 |
+
"""Intelligently route to appropriate model based on question type."""
|
| 92 |
+
q = question.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
# Multimodal: images, chess, visual content
|
| 95 |
+
if image or any(kw in q for kw in ["chess", "image", "picture", "photo", "visual", "diagram"]):
|
| 96 |
+
return "multimodal"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
# Reasoning: complex logic, math, proofs
|
| 99 |
+
if any(kw in q for kw in [
|
| 100 |
+
"prove", "counterexample", "logic", "reasoning", "algebra",
|
| 101 |
+
"calculate", "derive", "theorem", "equation", "formula",
|
| 102 |
+
"commutative", "associative", "distributive"
|
| 103 |
+
]):
|
| 104 |
+
return "reasoning"
|
| 105 |
|
| 106 |
+
# Fast: simple factual questions
|
| 107 |
+
return "fast"
|
| 108 |
+
|
| 109 |
+
def _create_agent(self, model_key: str) -> CodeAgent:
|
| 110 |
+
"""Create a new agent with the specified model."""
|
| 111 |
+
model_id = self.MODELS[model_key]
|
| 112 |
+
print(f"π€ Initializing {model_key} model: {model_id}")
|
|
|
|
| 113 |
|
| 114 |
+
model = LiteLLMModel(model_id=model_id, api_key=self.api_key)
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
return CodeAgent(
|
| 117 |
+
model=model,
|
| 118 |
+
tools=self.tools,
|
| 119 |
+
add_base_tools=True,
|
| 120 |
+
additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def _try_model(self, agent: CodeAgent, question: str, model_key: str) -> Optional[str]:
|
| 124 |
+
"""Try to get answer from a model with retries."""
|
| 125 |
+
for attempt in range(self.max_retries + 1):
|
| 126 |
+
try:
|
| 127 |
+
print(f"π Attempt {attempt + 1}/{self.max_retries + 1} with {model_key} model")
|
| 128 |
+
answer = agent.run(question)
|
| 129 |
+
if answer:
|
| 130 |
+
self.model_usage[model_key] += 1
|
| 131 |
+
return answer
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
err = str(e)
|
| 135 |
+
print(f"β οΈ Error: {err[:200]}{'...' if len(err) > 200 else ''}")
|
| 136 |
+
|
| 137 |
+
if "rate limit" in err.lower() or "rate_limit" in err.lower():
|
| 138 |
+
self.rate_limit_hits += 1
|
| 139 |
+
|
| 140 |
+
# Extract wait time from error
|
| 141 |
+
wait_match = re.search(r'(\d+\.?\d*)\s*s', err)
|
| 142 |
+
wait_time = float(wait_match.group(1)) + 3 if wait_match else 20
|
| 143 |
+
|
| 144 |
+
if attempt < self.max_retries:
|
| 145 |
+
print(f"β³ Rate limit hit. Waiting {wait_time:.1f}s...")
|
| 146 |
+
time.sleep(wait_time)
|
| 147 |
+
continue
|
| 148 |
+
else:
|
| 149 |
+
print(f"β Rate limit exhausted for {model_key}")
|
| 150 |
+
return None
|
| 151 |
+
else:
|
| 152 |
+
# Non-rate-limit error
|
| 153 |
+
if attempt < self.max_retries:
|
| 154 |
+
print(f"π Retrying in 5s...")
|
| 155 |
+
time.sleep(5)
|
| 156 |
+
continue
|
| 157 |
+
else:
|
| 158 |
+
print(f"β Failed after {self.max_retries + 1} attempts")
|
| 159 |
+
return None
|
| 160 |
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
def __call__(self, task_id: str, question: str, image: Optional[str] = None) -> str:
|
| 164 |
+
"""Process a task with intelligent routing and fallback."""
|
|
|
|
|
|
|
|
|
|
| 165 |
self.total_tasks += 1
|
| 166 |
|
| 167 |
+
# Rate limiting
|
| 168 |
elapsed = time.time() - self.last_call_time
|
| 169 |
if elapsed < self.min_delay:
|
| 170 |
wait_time = self.min_delay - elapsed
|
| 171 |
print(f"β³ Base rate limit: waiting {wait_time:.1f}s...")
|
| 172 |
time.sleep(wait_time)
|
| 173 |
+
|
| 174 |
print(f"\n{'='*70}")
|
| 175 |
print(f"πΉ Task #{self.total_tasks} | ID: {task_id}")
|
| 176 |
print(f"πΉ Question: {question[:120]}{'...' if len(question) > 120 else ''}")
|
| 177 |
print(f"{'='*70}")
|
| 178 |
+
|
| 179 |
+
# Route to primary model
|
| 180 |
+
primary_model = self._route_model(question, image)
|
| 181 |
+
print(f"π― Routing to: {primary_model} model")
|
| 182 |
|
| 183 |
+
# Define fallback chain
|
| 184 |
+
fallback_chain = {
|
| 185 |
+
"reasoning": ["fast"],
|
| 186 |
+
"multimodal": ["reasoning", "fast"],
|
| 187 |
+
"fast": [], # No fallback for fast model
|
| 188 |
+
}
|
| 189 |
|
| 190 |
+
# Try primary model first
|
| 191 |
+
models_to_try = [primary_model] + fallback_chain.get(primary_model, [])
|
| 192 |
+
|
| 193 |
+
answer = None
|
| 194 |
+
for model_key in models_to_try:
|
| 195 |
try:
|
| 196 |
+
agent = self._create_agent(model_key)
|
| 197 |
+
answer = self._try_model(agent, question, model_key)
|
| 198 |
|
| 199 |
if answer:
|
| 200 |
self.successful_tasks += 1
|
|
|
|
| 201 |
break
|
| 202 |
+
else:
|
| 203 |
+
print(f"π Trying fallback model...")
|
| 204 |
+
time.sleep(3) # Brief pause before fallback
|
| 205 |
|
| 206 |
except Exception as e:
|
| 207 |
+
print(f"β Failed to initialize {model_key}: {e}")
|
| 208 |
+
continue
|
| 209 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
if not answer:
|
| 211 |
+
answer = "β οΈ All models failed to generate a valid response."
|
| 212 |
self.failed_tasks += 1
|
| 213 |
+
|
|
|
|
| 214 |
self.last_call_time = time.time()
|
| 215 |
|
|
|
|
| 216 |
print(f"\n{'='*70}")
|
| 217 |
print(f"π Answer: {str(answer)[:200]}{'...' if len(str(answer)) > 200 else ''}")
|
| 218 |
+
print(f"{'='*70}\n")
|
| 219 |
|
| 220 |
return answer
|
| 221 |
+
|
| 222 |
def get_stats(self) -> dict:
|
| 223 |
"""Get agent performance statistics."""
|
| 224 |
return {
|
|
|
|
| 227 |
"failed_tasks": self.failed_tasks,
|
| 228 |
"success_rate": f"{(self.successful_tasks / self.total_tasks * 100):.1f}%" if self.total_tasks > 0 else "0%",
|
| 229 |
"rate_limit_hits": self.rate_limit_hits,
|
| 230 |
+
"model_usage": self.model_usage,
|
| 231 |
}
|
| 232 |
+
|
| 233 |
def print_stats(self):
|
| 234 |
"""Print agent performance statistics."""
|
| 235 |
stats = self.get_stats()
|
|
|
|
| 241 |
print(f"Failed: {stats['failed_tasks']}")
|
| 242 |
print(f"Success Rate: {stats['success_rate']}")
|
| 243 |
print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
|
| 244 |
+
print(f"\nModel Usage:")
|
| 245 |
+
for model, count in stats['model_usage'].items():
|
| 246 |
+
print(f" {model.capitalize():12} {count} tasks")
|
| 247 |
print(f"{'='*70}\n")
|
| 248 |
|
| 249 |
|
| 250 |
+
# Example usage
|
| 251 |
if __name__ == "__main__":
|
|
|
|
| 252 |
agent = GaiaAgent()
|
| 253 |
+
|
| 254 |
# Test with a simple question
|
| 255 |
answer = agent(
|
| 256 |
task_id="test-001",
|
| 257 |
question="What is 2+2? Explain your reasoning."
|
| 258 |
)
|
| 259 |
+
|
| 260 |
# Print statistics
|
| 261 |
agent.print_stats()
|