Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -46,26 +46,39 @@ class ExcelToTextTool(Tool):
|
|
| 46 |
|
| 47 |
class GaiaAgent:
|
| 48 |
"""
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
-
|
| 53 |
-
-
|
| 54 |
-
- Multimodal (
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
"reasoning": "groq/llama-3.3-70b-versatile", # 1K RPM, 12K TPM
|
| 60 |
-
"multimodal": "groq/meta-llama/llama-4-scout-17b-16e-instruct", # 1K RPM, 30K TPM
|
| 61 |
-
}
|
| 62 |
|
| 63 |
def __init__(self):
|
| 64 |
-
print("
|
|
|
|
|
|
|
|
|
|
| 65 |
self.api_key = os.getenv("GROQ_API_KEY")
|
| 66 |
if not self.api_key:
|
| 67 |
raise ValueError("GROQ_API_KEY not found in environment variables")
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# Tools
|
| 70 |
self.tools = [
|
| 71 |
DuckDuckGoSearchTool(),
|
|
@@ -74,160 +87,147 @@ class GaiaAgent:
|
|
| 74 |
PythonInterpreterTool(),
|
| 75 |
FinalAnswerTool(),
|
| 76 |
]
|
| 77 |
-
|
| 78 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
self.last_call_time = 0
|
| 80 |
-
self.min_delay =
|
| 81 |
-
self.max_retries =
|
| 82 |
-
|
| 83 |
# Stats
|
| 84 |
self.total_tasks = 0
|
| 85 |
self.successful_tasks = 0
|
| 86 |
self.failed_tasks = 0
|
| 87 |
self.rate_limit_hits = 0
|
| 88 |
-
self.model_usage = {"fast": 0, "reasoning": 0, "multimodal": 0}
|
| 89 |
|
| 90 |
-
def
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# Reasoning: complex logic, math, proofs
|
| 99 |
-
if any(kw in q for kw in [
|
| 100 |
-
"prove", "counterexample", "logic", "reasoning", "algebra",
|
| 101 |
-
"calculate", "derive", "theorem", "equation", "formula",
|
| 102 |
-
"commutative", "associative", "distributive"
|
| 103 |
-
]):
|
| 104 |
-
return "reasoning"
|
| 105 |
-
|
| 106 |
-
# Fast: simple factual questions
|
| 107 |
-
return "fast"
|
| 108 |
-
|
| 109 |
-
def _create_agent(self, model_key: str) -> CodeAgent:
|
| 110 |
-
"""Create a new agent with the specified model."""
|
| 111 |
-
model_id = self.MODELS[model_key]
|
| 112 |
-
print(f"π€ Initializing {model_key} model: {model_id}")
|
| 113 |
-
|
| 114 |
-
model = LiteLLMModel(model_id=model_id, api_key=self.api_key)
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
def _try_model(self, agent: CodeAgent, question: str, model_key: str) -> Optional[str]:
|
| 124 |
-
"""Try to get answer from a model with retries."""
|
| 125 |
-
for attempt in range(self.max_retries + 1):
|
| 126 |
-
try:
|
| 127 |
-
print(f"π Attempt {attempt + 1}/{self.max_retries + 1} with {model_key} model")
|
| 128 |
-
answer = agent.run(question)
|
| 129 |
-
if answer:
|
| 130 |
-
self.model_usage[model_key] += 1
|
| 131 |
-
return answer
|
| 132 |
-
|
| 133 |
-
except Exception as e:
|
| 134 |
-
err = str(e)
|
| 135 |
-
print(f"β οΈ Error: {err[:200]}{'...' if len(err) > 200 else ''}")
|
| 136 |
-
|
| 137 |
-
if "rate limit" in err.lower() or "rate_limit" in err.lower():
|
| 138 |
-
self.rate_limit_hits += 1
|
| 139 |
-
|
| 140 |
-
# Extract wait time from error
|
| 141 |
-
wait_match = re.search(r'(\d+\.?\d*)\s*s', err)
|
| 142 |
-
wait_time = float(wait_match.group(1)) + 3 if wait_match else 20
|
| 143 |
-
|
| 144 |
-
if attempt < self.max_retries:
|
| 145 |
-
print(f"β³ Rate limit hit. Waiting {wait_time:.1f}s...")
|
| 146 |
-
time.sleep(wait_time)
|
| 147 |
-
continue
|
| 148 |
-
else:
|
| 149 |
-
print(f"β Rate limit exhausted for {model_key}")
|
| 150 |
-
return None
|
| 151 |
-
else:
|
| 152 |
-
# Non-rate-limit error
|
| 153 |
-
if attempt < self.max_retries:
|
| 154 |
-
print(f"π Retrying in 5s...")
|
| 155 |
-
time.sleep(5)
|
| 156 |
-
continue
|
| 157 |
-
else:
|
| 158 |
-
print(f"β Failed after {self.max_retries + 1} attempts")
|
| 159 |
-
return None
|
| 160 |
|
| 161 |
-
return
|
| 162 |
|
| 163 |
-
def __call__(self, task_id: str, question: str
|
| 164 |
-
"""Process a task with
|
| 165 |
self.total_tasks += 1
|
| 166 |
|
| 167 |
# Rate limiting
|
| 168 |
elapsed = time.time() - self.last_call_time
|
| 169 |
if elapsed < self.min_delay:
|
| 170 |
wait_time = self.min_delay - elapsed
|
| 171 |
-
print(f"β³
|
| 172 |
time.sleep(wait_time)
|
| 173 |
|
| 174 |
print(f"\n{'='*70}")
|
| 175 |
-
print(f"
|
| 176 |
-
print(f"
|
| 177 |
-
print(f"{'='*70}")
|
| 178 |
|
| 179 |
-
# Route to primary model
|
| 180 |
-
primary_model = self._route_model(question, image)
|
| 181 |
-
print(f"π― Routing to: {primary_model} model")
|
| 182 |
-
|
| 183 |
-
# Define fallback chain
|
| 184 |
-
fallback_chain = {
|
| 185 |
-
"reasoning": ["fast"],
|
| 186 |
-
"multimodal": ["reasoning", "fast"],
|
| 187 |
-
"fast": [], # No fallback for fast model
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
# Try primary model first
|
| 191 |
-
models_to_try = [primary_model] + fallback_chain.get(primary_model, [])
|
| 192 |
-
|
| 193 |
answer = None
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
try:
|
| 196 |
-
|
| 197 |
-
answer = self.
|
| 198 |
|
| 199 |
-
if answer:
|
| 200 |
self.successful_tasks += 1
|
|
|
|
| 201 |
break
|
| 202 |
else:
|
| 203 |
-
print(f"
|
| 204 |
-
|
|
|
|
|
|
|
| 205 |
|
| 206 |
except Exception as e:
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
|
|
|
| 210 |
if not answer:
|
| 211 |
-
answer = "β οΈ
|
| 212 |
self.failed_tasks += 1
|
| 213 |
|
|
|
|
| 214 |
self.last_call_time = time.time()
|
| 215 |
|
|
|
|
| 216 |
print(f"\n{'='*70}")
|
| 217 |
-
|
|
|
|
| 218 |
print(f"{'='*70}\n")
|
| 219 |
|
| 220 |
-
return answer
|
| 221 |
|
| 222 |
def get_stats(self) -> dict:
|
| 223 |
"""Get agent performance statistics."""
|
|
|
|
| 224 |
return {
|
| 225 |
"total_tasks": self.total_tasks,
|
| 226 |
"successful_tasks": self.successful_tasks,
|
| 227 |
"failed_tasks": self.failed_tasks,
|
| 228 |
-
"success_rate": f"{
|
| 229 |
"rate_limit_hits": self.rate_limit_hits,
|
| 230 |
-
"model_usage": self.model_usage,
|
| 231 |
}
|
| 232 |
|
| 233 |
def print_stats(self):
|
|
@@ -237,25 +237,21 @@ class GaiaAgent:
|
|
| 237 |
print(f"π AGENT STATISTICS")
|
| 238 |
print(f"{'='*70}")
|
| 239 |
print(f"Total Tasks: {stats['total_tasks']}")
|
| 240 |
-
print(f"Successful: {stats['successful_tasks']}")
|
| 241 |
-
print(f"Failed: {stats['failed_tasks']}")
|
| 242 |
print(f"Success Rate: {stats['success_rate']}")
|
| 243 |
-
print(f"Rate Limit Hits: {stats['rate_limit_hits']}")
|
| 244 |
-
print(f"\nModel Usage:")
|
| 245 |
-
for model, count in stats['model_usage'].items():
|
| 246 |
-
print(f" {model.capitalize():12} {count} tasks")
|
| 247 |
print(f"{'='*70}\n")
|
| 248 |
|
| 249 |
|
| 250 |
# Example usage
|
| 251 |
if __name__ == "__main__":
|
| 252 |
agent = GaiaAgent()
|
| 253 |
-
|
| 254 |
-
# Test
|
| 255 |
answer = agent(
|
| 256 |
task_id="test-001",
|
| 257 |
-
question="What is 2+2?
|
| 258 |
)
|
| 259 |
-
|
| 260 |
-
# Print statistics
|
| 261 |
agent.print_stats()
|
|
|
|
| 46 |
|
| 47 |
class GaiaAgent:
|
| 48 |
"""
|
| 49 |
+
Single-model agent using Llama 4 Scout exclusively.
|
| 50 |
|
| 51 |
+
Why Llama 4 Scout:
|
| 52 |
+
- 30K TPM (highest available - 5x more than llama-3.1-8b)
|
| 53 |
+
- 500K context window
|
| 54 |
+
- Multimodal support (images, chess)
|
| 55 |
+
- 1K RPM
|
| 56 |
|
| 57 |
+
This avoids the 6K TPM bottleneck of llama-3.1-8b-instant.
|
| 58 |
+
"""
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def __init__(self):
|
| 61 |
+
print("="*70)
|
| 62 |
+
print("β
GaiaAgent initialized with Llama 4 Scout (30K TPM)")
|
| 63 |
+
print("="*70)
|
| 64 |
+
|
| 65 |
self.api_key = os.getenv("GROQ_API_KEY")
|
| 66 |
if not self.api_key:
|
| 67 |
raise ValueError("GROQ_API_KEY not found in environment variables")
|
| 68 |
+
|
| 69 |
+
# Single model configuration - Llama 4 Scout for all tasks
|
| 70 |
+
self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
|
| 71 |
+
|
| 72 |
+
print(f"π€ Model: {self.model_id}")
|
| 73 |
+
print(f"π Limits: 30K TPM | 1K RPM | 500K context | Multimodal")
|
| 74 |
+
print("="*70 + "\n")
|
| 75 |
+
|
| 76 |
+
# Initialize model
|
| 77 |
+
self.model = LiteLLMModel(
|
| 78 |
+
model_id=self.model_id,
|
| 79 |
+
api_key=self.api_key,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
# Tools
|
| 83 |
self.tools = [
|
| 84 |
DuckDuckGoSearchTool(),
|
|
|
|
| 87 |
PythonInterpreterTool(),
|
| 88 |
FinalAnswerTool(),
|
| 89 |
]
|
| 90 |
+
|
| 91 |
+
# Create agent
|
| 92 |
+
self.agent = CodeAgent(
|
| 93 |
+
model=self.model,
|
| 94 |
+
tools=self.tools,
|
| 95 |
+
add_base_tools=True,
|
| 96 |
+
additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Rate limiting - 30K TPM is generous but agents make multiple calls
|
| 100 |
self.last_call_time = 0
|
| 101 |
+
self.min_delay = 10 # 10s between tasks (reasonable with 30K TPM)
|
| 102 |
+
self.max_retries = 3 # More retries since we have higher TPM
|
| 103 |
+
|
| 104 |
# Stats
|
| 105 |
self.total_tasks = 0
|
| 106 |
self.successful_tasks = 0
|
| 107 |
self.failed_tasks = 0
|
| 108 |
self.rate_limit_hits = 0
|
|
|
|
| 109 |
|
| 110 |
+
def _extract_wait_time(self, error_str: str) -> float:
|
| 111 |
+
"""Extract wait time from rate limit error message."""
|
| 112 |
+
patterns = [
|
| 113 |
+
r'try again in (\d+\.?\d*)\s*s',
|
| 114 |
+
r'retry in (\d+\.?\d*)\s*s',
|
| 115 |
+
r'(\d+\.?\d*)\s*s',
|
| 116 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
for pattern in patterns:
|
| 119 |
+
match = re.search(pattern, error_str)
|
| 120 |
+
if match:
|
| 121 |
+
return float(match.group(1)) + 5 # Add 5s buffer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
return 30 # Default fallback
|
| 124 |
|
| 125 |
+
def __call__(self, task_id: str, question: str) -> str:
|
| 126 |
+
"""Process a task with automatic rate limiting and retry."""
|
| 127 |
self.total_tasks += 1
|
| 128 |
|
| 129 |
# Rate limiting
|
| 130 |
elapsed = time.time() - self.last_call_time
|
| 131 |
if elapsed < self.min_delay:
|
| 132 |
wait_time = self.min_delay - elapsed
|
| 133 |
+
print(f"β³ Rate limit: waiting {wait_time:.1f}s...")
|
| 134 |
time.sleep(wait_time)
|
| 135 |
|
| 136 |
print(f"\n{'='*70}")
|
| 137 |
+
print(f"π Task #{self.total_tasks} | ID: {task_id}")
|
| 138 |
+
print(f"β Question: {question[:150]}{'...' if len(question) > 150 else ''}")
|
| 139 |
+
print(f"{'='*70}\n")
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
answer = None
|
| 142 |
+
|
| 143 |
+
# Retry loop with exponential backoff
|
| 144 |
+
for attempt in range(self.max_retries + 1):
|
| 145 |
try:
|
| 146 |
+
print(f"π Attempt {attempt + 1}/{self.max_retries + 1}")
|
| 147 |
+
answer = self.agent.run(question)
|
| 148 |
|
| 149 |
+
if answer and len(str(answer).strip()) > 0:
|
| 150 |
self.successful_tasks += 1
|
| 151 |
+
print(f"β
Success!")
|
| 152 |
break
|
| 153 |
else:
|
| 154 |
+
print(f"β οΈ Empty answer received")
|
| 155 |
+
if attempt < self.max_retries:
|
| 156 |
+
time.sleep(5)
|
| 157 |
+
continue
|
| 158 |
|
| 159 |
except Exception as e:
|
| 160 |
+
error_str = str(e)
|
| 161 |
+
|
| 162 |
+
# Show condensed error
|
| 163 |
+
if len(error_str) > 300:
|
| 164 |
+
print(f"β Error: {error_str[:300]}...")
|
| 165 |
+
else:
|
| 166 |
+
print(f"β Error: {error_str}")
|
| 167 |
+
|
| 168 |
+
# Check if it's a rate limit error
|
| 169 |
+
if "rate limit" in error_str.lower() or "rate_limit" in error_str.lower():
|
| 170 |
+
self.rate_limit_hits += 1
|
| 171 |
+
wait_time = self._extract_wait_time(error_str)
|
| 172 |
+
|
| 173 |
+
if attempt < self.max_retries:
|
| 174 |
+
print(f"β³ Rate limit hit. Waiting {wait_time:.1f}s before retry...")
|
| 175 |
+
|
| 176 |
+
# Show countdown for long waits
|
| 177 |
+
if wait_time > 10:
|
| 178 |
+
for remaining in range(int(wait_time), 0, -5):
|
| 179 |
+
print(f" β±οΈ {remaining}s remaining...", flush=True)
|
| 180 |
+
time.sleep(5)
|
| 181 |
+
else:
|
| 182 |
+
time.sleep(wait_time)
|
| 183 |
+
|
| 184 |
+
print(f"π Retrying...")
|
| 185 |
+
continue
|
| 186 |
+
else:
|
| 187 |
+
answer = "β οΈ Rate limit exceeded after all retries."
|
| 188 |
+
self.failed_tasks += 1
|
| 189 |
+
|
| 190 |
+
# Authentication error
|
| 191 |
+
elif "authentication" in error_str.lower() or "api key" in error_str.lower():
|
| 192 |
+
answer = "β οΈ Authentication failed. Check your GROQ_API_KEY."
|
| 193 |
+
self.failed_tasks += 1
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
# Other errors
|
| 197 |
+
else:
|
| 198 |
+
if attempt < self.max_retries:
|
| 199 |
+
print(f"π Retrying in 5s...")
|
| 200 |
+
time.sleep(5)
|
| 201 |
+
continue
|
| 202 |
+
else:
|
| 203 |
+
answer = f"β οΈ Failed after {self.max_retries + 1} attempts."
|
| 204 |
+
self.failed_tasks += 1
|
| 205 |
|
| 206 |
+
# Fallback
|
| 207 |
if not answer:
|
| 208 |
+
answer = "β οΈ Could not generate a valid response."
|
| 209 |
self.failed_tasks += 1
|
| 210 |
|
| 211 |
+
# Update timing
|
| 212 |
self.last_call_time = time.time()
|
| 213 |
|
| 214 |
+
# Print result
|
| 215 |
print(f"\n{'='*70}")
|
| 216 |
+
answer_preview = str(answer)[:250] + ('...' if len(str(answer)) > 250 else '')
|
| 217 |
+
print(f"βοΈ Answer: {answer_preview}")
|
| 218 |
print(f"{'='*70}\n")
|
| 219 |
|
| 220 |
+
return str(answer)
|
| 221 |
|
| 222 |
def get_stats(self) -> dict:
|
| 223 |
"""Get agent performance statistics."""
|
| 224 |
+
success_rate = (self.successful_tasks / self.total_tasks * 100) if self.total_tasks > 0 else 0
|
| 225 |
return {
|
| 226 |
"total_tasks": self.total_tasks,
|
| 227 |
"successful_tasks": self.successful_tasks,
|
| 228 |
"failed_tasks": self.failed_tasks,
|
| 229 |
+
"success_rate": f"{success_rate:.1f}%",
|
| 230 |
"rate_limit_hits": self.rate_limit_hits,
|
|
|
|
| 231 |
}
|
| 232 |
|
| 233 |
def print_stats(self):
|
|
|
|
| 237 |
print(f"π AGENT STATISTICS")
|
| 238 |
print(f"{'='*70}")
|
| 239 |
print(f"Total Tasks: {stats['total_tasks']}")
|
| 240 |
+
print(f"Successful: {stats['successful_tasks']} β
")
|
| 241 |
+
print(f"Failed: {stats['failed_tasks']} β")
|
| 242 |
print(f"Success Rate: {stats['success_rate']}")
|
| 243 |
+
print(f"Rate Limit Hits: {stats['rate_limit_hits']} π«")
|
|
|
|
|
|
|
|
|
|
| 244 |
print(f"{'='*70}\n")
|
| 245 |
|
| 246 |
|
| 247 |
# Example usage
|
| 248 |
if __name__ == "__main__":
|
| 249 |
agent = GaiaAgent()
|
| 250 |
+
|
| 251 |
+
# Test
|
| 252 |
answer = agent(
|
| 253 |
task_id="test-001",
|
| 254 |
+
question="What is 2+2? Show your calculation."
|
| 255 |
)
|
| 256 |
+
|
|
|
|
| 257 |
agent.print_stats()
|