Sanyam0605 commited on
Commit
35f2765
·
verified ·
1 Parent(s): 76f42c5

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +277 -0
main.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import traceback
6
+ from core_agent import GAIAAgent
7
+ from api_integration import GAIAApiClient
8
+
9
+ # Constants
10
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
+
12
+ def save_task_file(file_content, task_id):
13
+ """
14
+ Save a task file to a temporary location
15
+ """
16
+ if not file_content:
17
+ return None
18
+
19
+ # Create a temporary file
20
+ temp_dir = tempfile.gettempdir()
21
+ file_path = os.path.join(temp_dir, f"gaia_task_{task_id}.txt")
22
+
23
+ # Write content to the file
24
+ with open(file_path, 'wb') as f:
25
+ f.write(file_content)
26
+
27
+ print(f"File saved to {file_path}")
28
+ return file_path
29
+
30
+ def get_agent_configuration():
31
+ """
32
+ Get the agent configuration based on environment variables
33
+ """
34
+ # Default configuration
35
+ config = {
36
+ "model_type": "OpenAIServerModel", # Default to OpenAIServerModel
37
+ "model_id": "gpt-4o", # Default model for OpenAI
38
+ "temperature": 0.2,
39
+ "executor_type": "local",
40
+ "verbose": False,
41
+ "provider": "hf-inference", # For InferenceClientModel
42
+ "timeout": 120 # For InferenceClientModel
43
+ }
44
+
45
+ # Check for xAI API key and base URL
46
+ xai_api_key = os.getenv("XAI_API_KEY")
47
+ xai_api_base = os.getenv("XAI_API_BASE")
48
+
49
+ # If we have xAI credentials, use them
50
+ if xai_api_key:
51
+ config["api_key"] = xai_api_key
52
+ if xai_api_base:
53
+ config["api_base"] = xai_api_base
54
+ # Use a model that works well with xAI
55
+ config["model_id"] = "mixtral-8x7b-32768"
56
+
57
+ # Override with environment variables if present
58
+ if os.getenv("AGENT_MODEL_TYPE"):
59
+ config["model_type"] = os.getenv("AGENT_MODEL_TYPE")
60
+
61
+ if os.getenv("AGENT_MODEL_ID"):
62
+ config["model_id"] = os.getenv("AGENT_MODEL_ID")
63
+
64
+ if os.getenv("AGENT_TEMPERATURE"):
65
+ config["temperature"] = float(os.getenv("AGENT_TEMPERATURE"))
66
+
67
+ if os.getenv("AGENT_EXECUTOR_TYPE"):
68
+ config["executor_type"] = os.getenv("AGENT_EXECUTOR_TYPE")
69
+
70
+ if os.getenv("AGENT_VERBOSE") is not None:
71
+ config["verbose"] = os.getenv("AGENT_VERBOSE").lower() == "true"
72
+
73
+ if os.getenv("AGENT_API_BASE"):
74
+ config["api_base"] = os.getenv("AGENT_API_BASE")
75
+
76
+ # InferenceClientModel specific settings
77
+ if os.getenv("AGENT_PROVIDER"):
78
+ config["provider"] = os.getenv("AGENT_PROVIDER")
79
+
80
+ if os.getenv("AGENT_TIMEOUT"):
81
+ config["timeout"] = int(os.getenv("AGENT_TIMEOUT"))
82
+
83
+ return config
84
+
85
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
86
+ """
87
+ Fetches all questions, runs the GAIAAgent on them, submits all answers,
88
+ and displays the results.
89
+ """
90
+ # Check for user login
91
+ if not profile:
92
+ return "Please Login to Hugging Face with the button.", None
93
+
94
+ username = profile.username
95
+ print(f"User logged in: {username}")
96
+
97
+ # Get SPACE_ID for code link
98
+ space_id = os.getenv("SPACE_ID")
99
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
100
+
101
+ # Initialize API client
102
+ api_client = GAIAApiClient(DEFAULT_API_URL)
103
+
104
+ # Initialize Agent with configuration
105
+ try:
106
+ agent_config = get_agent_configuration()
107
+ print(f"Using agent configuration: {agent_config}")
108
+
109
+ agent = GAIAAgent(**agent_config)
110
+ print("Agent initialized successfully")
111
+ except Exception as e:
112
+ error_details = traceback.format_exc()
113
+ print(f"Error initializing agent: {e}\n{error_details}")
114
+ return f"Error initializing agent: {e}", None
115
+
116
+ # Fetch questions
117
+ try:
118
+ questions_data = api_client.get_questions()
119
+ if not questions_data:
120
+ return "Fetched questions list is empty or invalid format.", None
121
+ print(f"Fetched {len(questions_data)} questions.")
122
+ except Exception as e:
123
+ error_details = traceback.format_exc()
124
+ print(f"Error fetching questions: {e}\n{error_details}")
125
+ return f"Error fetching questions: {e}", None
126
+
127
+ # Run agent on questions
128
+ results_log = []
129
+ answers_payload = []
130
+ print(f"Running agent on {len(questions_data)} questions...")
131
+
132
+ # Progress tracking
133
+ total_questions = len(questions_data)
134
+ completed = 0
135
+ failed = 0
136
+
137
+ for item in questions_data:
138
+ task_id = item.get("task_id")
139
+ question_text = item.get("question")
140
+ if not task_id or question_text is None:
141
+ print(f"Skipping item with missing task_id or question: {item}")
142
+ continue
143
+
144
+ try:
145
+ # Update progress
146
+ completed += 1
147
+ print(f"Processing question {completed}/{total_questions}: Task ID {task_id}")
148
+
149
+ # Check if the question has an associated file
150
+ file_path = None
151
+ try:
152
+ file_content = api_client.get_file(task_id)
153
+ print(f"Downloaded file for task {task_id}")
154
+ file_path = save_task_file(file_content, task_id)
155
+ except Exception as file_e:
156
+ print(f"No file found for task {task_id} or error: {file_e}")
157
+
158
+ # Run the agent to get the answer
159
+ submitted_answer = agent.answer_question(question_text, file_path)
160
+
161
+ # Add to results
162
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
163
+ results_log.append({
164
+ "Task ID": task_id,
165
+ "Question": question_text,
166
+ "Submitted Answer": submitted_answer
167
+ })
168
+ except Exception as e:
169
+ # Update error count
170
+ failed += 1
171
+ error_details = traceback.format_exc()
172
+ print(f"Error running agent on task {task_id}: {e}\n{error_details}")
173
+
174
+ # Add error to results
175
+ error_msg = f"AGENT ERROR: {e}"
176
+ answers_payload.append({"task_id": task_id, "submitted_answer": error_msg})
177
+ results_log.append({
178
+ "Task ID": task_id,
179
+ "Question": question_text,
180
+ "Submitted Answer": error_msg
181
+ })
182
+
183
+ # Print summary
184
+ print(f"\nProcessing complete: {completed} questions processed, {failed} failures")
185
+
186
+ if not answers_payload:
187
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
188
+
189
+ # Submit answers
190
+ submission_data = {
191
+ "username": username.strip(),
192
+ "agent_code": agent_code,
193
+ "answers": answers_payload
194
+ }
195
+
196
+ print(f"Submitting {len(answers_payload)} answers for username '{username}'...")
197
+
198
+ try:
199
+ result_data = api_client.submit_answers(
200
+ username.strip(),
201
+ agent_code,
202
+ answers_payload
203
+ )
204
+
205
+ # Calculate success rate
206
+ correct_count = result_data.get('correct_count', 0)
207
+ total_attempted = result_data.get('total_attempted', len(answers_payload))
208
+ success_rate = (correct_count / total_attempted) * 100 if total_attempted > 0 else 0
209
+
210
+ final_status = (
211
+ f"Submission Successful!\n"
212
+ f"User: {result_data.get('username')}\n"
213
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
214
+ f"({correct_count}/{total_attempted} correct, {success_rate:.1f}% success rate)\n"
215
+ f"Message: {result_data.get('message', 'No message received.')}"
216
+ )
217
+
218
+ print("Submission successful.")
219
+ return final_status, pd.DataFrame(results_log)
220
+ except Exception as e:
221
+ error_details = traceback.format_exc()
222
+ status_message = f"Submission Failed: {e}\n{error_details}"
223
+ print(status_message)
224
+ return status_message, pd.DataFrame(results_log)
225
+
226
+ # Build Gradio Interface
227
+ with gr.Blocks() as demo:
228
+ gr.Markdown("# GAIA Agent Evaluation Runner")
229
+ gr.Markdown(
230
+ """
231
+ **Instructions:**
232
+
233
+ 1. Log in to your Hugging Face account using the button below.
234
+ 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
235
+
236
+ **Configuration:**
237
+
238
+ You can configure the agent by setting these environment variables:
239
+ - `AGENT_MODEL_TYPE`: Model type (HfApiModel, InferenceClientModel, LiteLLMModel, OpenAIServerModel)
240
+ - `AGENT_MODEL_ID`: Model ID
241
+ - `AGENT_TEMPERATURE`: Temperature for generation (0.0-1.0)
242
+ - `AGENT_EXECUTOR_TYPE`: Type of executor ('local' or 'e2b')
243
+ - `AGENT_VERBOSE`: Enable verbose logging (true/false)
244
+ - `AGENT_API_BASE`: Base URL for API calls (for OpenAIServerModel)
245
+
246
+ **xAI Support:**
247
+ - `XAI_API_KEY`: Your xAI API key
248
+ - `XAI_API_BASE`: Base URL for xAI API (default: https://api.groq.com/openai/v1)
249
+ - When using xAI, set AGENT_MODEL_TYPE=OpenAIServerModel and AGENT_MODEL_ID=mixtral-8x7b-32768
250
+
251
+ **InferenceClientModel specific settings:**
252
+ - `AGENT_PROVIDER`: Provider for InferenceClientModel (e.g., "hf-inference")
253
+ - `AGENT_TIMEOUT`: Timeout in seconds for API calls
254
+ """
255
+ )
256
+
257
+ gr.LoginButton()
258
+
259
+ run_button = gr.Button("Run Evaluation & Submit All Answers")
260
+
261
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
262
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
263
+
264
+ run_button.click(
265
+ fn=run_and_submit_all,
266
+ outputs=[status_output, results_table]
267
+ )
268
+
269
+ if __name__ == "__main__":
270
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
271
+
272
+ # Check for environment variables
273
+ config = get_agent_configuration()
274
+ print(f"Agent configuration: {config}")
275
+
276
+ # Run the Gradio app
277
+ demo.launch(debug=True, share=False)