File size: 9,233 Bytes
a80fafc
51b92fc
794795e
a80fafc
 
efbc030
a80fafc
 
efbc030
 
8cb70b7
efbc030
 
 
 
 
 
8cb70b7
a80fafc
 
efbc030
a80fafc
 
26e0177
a80fafc
 
 
f36e4c1
a80fafc
f36e4c1
 
a80fafc
 
 
 
 
 
 
 
f36e4c1
a80fafc
f36e4c1
a80fafc
 
 
 
 
794795e
d2a08ac
51b92fc
d2a08ac
 
 
 
 
794795e
d2a08ac
 
f36e4c1
794795e
d2a08ac
 
 
 
794795e
 
 
d2a08ac
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e4c1
26e0177
a80fafc
 
 
 
 
 
d2a08ac
 
 
 
 
 
 
 
 
 
26e0177
d2a08ac
 
 
f36e4c1
794795e
 
 
 
f36e4c1
d2a08ac
 
 
 
 
 
 
794795e
d2a08ac
 
 
 
794795e
d2a08ac
f36e4c1
d2a08ac
 
794795e
 
f36e4c1
51b92fc
 
 
d2a08ac
51b92fc
f36e4c1
794795e
d2a08ac
 
 
f36e4c1
 
d2a08ac
 
 
26e0177
d2a08ac
 
794795e
d2a08ac
794795e
d2a08ac
26e0177
f36e4c1
d2a08ac
 
 
 
794795e
26e0177
d2a08ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36e4c1
d2a08ac
26e0177
d2a08ac
794795e
f36e4c1
d2a08ac
51b92fc
 
d2a08ac
794795e
d2a08ac
 
f36e4c1
794795e
d2a08ac
f36e4c1
794795e
 
d2a08ac
794795e
 
 
 
d2a08ac
794795e
 
f36e4c1
794795e
 
 
 
 
 
 
d2a08ac
 
794795e
d2a08ac
794795e
26e0177
 
f36e4c1
794795e
 
d2a08ac
 
794795e
 
d2a08ac
794795e
d2a08ac
794795e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import os
import time
import re
from pathlib import Path
from typing import Optional, Union

import pandas as pd
from dotenv import load_dotenv
from tabulate import tabulate

from smolagents import (
    CodeAgent,
    DuckDuckGoSearchTool,
    FinalAnswerTool,
    LiteLLMModel,
    PythonInterpreterTool,
    WikipediaSearchTool,
)
from smolagents.tools import Tool

# Load environment variables
load_dotenv()


class ExcelToTextTool(Tool):
    """Render an Excel worksheet as a Markdown table."""
    name = "excel_to_text"
    description = "Read an Excel file and return a Markdown table of the requested sheet."
    inputs = {
        "excel_path": {"type": "string", "description": "Path to the Excel file."},
        "sheet_name": {"type": "string", "description": "Worksheet name or index. Optional.", "nullable": True},
    }
    output_type = "string"

    def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
        file_path = Path(excel_path).expanduser().resolve()
        if not file_path.is_file():
            return f"Error: Excel file not found at {file_path}"
        try:
            sheet: Union[str, int] = int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
            df = pd.read_excel(file_path, sheet_name=sheet)
            return df.to_markdown(index=False) if hasattr(df, "to_markdown") else tabulate(df, headers="keys", tablefmt="github", showindex=False)
        except Exception as e:
            return f"Error reading Excel file: {e}"


