claudi47 commited on
Commit
a765bf8
·
1 Parent(s): 8420ad2

Fix Groq auth and add support for GAIA file tasks

Browse files
Files changed (3) hide show
  1. README.md +2 -1
  2. app.py +160 -6
  3. requirements.txt +1 -0
README.md CHANGED
@@ -23,7 +23,7 @@ The app logs in with Hugging Face OAuth, downloads the GAIA evaluation questions
23
  Create a `.env` file with the secrets needed by the model provider and by Hugging Face Spaces:
24
 
25
  ```bash
26
- HF_TOKEN=your_token_here
27
  SPACE_ID=your-username/your-space-name
28
  ```
29
 
@@ -44,5 +44,6 @@ python app.py
44
  ## Notes
45
 
46
  - The app uses `https://agents-course-unit4-scoring.hf.space` as the scoring API.
 
47
  - The Gradio SDK version is pinned in this README frontmatter and dependencies are pinned in `requirements.txt`.
48
  - OAuth must be enabled on the Hugging Face Space for the login flow to work.
 
23
  Create a `.env` file with the secrets needed by the model provider and by Hugging Face Spaces:
24
 
25
  ```bash
26
+ GROQ_API_KEY=your_groq_key_here
27
  SPACE_ID=your-username/your-space-name
28
  ```
29
 
 
44
  ## Notes
45
 
46
  - The app uses `https://agents-course-unit4-scoring.hf.space` as the scoring API.
47
+ - Text answers use Groq `llama-3.3-70b-versatile`; audio files use Groq Whisper; image files use a Groq vision model.
48
  - The Gradio SDK version is pinned in this README frontmatter and dependencies are pinned in `requirements.txt`.
49
  - OAuth must be enabled on the Hugging Face Space for the login flow to work.
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import requests
3
  import pandas as pd
4
  import gradio as gr
@@ -18,6 +20,12 @@ load_dotenv()
18
  DEFAULT_API_URL = (
19
  "https://agents-course-unit4-scoring.hf.space"
20
  )
 
 
 
 
 
 
21
 
22
  # Format instructions appended to every question
23
  # so that the agent returns exact-match-friendly
@@ -89,6 +97,7 @@ class GaiaFileFetcherTool(Tool):
89
  if not fname:
90
  fname = f"{task_id}{ext}"
91
 
 
92
  path = os.path.join(
93
  _tmp.gettempdir(), fname
94
  )
@@ -97,6 +106,130 @@ class GaiaFileFetcherTool(Tool):
97
  return path
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # --------------------------------------------------
101
  # Agent wrapper
102
  # --------------------------------------------------
@@ -104,15 +237,24 @@ class BasicAgent:
104
  def __init__(self):
105
  print("BasicAgent initialized.")
106
 
 
 
 
 
 
 
 
107
  model = OpenAIServerModel(
108
- model_id="llama-3.3-70b-versatile",
109
- api_base="https://api.groq.com/openai/v1",
110
- api_key=os.getenv("HF_TOKEN"),
111
  )
112
 
113
  self.file_tool = GaiaFileFetcherTool(
114
  api_url=DEFAULT_API_URL,
115
  )
 
 
116
 
117
  self.agent = CodeAgent(
118
  model=model,
@@ -123,10 +265,13 @@ class BasicAgent:
123
  ),
124
  VisitWebpageTool(),
125
  self.file_tool,
 
 
126
  ],
127
  max_steps=15,
128
  verbosity_level=0,
129
  additional_authorized_imports=[
 
130
  "json",
131
  "re",
132
  "csv",
@@ -136,6 +281,10 @@ class BasicAgent:
136
  "collections",
137
  "itertools",
138
  "os",
 
 
 
 
139
  ],
140
  )
141
 
@@ -153,7 +302,10 @@ class BasicAgent:
153
  f"\n\n[This question has an attached "
154
  f"file. Use the fetch_task_file tool "
155
  f"with task_id='{task_id}' to "
156
- f"download and read it.]"
 
 
 
157
  )
158
 
159
  prompt += ANSWER_FORMAT_INSTRUCTIONS
