Wen-ChuangChou commited on
Commit
1729ab6
·
1 Parent(s): cee706e

Add using Gemeni Model

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. agent.py +314 -0
  3. app.py +2 -165
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .env
2
- *pycache*
 
 
1
  .env
2
+ *pycache*
3
+ results/
agent.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import requests
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from dotenv import load_dotenv
8
+ from typing import Dict, List, Any
9
+ from smolagents import DuckDuckGoSearchTool, OpenAIServerModel, CodeAgent, ActionStep, TaskStep
10
+ from blablador import Models
11
+
12
+ load_dotenv()
13
+
14
+
15
+ class BasicAgent:
16
+
17
+ def __init__(self,
18
+ model_provider: str = "Blablador",
19
+ memory_file: str = "agent_memory.json"):
20
+ self.model_provider = model_provider
21
+ self.memory_file = memory_file
22
+
23
+ if model_provider == "Blablador":
24
+
25
+ models = Models(
26
+ api_key=os.getenv("Blablador_API_KEY")).get_model_ids()
27
+ model_id_blablador = 5
28
+ model_name = " ".join(
29
+ models[model_id_blablador].split(" - ")[1].split()[:2])
30
+ print("The agent uses the following model:", model_name)
31
+
32
+ answer_llm = OpenAIServerModel(
33
+ model_id=models[model_id_blablador],
34
+ api_base="https://helmholtz-blablador.fz-juelich.de:8000/v1",
35
+ api_key=os.getenv("Blablador_API_KEY"),
36
+ flatten_messages_as_text=True,
37
+ temperature=0.2)
38
+
39
+ elif model_provider == "Gemini":
40
+
41
+ # model_name = "gemini-2.5-flash-preview-05-20"
42
+ model_name = "gemini-2.0-flash"
43
+ print("The agent uses the following model:", model_name)
44
+
45
+ answer_llm = OpenAIServerModel(
46
+ model_id=model_name,
47
+ api_base=
48
+ "https://generativelanguage.googleapis.com/v1beta/openai/",
49
+ api_key=os.getenv("Gemini_API_KEY2"),
50
+ temperature=0.2)
51
+ else:
52
+ print(
53
+ f"Error: Unsupported model provider '{model_provider}'. Only 'Blablador' and 'Gemini' are supported."
54
+ )
55
+ sys.exit(1)
56
+
57
+ self.agent = CodeAgent(
58
+ tools=[DuckDuckGoSearchTool()],
59
+ model=answer_llm,
60
+ planning_interval=3,
61
+ max_steps=10,
62
+ # verbosity_level=LogLevel.ERROR,
63
+ )
64
+
65
+ def __call__(self,
66
+ question: str,
67
+ task_id: str = "",
68
+ file_url: str = "",
69
+ file_ext: str = "") -> str:
70
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
71
+
72
+ SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question.
73
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
74
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
75
+ If you are asked for a number, don't use comma to write your number
76
+ neither use units such as $ or percent sign unless specified otherwise.
77
+ If you are asked for a string, don't use articles, neither abbreviations, (e.g. for cities),
78
+ and write the digits in plain text unless specified otherwise.
79
+ If you are asked for a comma separated list,
80
+ apply the above rules depending of whether the element to be put in the list is a number or a string.
81
+ """
82
+
83
+ # Prepare additional_args for file handling
84
+ additional_args = {}
85
+
86
+ # Handle file if provided
87
+ if file_url:
88
+ # print(f"Downloading file from: {file_url}")
89
+ # file_content = self._download_file(file_url, file_ext)
90
+
91
+ # if file_content is not None:
92
+ # # Give the file a clear name based on its extension
93
+ # if file_ext.lower() == 'csv':
94
+ # # For CSV files, try to load as DataFrame
95
+ # try:
96
+ # import io
97
+ # if isinstance(file_content, str):
98
+ # df = pd.read_csv(io.StringIO(file_content))
99
+ # else:
100
+ # df = pd.read_csv(io.BytesIO(file_content))
101
+ # additional_args['dataframe'] = df
102
+ # additional_args['csv_file'] = file_content
103
+ # print(f"Loaded CSV file with shape: {df.shape}")
104
+ # except Exception as e:
105
+ # print(f"Could not parse CSV file: {e}")
106
+ # additional_args['file_content'] = file_content
107
+
108
+ # elif file_ext.lower() in ['json']:
109
+ # try:
110
+ # import json
111
+ # if isinstance(file_content, bytes):
112
+ # file_content = file_content.decode('utf-8')
113
+ # json_data = json.loads(file_content)
114
+ # additional_args['json_data'] = json_data
115
+ # additional_args['file_content'] = file_content
116
+ # print(f"Loaded JSON file")
117
+ # except Exception as e:
118
+ # print(f"Could not parse JSON file: {e}")
119
+ # additional_args['file_content'] = file_content
120
+
121
+ # else:
122
+ # # For other file types, just pass the content
123
+ # additional_args['file_content'] = file_content
124
+ # if file_ext:
125
+ # additional_args['file_extension'] = file_ext
126
+ # print(f"Loaded {file_ext} file")
127
+
128
+ # Update the prompt to mention the file
129
+ # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: A {file_ext} file has been provided and is available for your analysis."
130
+ additional_args = f"{file_url}_{file_ext}"
131
+ full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: A {file_ext} file has been provided and is available for your analysis."
132
+
133
+ # else:
134
+ # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: Could not retrieve the file from {file_url}."
135
+ else:
136
+ full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}"
137
+
138
+ # # Combine system prompt with the user question
139
+ # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}"
140
+
141
+ try:
142
+ answer = self.agent.run(full_prompt)
143
+ # answer = self.agent.run(
144
+ # task=full_prompt,
145
+ # additional_args=additional_args if additional_args else None)
146
+ print(f"Agent returning answer: {answer}")
147
+
148
+ # Export memory after execution
149
+ self.export_memory_to_json(task_id=task_id,
150
+ question=question,
151
+ answer=answer)
152
+
153
+ # Sleep for 10 seconds if using Gemini to avoid rate limiting
154
+ if self.model_provider == "Gemini":
155
+ time.sleep(10)
156
+ return answer
157
+ except Exception as e:
158
+ print(f"Error running agent: {e}")
159
+ return f"Error: {e}"
160
+
161
+ def export_memory_to_json(self,
162
+ task_id: str = "",
163
+ question: str = "",
164
+ answer: str = "",
165
+ error: str = ""):
166
+ """Export agent's memory to JSON file for each question"""
167
+ memory_data = self.extract_memory_data()
168
+
169
+ # Load existing memory file if it exists
170
+ if os.path.exists(self.memory_file):
171
+ with open(self.memory_file, 'r', encoding='utf-8') as f:
172
+ existing_data = json.load(f)
173
+ else:
174
+ existing_data = {"questions": [], "batch_info": {}}
175
+
176
+ # Create question data
177
+ question_data = {
178
+ "question_id": task_id or len(existing_data["questions"]) + 1,
179
+ "timestamp": datetime.now().isoformat(),
180
+ "model_provider": self.model_provider,
181
+ "task": question,
182
+ "result": answer,
183
+ "error": error,
184
+ "memory": memory_data,
185
+ "memory_stats": self.get_memory_stats()
186
+ }
187
+
188
+ # Add or update question
189
+ if task_id:
190
+ # Check if question_id already exists and update it
191
+ question_exists = False
192
+ for i, existing_question in enumerate(existing_data["questions"]):
193
+ if existing_question["question_id"] == task_id:
194
+ existing_data["questions"][i] = question_data
195
+ question_exists = True
196
+ break
197
+
198
+ if not question_exists:
199
+ existing_data["questions"].append(question_data)
200
+ else:
201
+ existing_data["questions"].append(question_data)
202
+
203
+ # Update batch info
204
+ existing_data["batch_info"] = {
205
+ "total_questions": len(existing_data["questions"]),
206
+ "last_updated": datetime.now().isoformat(),
207
+ "model_provider": self.model_provider
208
+ }
209
+
210
+ # Save to file
211
+ with open(self.memory_file, 'w', encoding='utf-8') as f:
212
+ json.dump(existing_data,
213
+ f,
214
+ indent=2,
215
+ ensure_ascii=False,
216
+ default=str)
217
+
218
+ print(f"Memory for question {task_id} exported to {self.memory_file}")
219
+
220
+ def extract_memory_data(self) -> Dict[str, Any]:
221
+ """Extract memory data from agent"""
222
+ memory_data = {"system_prompt": None, "steps": [], "full_steps": []}
223
+
224
+ # Get system prompt
225
+ if hasattr(self.agent.memory,
226
+ 'system_prompt') and self.agent.memory.system_prompt:
227
+ memory_data["system_prompt"] = {
228
+ "content": str(self.agent.memory.system_prompt.system_prompt),
229
+ "type": "system_prompt"
230
+ }
231
+
232
+ # Get all memory steps
233
+ for i, step in enumerate(self.agent.memory.steps):
234
+ step_data = {
235
+ "step_index": i,
236
+ "step_type": type(step).__name__,
237
+ "timestamp": datetime.now().isoformat()
238
+ }
239
+
240
+ if isinstance(step, TaskStep):
241
+ step_data.update({
242
+ "task":
243
+ step.task,
244
+ "task_images":
245
+ len(step.task_images) if step.task_images else 0
246
+ })
247
+
248
+ elif isinstance(step, ActionStep):
249
+ step_data.update({
250
+ "step_number":
251
+ step.step_number,
252
+ "llm_output":
253
+ getattr(step, 'action', None),
254
+ "observations":
255
+ step.observations,
256
+ "error":
257
+ str(step.error) if step.error else None,
258
+ "has_images":
259
+ len(step.observations_images) > 0
260
+ if step.observations_images else False
261
+ })
262
+
263
+ memory_data["steps"].append(step_data)
264
+
265
+ # Get full steps as dictionaries (as mentioned in docs)
266
+ try:
267
+ full_steps = self.agent.memory.get_full_steps()
268
+ memory_data["full_steps"] = full_steps
269
+ except Exception as e:
270
+ print(f"Could not get full steps: {e}")
271
+ memory_data["full_steps"] = []
272
+
273
+ return memory_data
274
+
275
+ def get_memory_stats(self) -> Dict[str, int]:
276
+ """Get statistics about the agent's memory"""
277
+ stats = {
278
+ "total_steps": len(self.agent.memory.steps),
279
+ "task_steps": 0,
280
+ "action_steps": 0,
281
+ "error_steps": 0,
282
+ "successful_steps": 0
283
+ }
284
+
285
+ for step in self.agent.memory.steps:
286
+ if isinstance(step, TaskStep):
287
+ stats["task_steps"] += 1
288
+ elif isinstance(step, ActionStep):
289
+ stats["action_steps"] += 1
290
+ if step.error:
291
+ stats["error_steps"] += 1
292
+ else:
293
+ stats["successful_steps"] += 1
294
+
295
+ return stats
296
+
297
+ def _download_file(self, file_url: str, file_ext: str = "") -> str:
298
+ """Download file content from URL and return as text or bytes"""
299
+ try:
300
+ response = requests.get(file_url, timeout=30)
301
+ response.raise_for_status()
302
+
303
+ # For text files, return as string
304
+ if file_ext.lower() in [
305
+ 'txt', 'csv', 'json', 'md', 'py', 'js', 'html', 'xml'
306
+ ]:
307
+ return response.text
308
+ else:
309
+ # For binary files, return the content as bytes
310
+ return response.content
311
+
312
+ except Exception as e:
313
+ print(f"Error downloading file from {file_url}: {e}")
314
+ return None
app.py CHANGED
@@ -3,175 +3,12 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- import sys
7
- import time
8
- from dotenv import load_dotenv
9
- from smolagents import DuckDuckGoSearchTool, OpenAIServerModel, CodeAgent, Tool
10
- from blablador import Models
11
 