class GaiaAgent:
    """
    Single-model agent using Llama 4 Scout exclusively.
    
    Why Llama 4 Scout:
    - 30K TPM (highest available - 5x more than llama-3.1-8b)
    - 500K context window
    - Multimodal support (images, chess)
    - 1K RPM
    
    This avoids the 6K TPM bottleneck of llama-3.1-8b-instant.
    """

    def __init__(self):
        print("="*70)
        print("βœ… GaiaAgent initialized with Llama 4 Scout (30K TPM)")
        print("="*70)
        
        self.api_key = os.getenv("GROQ_API_KEY")
        if not self.api_key:
            raise ValueError("GROQ_API_KEY not found in environment variables")
        
        # Single model configuration - Llama 4 Scout for all tasks
        self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
        
        print(f"πŸ€– Model: {self.model_id}")
        print(f"πŸ“Š Limits: 30K TPM | 1K RPM | 500K context | Multimodal")
        print("="*70 + "\n")
        
        # Initialize model
        self.model = LiteLLMModel(
            model_id=self.model_id,
            api_key=self.api_key,
        )
        
        # Tools
        self.tools = [
            DuckDuckGoSearchTool(),
            WikipediaSearchTool(),
            ExcelToTextTool(),
            PythonInterpreterTool(),
            FinalAnswerTool(),
        ]
        
        # Create agent
        self.agent = CodeAgent(
            model=self.model,
            tools=self.tools,
            add_base_tools=True,
            additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
        )
        
        # Rate limiting - 30K TPM is generous but agents make multiple calls
        self.last_call_time = 0
        self.min_delay = 10  # 10s between tasks (reasonable with 30K TPM)
        self.max_retries = 3  # More retries since we have higher TPM
        
        # Stats
        self.total_tasks = 0
        self.successful_tasks = 0
        self.failed_tasks = 0
        self.rate_limit_hits = 0

    def _extract_wait_time(self, error_str: str) -> float:
        """Extract wait time from rate limit error message."""
        patterns = [
            r'try again in (\d+\.?\d*)\s*s',
            r'retry in (\d+\.?\d*)\s*s',
            r'(\d+\.?\d*)\s*s',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, error_str)
            if match:
                return float(match.group(1)) + 5  # Add 5s buffer
        
        return 30  # Default fallback

    def __call__(self, task_id: str, question: str) -> str:
        """Process a task with automatic rate limiting and retry."""
        self.total_tasks += 1
        
        # Rate limiting
        elapsed = time.time() - self.last_call_time
        if elapsed < self.min_delay:
            wait_time = self.min_delay - elapsed
            print(f"⏳ Rate limit: waiting {wait_time:.1f}s...")
            time.sleep(wait_time)

        print(f"\n{'='*70}")
        print(f"πŸ“‹ Task #{self.total_tasks} | ID: {task_id}")
        print(f"❓ Question: {question[:150]}{'...' if len(question) > 150 else ''}")
        print(f"{'='*70}\n")

        answer = None
        
        # Retry loop with exponential backoff
        for attempt in range(self.max_retries + 1):
            try:
                print(f"πŸš€ Attempt {attempt + 1}/{self.max_retries + 1}")
                answer = self.agent.run(question)
                
                if answer and len(str(answer).strip()) > 0:
                    self.successful_tasks += 1
                    print(f"βœ… Success!")
                    break
                else:
                    print(f"⚠️ Empty answer received")
                    if attempt < self.max_retries:
                        time.sleep(5)
                        continue
                    
            except Exception as e:
                error_str = str(e)
                
                # Show condensed error
                if len(error_str) > 300:
                    print(f"❌ Error: {error_str[:300]}...")
                else:
                    print(f"❌ Error: {error_str}")
                
                # Check if it's a rate limit error
                if "rate limit" in error_str.lower() or "rate_limit" in error_str.lower():
                    self.rate_limit_hits += 1
                    wait_time = self._extract_wait_time(error_str)
                    
                    if attempt < self.max_retries:
                        print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s before retry...")
                        
                        # Show countdown for long waits
                        if wait_time > 10:
                            for remaining in range(int(wait_time), 0, -5):
                                print(f"   ⏱️  {remaining}s remaining...", flush=True)
                                time.sleep(5)
                        else:
                            time.sleep(wait_time)
                        
                        print(f"πŸ”„ Retrying...")
                        continue
                    else:
                        answer = "⚠️ Rate limit exceeded after all retries."
                        self.failed_tasks += 1
                        
                # Authentication error
                elif "authentication" in error_str.lower() or "api key" in error_str.lower():
                    answer = "⚠️ Authentication failed. Check your GROQ_API_KEY."
                    self.failed_tasks += 1
                    break
                    
                # Other errors
                else:
                    if attempt < self.max_retries:
                        print(f"πŸ”„ Retrying in 5s...")
                        time.sleep(5)
                        continue
                    else:
                        answer = f"⚠️ Failed after {self.max_retries + 1} attempts."
                        self.failed_tasks += 1

        # Fallback
        if not answer:
            answer = "⚠️ Could not generate a valid response."
            self.failed_tasks += 1

        # Update timing
        self.last_call_time = time.time()
        
        # Print result
        print(f"\n{'='*70}")
        answer_preview = str(answer)[:250] + ('...' if len(str(answer)) > 250 else '')
        print(f"✍️  Answer: {answer_preview}")
        print(f"{'='*70}\n")
        
        return str(answer)

    def get_stats(self) -> dict:
        """Get agent performance statistics."""
        success_rate = (self.successful_tasks / self.total_tasks * 100) if self.total_tasks > 0 else 0
        return {
            "total_tasks": self.total_tasks,
            "successful_tasks": self.successful_tasks,
            "failed_tasks": self.failed_tasks,
            "success_rate": f"{success_rate:.1f}%",
            "rate_limit_hits": self.rate_limit_hits,
        }

    def print_stats(self):
        """Print agent performance statistics."""
        stats = self.get_stats()
        print(f"\n{'='*70}")
        print(f"πŸ“Š AGENT STATISTICS")
        print(f"{'='*70}")
        print(f"Total Tasks:       {stats['total_tasks']}")
        print(f"Successful:        {stats['successful_tasks']} βœ…")
        print(f"Failed:            {stats['failed_tasks']} ❌")
        print(f"Success Rate:      {stats['success_rate']}")
        print(f"Rate Limit Hits:   {stats['rate_limit_hits']} 🚫")
        print(f"{'='*70}\n")


# Example usage
if __name__ == "__main__":
    agent = GaiaAgent()
    
    # Test
    answer = agent(
        task_id="test-001",
        question="What is 2+2? Show your calculation."
    )
    
    agent.print_stats()