@@ -198,7 +350,7 @@ def run_and_submit_all(
198
 
199
  agent_code = (
200
  f"https://huggingface.co/spaces/"
201
- f"{space_id}/tree/main"
202
  )
203
  print(agent_code)
204
 
@@ -448,6 +600,8 @@ page fetching, and file download tools.*
448
  outputs=[status_output, results_table],
449
  )
450
 
 
 
451
  if __name__ == "__main__":
452
  print(
453
  "\n" + "-" * 30
@@ -469,4 +623,4 @@ if __name__ == "__main__":
469
 
470
  print("-" * 74 + "\n")
471
  print("Launching Gradio Interface...")
472
- demo.launch(debug=True, share=False)
 
1
  import os
2
+ import base64
3
+ import mimetypes
4
  import requests
5
  import pandas as pd
6
  import gradio as gr
 
20
  DEFAULT_API_URL = (
21
  "https://agents-course-unit4-scoring.hf.space"
22
  )
23
+ GROQ_API_BASE = "https://api.groq.com/openai/v1"
24
+ TEXT_MODEL_ID = "llama-3.3-70b-versatile"
25
+ VISION_MODEL_ID = (
26
+ "meta-llama/llama-4-scout-17b-16e-instruct"
27
+ )
28
+ AUDIO_MODEL_ID = "whisper-large-v3"
29
 
30
  # Format instructions appended to every question
31
  # so that the agent returns exact-match-friendly
 
97
  if not fname:
98
  fname = f"{task_id}{ext}"
99
 
100
+ fname = os.path.basename(fname)
101
  path = os.path.join(
102
  _tmp.gettempdir(), fname
103
  )
 
106
  return path
107
 
108
 
109
+ class GroqAudioTranscriptionTool(Tool):
110
+ """Transcribes an audio file with Groq Whisper."""
111
+
112
+ name = "transcribe_audio_file"
113
+ description = (
114
+ "Transcribes a local audio file path, such as an "
115
+ "MP3 downloaded with fetch_task_file. Returns the "
116
+ "plain transcript text."
117
+ )
118
+ inputs = {
119
+ "file_path": {
120
+ "type": "string",
121
+ "description": "Local path to the audio file.",
122
+ }
123
+ }
124
+ output_type = "string"
125
+
126
+ def forward(self, file_path: str) -> str:
127
+ api_key = os.getenv("GROQ_API_KEY")
128
+ if not api_key:
129
+ raise RuntimeError(
130
+ "GROQ_API_KEY is required for audio transcription."
131
+ )
132
+
133
+ with open(file_path, "rb") as audio_file:
134
+ response = requests.post(
135
+ f"{GROQ_API_BASE}/audio/transcriptions",
136
+ headers={
137
+ "Authorization": f"Bearer {api_key}",
138
+ },
139
+ files={
140
+ "file": (
141
+ os.path.basename(file_path),
142
+ audio_file,
143
+ )
144
+ },
145
+ data={
146
+ "model": AUDIO_MODEL_ID,
147
+ "response_format": "json",
148
+ "temperature": "0",
149
+ },
150
+ timeout=120,
151
+ )
152
+ response.raise_for_status()
153
+ return response.json().get("text", "").strip()
154
+
155
+
156
+ class GroqImageAnalysisTool(Tool):
157
+ """Answers questions about a local image with Groq vision."""
158
+
159
+ name = "analyze_image_file"
160
+ description = (
161
+ "Analyzes a local image file path and answers a "
162
+ "specific visual question about it."
163
+ )
164
+ inputs = {
165
+ "file_path": {
166
+ "type": "string",
167
+ "description": "Local path to the image file.",
168
+ },
169
+ "question": {
170
+ "type": "string",
171
+ "description": "The question to answer about the image.",
172
+ },
173
+ }
174
+ output_type = "string"
175
+
176
+ def forward(self, file_path: str, question: str) -> str:
177
+ api_key = os.getenv("GROQ_API_KEY")
178
+ if not api_key:
179
+ raise RuntimeError(
180
+ "GROQ_API_KEY is required for image analysis."
181
+ )
182
+
183
+ mime_type = (
184
+ mimetypes.guess_type(file_path)[0]
185
+ or "application/octet-stream"
186
+ )
187
+ with open(file_path, "rb") as image_file:
188
+ encoded = base64.b64encode(
189
+ image_file.read()
190
+ ).decode("ascii")
191
+
192
+ response = requests.post(
193
+ f"{GROQ_API_BASE}/chat/completions",
194
+ headers={
195
+ "Authorization": f"Bearer {api_key}",
196
+ "Content-Type": "application/json",
197
+ },
198
+ json={
199
+ "model": VISION_MODEL_ID,
200
+ "messages": [
201
+ {
202
+ "role": "user",
203
+ "content": [
204
+ {
205
+ "type": "text",
206
+ "text": question,
207
+ },
208
+ {
209
+ "type": "image_url",
210
+ "image_url": {
211
+ "url": (
212
+ f"data:{mime_type};"
213
+ f"base64,{encoded}"
214
+ )
215
+ },
216
+ },
217
+ ],
218
+ }
219
+ ],
220
+ "temperature": 0.1,
221
+ "max_completion_tokens": 512,
222
+ },
223
+ timeout=120,
224
+ )
225
+ response.raise_for_status()
226
+ return (
227
+ response.json()["choices"][0]["message"]
228
+ ["content"]
229
+ .strip()
230
+ )
231
+
232
+
233
  # --------------------------------------------------
234
  # Agent wrapper
235
  # --------------------------------------------------
 
237
  def __init__(self):
238
  print("BasicAgent initialized.")
239
 
240
+ groq_api_key = os.getenv("GROQ_API_KEY")
241
+ if not groq_api_key:
242
+ raise RuntimeError(
243
+ "Missing GROQ_API_KEY. Add it to your "
244
+ "Hugging Face Space secrets or local .env file."
245
+ )
246
+
247
  model = OpenAIServerModel(
248
+ model_id=TEXT_MODEL_ID,
249
+ api_base=GROQ_API_BASE,
250
+ api_key=groq_api_key,
251
  )
252
 
253
  self.file_tool = GaiaFileFetcherTool(
254
  api_url=DEFAULT_API_URL,
255
  )
256
+ self.audio_tool = GroqAudioTranscriptionTool()
257
+ self.image_tool = GroqImageAnalysisTool()
258
 
259
  self.agent = CodeAgent(
260
  model=model,
 
265
  ),
266
  VisitWebpageTool(),
267
  self.file_tool,
268
+ self.audio_tool,
269
+ self.image_tool,
270
  ],