12
  # (Keep Constants as is)
13
  # --- Constants ---
14
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
16
- # --- Basic Agent Definition ---
17
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
- load_dotenv()
19
-
20
-
21
- class BasicAgent:
22
-
23
- def __init__(self, model_provider: str = "Blablador"):
24
- self.model_provider = model_provider
25
-
26
- if model_provider == "Blablador":
27
-
28
- models = Models(
29
- api_key=os.getenv("Blablador_API_KEY")).get_model_ids()
30
- model_id_blablador = 5
31
- model_name = " ".join(
32
- models[model_id_blablador].split(" - ")[1].split()[:2])
33
- print("The agent uses the following model:", model_name)
34
-
35
- answer_llm = OpenAIServerModel(
36
- model_id=models[model_id_blablador],
37
- api_base="https://helmholtz-blablador.fz-juelich.de:8000/v1",
38
- api_key=os.getenv("Blablador_API_KEY"),
39
- flatten_messages_as_text=True,
40
- temperature=0.2)
41
-
42
- elif model_provider == "Gemini":
43
-
44
- # model_name = "gemini-2.5-flash-preview-05-20"
45
- model_name = "gemini-2.0-flash"
46
- print("The agent uses the following model:", model_name)
47
-
48
- answer_llm = OpenAIServerModel(
49
- model_id=model_name,
50
- api_base=
51
- "https://generativelanguage.googleapis.com/v1beta/openai/",
52
- api_key=os.getenv("Gemini_API_KEY2"),
53
- temperature=0.2)
54
- else:
55
- print(
56
- f"Error: Unsupported model provider '{model_provider}'. Only 'Blablador' and 'Gemini' are supported."
57
- )
58
- sys.exit(1)
59
-
60
- self.agent = CodeAgent(
61
- tools=[DuckDuckGoSearchTool()],
62
- model=answer_llm,
63
- planning_interval=3,
64
- max_steps=10,
65
- # verbosity_level=LogLevel.ERROR,
66
- )
67
-
68
- def __call__(self,
69
- question: str,
70
- file_url: str = "",
71
- file_ext: str = "") -> str:
72
- print(f"Agent received question (first 50 chars): {question[:50]}...")
73
-
74
- SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question.
75
- Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
76
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
77
- If you are asked for a number, don't use comma to write your number
78
- neither use units such as $ or percent sign unless specified otherwise.
79
- If you are asked for a string, don't use articles, neither ABBREVIATIONS, (e.g. for cities),
80
- and write the digits in plain text unless specified otherwise.
81
- If you are asked for a comma separated list,
82
- apply the above rules depending of whether the element to be put in the list is a number or a string.
83
- """
84
-
85
- # Prepare additional_args for file handling
86
- additional_args = {}
87
-
88
- # Handle file if provided
89
- if file_url:
90
- # print(f"Downloading file from: {file_url}")
91
- # file_content = self._download_file(file_url, file_ext)
92
-
93
- # if file_content is not None:
94
- # # Give the file a clear name based on its extension
95
- # if file_ext.lower() == 'csv':
96
- # # For CSV files, try to load as DataFrame
97
- # try:
98
- # import io
99
- # if isinstance(file_content, str):
100
- # df = pd.read_csv(io.StringIO(file_content))
101
- # else:
102
- # df = pd.read_csv(io.BytesIO(file_content))
103
- # additional_args['dataframe'] = df
104
- # additional_args['csv_file'] = file_content
105
- # print(f"Loaded CSV file with shape: {df.shape}")
106
- # except Exception as e:
107
- # print(f"Could not parse CSV file: {e}")
108
- # additional_args['file_content'] = file_content
109
-
110
- # elif file_ext.lower() in ['json']:
111
- # try:
112
- # import json
113
- # if isinstance(file_content, bytes):
114
- # file_content = file_content.decode('utf-8')
115
- # json_data = json.loads(file_content)
116
- # additional_args['json_data'] = json_data
117
- # additional_args['file_content'] = file_content
118
- # print(f"Loaded JSON file")
119
- # except Exception as e:
120
- # print(f"Could not parse JSON file: {e}")
121
- # additional_args['file_content'] = file_content
122
-
123
- # else:
124
- # # For other file types, just pass the content
125
- # additional_args['file_content'] = file_content
126
- # if file_ext:
127
- # additional_args['file_extension'] = file_ext
128
- # print(f"Loaded {file_ext} file")
129
-
130
- # Update the prompt to mention the file
131
- # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: A {file_ext} file has been provided and is available for your analysis."
132
- additional_args = f"{file_url}_{file_ext}"
133
- full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: A {file_ext} file has been provided and is available for your analysis."
134
-
135
- # else:
136
- # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}\n\nNote: Could not retrieve the file from {file_url}."
137
- else:
138
- full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}"
139
-
140
- # # Combine system prompt with the user question
141
- # full_prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {question}"
142
-
143
- try:
144
- answer = self.agent.run(full_prompt)
145
- # answer = self.agent.run(
146
- # task=full_prompt,
147
- # additional_args=additional_args if additional_args else None)
148
- print(f"Agent returning answer: {answer}")
149
- if self.model_provider == "Gemini":
150
- time.sleep(10)
151
- return answer
152
- except Exception as e:
153
- print(f"Error running agent: {e}")
154
- return f"Error: {e}"
155
-
156
- def _download_file(self, file_url: str, file_ext: str = "") -> str:
157
- """Download file content from URL and return as text or bytes"""
158
- try:
159
- response = requests.get(file_url, timeout=30)
160
- response.raise_for_status()
161
-
162
- # For text files, return as string
163
- if file_ext.lower() in [
164
- 'txt', 'csv', 'json', 'md', 'py', 'js', 'html', 'xml'
165
- ]:
166
- return response.text
167
- else:
168
- # For binary files, return the content as bytes
169
- return response.content
170
-
171
- except Exception as e:
172
- print(f"Error downloading file from {file_url}: {e}")
173
- return None
174
-
175
 
176
  def run_and_submit_all(profile: gr.OAuthProfile | None):
177
  """
@@ -244,7 +81,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
244
  file_url = f"{api_url}/files/{task_id}"
245
 
246
  try:
247
- submitted_answer = agent(question_text)
248
  # submitted_answer = agent(question_text, file_url, file_ext)
249
  answers_payload.append({
250
  "task_id": task_id,
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from agent import BasicAgent
 
 
 
 
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def run_and_submit_all(profile: gr.OAuthProfile | None):
14
  """
 
81
  file_url = f"{api_url}/files/{task_id}"
82
 
83
  try:
84
+ submitted_answer = agent(question_text, task_id)
85
  # submitted_answer = agent(question_text, file_url, file_ext)
86
  answers_payload.append({
87
  "task_id": task_id,