271
  max_steps=15,
272
  verbosity_level=0,
273
  additional_authorized_imports=[
274
+ "base64",
275
  "json",
276
  "re",
277
  "csv",
 
281
  "collections",
282
  "itertools",
283
  "os",
284
+ "pathlib",
285
+ "mimetypes",
286
+ "pandas",
287
+ "openpyxl",
288
  ],
289
  )
290
 
 
302
  f"\n\n[This question has an attached "
303
  f"file. Use the fetch_task_file tool "
304
  f"with task_id='{task_id}' to "
305
+ f"download it. If it is audio, use "
306
+ f"transcribe_audio_file. If it is an "
307
+ f"image, use analyze_image_file. If it "
308
+ f"is a spreadsheet, read it with pandas.]"
309
  )
310
 
311
  prompt += ANSWER_FORMAT_INSTRUCTIONS
 
350
 
351
  agent_code = (
352
  f"https://huggingface.co/spaces/"
353
+ f"{space_id or 'unknown-space'}/tree/main"
354
  )
355
  print(agent_code)
356
 
 
600
  outputs=[status_output, results_table],
601
  )
602
 
603
+ demo.queue()
604
+
605
  if __name__ == "__main__":
606
  print(
607
  "\n" + "-" * 30
 
623
 
624
  print("-" * 74 + "\n")
625
  print("Launching Gradio Interface...")
626
+ demo.launch(debug=True, share=False)
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  gradio==5.25.2
2
  requests==2.32.5
3
  pandas==2.3.3
 
4
  smolagents[openai]==1.24.0
5
  ddgs==9.14.0
6
  wikipedia-api==0.10.2
 
1
  gradio==5.25.2
2
  requests==2.32.5
3
  pandas==2.3.3
4
+ openpyxl==3.1.5
5
  smolagents[openai]==1.24.0
6
  ddgs==9.14.0
7
  wikipedia-api==0.10.2