shan gao commited on
Commit
2ce7947
·
1 Parent(s): c9beace
Files changed (3) hide show
  1. agent.py +535 -0
  2. app.py +192 -15
  3. requirements.txt +9 -1
agent.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent_v6.py
2
+ from pathlib import Path
3
+ import os, re, base64, mimetypes, tempfile, uuid, subprocess, json
4
+ from urllib.parse import urlparse, unquote
5
+ from PIL import Image
6
+ import pytesseract
7
+ import whisper
8
+ import requests
9
+ from typing import TypedDict, List, Optional, Dict, Any, Literal
10
+ from langchain_core.tools import tool
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain_openai import ChatOpenAI
13
+ from langgraph.graph import StateGraph, START, END
14
+
15
+
16
+ # Optional: pdf parsing if GAIA sometimes includes PDFs
17
+ try:
18
+ import pdfplumber
19
+ _HAS_PDFPLUMBER = True
20
+ except Exception:
21
+ _HAS_PDFPLUMBER = False
22
+
23
+
24
+ # -------------- State -------------
25
+ class EvidenceItem(TypedDict):
26
+ kind: Literal["audio_transcript","image_ocr","image_vqa","doc_text"]
27
+ text: str
28
+ path: Optional[str]
29
+ meta: Dict[str, Any]
30
+
31
+ class AgentState(TypedDict):
32
+ task_id: str
33
+ question: str
34
+ attachment_urls: List[str] # empty list when no files
35
+ local_files: List[str]
36
+ evidence: List[EvidenceItem]
37
+ answer: Optional[str]
38
+ parsed_final_answer: Optional[str]
39
+ emit_final_answer: bool # <<< add this (default True if you want old behavior)
40
+
41
+ # -------------- helpers ---------------
42
+ def _filename_from_cd(cd: str) -> str | None:
43
+ # RFC 6266/5987: filename* takes precedence; fall back to filename
44
+ if not cd:
45
+ return None
46
+ # filename*=
47
+ m = re.search(r"filename\*\s*=\s*([^']*)'[^']*'([^;]+)", cd, flags=re.I)
48
+ if m:
49
+ return unquote(m.group(2)).strip().strip('"')
50
+ # filename=
51
+ m = re.search(r'filename\s*=\s*"?(.*?)(?:"|;|$)', cd, flags=re.I)
52
+ if m:
53
+ return m.group(1).strip().strip('"')
54
+ return None
55
+
56
+ def _pick_extension(ct: str | None) -> str | None:
57
+ if not ct:
58
+ return None
59
+ ct = ct.split(";", 1)[0].strip()
60
+ ext = mimetypes.guess_extension(ct)
61
+ # Fix common mis-maps
62
+ return {".jpe": ".jpg"}.get(ext, ext)
63
+
64
+ def _summarize_evidence(evidence: List[Dict[str, Any]], limit_chars: int = 6000) -> str:
65
+ """Compact the evidence text for prompting; keep provenance-style tags."""
66
+ chunks = []
67
+ for i, e in enumerate(evidence, 1):
68
+ t = e.get("text", "") or ""
69
+ if len(t) > 1200: # keep things small but informative
70
+ t = t[:1200] + " …"
71
+ meta = e.get("meta", {})
72
+ tag = f"{e.get('kind','?')}"
73
+ if meta.get("mime"):
74
+ tag += f"({meta['mime']})"
75
+ chunks.append(f"[{i}:{tag}] {t}")
76
+ out = "\n".join(chunks)
77
+ return out if len(out) <= limit_chars else out[:limit_chars] + " …"
78
+
79
+ def _collect_image_paths(evidence: List[Dict[str, Any]], limit: int = 4) -> List[str]:
80
+ """Find image file paths to attach to a vision model."""
81
+ paths = []
82
+ for e in evidence:
83
+ if e.get("path") and str(e.get("meta", {}).get("mime","")).startswith("image"):
84
+ p = e["path"]
85
+ if os.path.exists(p) and p not in paths:
86
+ paths.append(p)
87
+ if len(paths) >= limit:
88
+ break
89
+ return paths
90
+
91
+ def _image_to_data_url(path: str) -> str:
92
+ """Encode an image file as a data URL for OpenAI chat image parts."""
93
+ with open(path, "rb") as f:
94
+ b64 = base64.b64encode(f.read()).decode("utf-8")
95
+ mime, _ = mimetypes.guess_type(path)
96
+ mime = mime or "image/png"
97
+ return f"data:{mime};base64,{b64}"
98
+
99
+ def _ensure_final_answer_line(text: str, *, enabled: bool) -> str:
100
+ """When enabled, ensure a `final_answer:` line. When disabled, strip any such line."""
101
+ if enabled:
102
+ if re.search(r"(?im)^final_answer\s*:", text):
103
+ return text
104
+ # best-effort: take last non-empty line
105
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip() and not ln.strip().startswith("```")]
106
+ candidate = lines[-1] if lines else "[NO_ANSWER]"
107
+ return f"{text.rstrip()}\n\nfinal_answer: {candidate}"
108
+ else:
109
+ # remove any final_answer line(s)
110
+ return re.sub(r"(?im)^final_answer\s*:\s*.*\n?", "", text).strip()
111
+
112
+ def _parse_final_answer(text: str, *, enabled: bool) -> Optional[str]:
113
+ """Only parse when enabled; otherwise return None."""
114
+ if not enabled:
115
+ return None
116
+ m = re.search(r"(?im)^final_answer\s*:\s*(.+)$", text)
117
+ return m.group(1).strip() if m else None
118
+
119
+ def _convert_to_wav_mono16k(src_path: str) -> str:
120
+ print("converting to mono16... from: ", src_path)
121
+ out = os.path.join(tempfile.gettempdir(), f"gaia_{uuid.uuid4().hex}.wav")
122
+ cmd = ["ffmpeg", "-y", "-i", src_path, "-ac", "1", "-ar", "16000", out]
123
+ # Capture stderr for debugging
124
+ p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
125
+ if p.returncode != 0 or not os.path.exists(out):
126
+ raise RuntimeError(f"ffmpeg failed: {p.stderr[-500:]}")
127
+ return out
128
+
129
+ # ----------------------Tools ----------------------
130
+ @tool
131
+ def download_file(url: str, headers: dict | None = None, auth_token: str | None = None) -> str:
132
+ """Download a file following redirects and honoring Content-Disposition. Returns local path."""
133
+ sess = requests.Session()
134
+ hdrs = {"User-Agent": "gaia-agent/1.0"}
135
+ if headers:
136
+ hdrs.update(headers)
137
+ if auth_token:
138
+ hdrs["Authorization"] = f"Bearer {auth_token}"
139
+
140
+ with sess.get(url, headers=hdrs, timeout=(10, 60), stream=True, allow_redirects=True) as r:
141
+ r.raise_for_status()
142
+
143
+ # Determine filename
144
+ cd = r.headers.get("Content-Disposition", "")
145
+ fname = _filename_from_cd(cd)
146
+
147
+ if not fname:
148
+ # Fallback to URL path
149
+ path = urlparse(r.url).path or urlparse(url).path
150
+ fname = os.path.basename(path) or f"download-{uuid.uuid4().hex}"
151
+
152
+ # Ensure we have an extension
153
+ base, ext = os.path.splitext(fname)
154
+ if not ext:
155
+ guess = _pick_extension(r.headers.get("Content-Type"))
156
+ if guess:
157
+ fname = base + guess
158
+
159
+ # # Write to a temp folder (unique per call)
160
+ out_dir = tempfile.mkdtemp(prefix="gaia_tmpdl_")
161
+ out_path = os.path.join(out_dir, fname)
162
+
163
+ # # Write to colab folder
164
+ # out_dir: str | Path = "."
165
+ # out_path = Path(out_dir) / fname
166
+
167
+ print("out_path:", out_path)
168
+
169
+ with open(out_path, "wb") as f:
170
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
171
+ if chunk:
172
+ f.write(chunk)
173
+
174
+ return out_path
175
+
176
+
177
+ @tool
178
+ def transcribe_audio(path: str, model_size: str = "base") -> str:
179
+ """
180
+ Transcribe an audio file using Whisper (local). Converts to mono/16k WAV first for robustness.
181
+ Returns the transcript text; raises on failure (caller handles).
182
+ """
183
+ print("running transcribe_audio")
184
+ try:
185
+ model = whisper.load_model(model_size)
186
+ result = model.transcribe(path)
187
+ return (result.get("text") or "").strip()
188
+ except Exception as e:
189
+ raise RuntimeError(f"Whisper error: {e}")
190
+
191
+
192
+ @tool
193
+ def ocr_image(path: str) -> str:
194
+ """OCR an image using Tesseract."""
195
+ # Install tesseract binary on your system first
196
+ print("running ocr")
197
+ img = Image.open(path)
198
+ text = pytesseract.image_to_string(img)
199
+ return text.strip()
200
+
201
+
202
+ # ------------------------------- Nodes ------------------------------
203
+ def check_attachment_node(state: AgentState) -> AgentState:
204
+ """Check if there is attachment."""
205
+ print("enter check attachment node")
206
+
207
+ # 1) Try HEAD first
208
+ urls = state.get("attachment_urls")
209
+ if not urls:
210
+ print("No attachment URLs provided.")
211
+ state["attachment_urls"] = []
212
+ return state
213
+
214
+ url = urls[0] # Get the first URL from the list
215
+ headers = {"Accept": "application/json"}
216
+ timeout = 30
217
+ r = requests.head(url, headers=headers, allow_redirects=True, timeout=timeout)
218
+ # Some servers don't support HEAD; 405/501 are common. Fallback to GET (stream) to read headers only.
219
+ if r.status_code in (405, 501):
220
+ r.close()
221
+ r = requests.get(url, headers=headers, stream=True, allow_redirects=True, timeout=timeout)
222
+ try:
223
+ cd = r.headers.get("Content-Disposition", "") or r.headers.get("content-disposition", "")
224
+ is_attachment = "attachment" in cd.lower()
225
+
226
+ filename = None
227
+ if is_attachment:
228
+ m = re.search(r"filename\*=UTF-8''([^;]+)", cd, flags=re.I)
229
+ if m:
230
+ filename = unquote(m.group(1))
231
+ else:
232
+ m = re.search(r'filename="?([^";]+)"?', cd, flags=re.I)
233
+ if m:
234
+ filename = m.group(1)
235
+ print("Need to download attachment:", filename)
236
+ else:
237
+ print("No attachment header; skip downloading.")
238
+ state["attachment_urls"] = []
239
+ return state
240
+ finally:
241
+ # If we fell back to GET(stream=True), make sure we don't keep the connection open.
242
+ try:
243
+ r.close()
244
+ except Exception:
245
+ pass
246
+
247
+ def fetch_node(state: AgentState) -> AgentState:
248
+ print("enter fetch_node")
249
+
250
+ local_files = []
251
+ for u in state["attachment_urls"]:
252
+ # If already local file paths, just append them
253
+ if os.path.exists(u):
254
+ local_files.append(u)
255
+ else:
256
+ p = download_file.invoke({"url": u})
257
+ local_files.append(p)
258
+ state["local_files"] = local_files
259
+ return state
260
+
261
+ def preprocess_node(state: AgentState) -> AgentState:
262
+
263
+ """
264
+ For each local file:
265
+ - audio/* -> ASR transcript
266
+ - image/* -> OCR text (basic enhancement to help OCR)
267
+ - application/pdf -> text extraction (if pdfplumber available)
268
+ Produces EvidenceItem entries and stores in state['evidence'].
269
+ """
270
+ print("enter preprocessing node")
271
+
272
+ ev: List[Dict[str, Any]] = list(state.get("evidence", []))
273
+
274
+ for path in state.get("local_files", []):
275
+ mime, _ = mimetypes.guess_type(path)
276
+ meta = {"mime": mime or "application/octet-stream", "filename": os.path.basename(path)}
277
+
278
+ print("mime", mime)
279
+
280
+ try:
281
+ if mime and mime.startswith("audio"):
282
+ print("mime start with audio")
283
+ # print("path: ", path)
284
+ # --- ASR ---
285
+ try:
286
+ wav = _convert_to_wav_mono16k(path)
287
+ except Exception as e:
288
+ raise RuntimeError(f"Pre-conversion error: {e}")
289
+
290
+ print("after conversion saving at tmp_wav path: ", wav)
291
+ txt = transcribe_audio.invoke({"path": wav})
292
+ ev.append({"kind": "audio_transcript", "text": txt, "path": path, "meta": meta})
293
+
294
+ elif mime and mime.startswith("image"):
295
+ print("mime start with image")
296
+ # --- OCR with simple pre-enhancement ---
297
+ try:
298
+ print("upscaling original small image: ", path)
299
+ img = Image.open(path)
300
+ img = img.convert("L") # grayscale
301
+ w, h = img.size
302
+ if max(w, h) < 1600: # upscale small images to help OCR
303
+ img = img.resize((w * 2, h * 2))
304
+ tmp_ocr = os.path.join(tempfile.gettempdir(), f"ocr_{uuid.uuid4().hex}.png")
305
+ img.save(tmp_ocr)
306
+ print("After upscaling save at tmp_ocr path: ", tmp_ocr)
307
+ ocr = ocr_image.invoke({"path": tmp_ocr})
308
+ except Exception as e:
309
+ ocr = f"[OCR error: {e}]"
310
+ ev.append({"kind": "image_ocr", "text": ocr, "path": path, "meta": meta})
311
+
312
+ elif mime == "application/pdf" or (mime and mime.startswith("application") and path.lower().endswith(".pdf")):
313
+ # --- PDF extraction (best-effort; image-only PDFs may need OCR) ---
314
+ if _HAS_PDFPLUMBER:
315
+ try:
316
+ pages = []
317
+ with pdfplumber.open(path) as pdf:
318
+ for pg in pdf.pages:
319
+ pages.append(pg.extract_text() or "")
320
+ txt = "\n\n".join(pages).strip() or "[Empty or image-based PDF; try OCR]"
321
+ except Exception as e:
322
+ txt = f"[PDF parse error: {e}]"
323
+ else:
324
+ txt = "[PDF support not installed; pip install pdfplumber]"
325
+ ev.append({"kind": "doc_text", "text": txt, "path": path, "meta": meta})
326
+
327
+ else:
328
+ # Unknown/unsupported; keep a breadcrumb so you can inspect later
329
+ ev.append({"kind": "unknown_file", "text": "[Unsupported file type]", "path": path, "meta": meta})
330
+
331
+ except Exception as e:
332
+ ev.append({"kind": "preprocess_error", "text": f"[Error processing {path}: {e}]", "path": path, "meta": meta})
333
+
334
+ state["evidence"] = ev
335
+ return state
336
+
337
+ def solve_multimodal_node(state: AgentState) -> AgentState:
338
+ """
339
+ Use a vision-capable model (e.g., gpt-4o) and attach the image(s) PLUS the text evidence (ASR/OCR).
340
+ """
341
+ print("enter solve_multimodal_node")
342
+
343
+ emit = bool(state.get("emit_final_answer", True))
344
+ end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>"
345
+
346
+ question = state.get("question", "").strip()
347
+ evidence = state.get("evidence", [])
348
+
349
+ vision_llm = ChatOpenAI(model="gpt-4o", temperature=0) # vision-capable
350
+ sys = SystemMessage(content=(
351
+ "You solve GAIA tasks using the provided evidence and attached images.\n"
352
+ "Be precise, quote numbers/strings exactly. If uncertain, say so.\n"
353
+ "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n" + end_instr
354
+ ))
355
+
356
+ # Summarized text evidence (ASR/OCR/PDF text)
357
+ ev_text = _summarize_evidence(evidence)
358
+ text_part = (
359
+ f"Question:\n{question}\n\n"
360
+ f"Textual evidence (summarized):\n{ev_text}\n\n"
361
+ "Use the attached images if any to read fine text, diagrams, or confirm details."
362
+ )
363
+
364
+ parts: List[Any] = [{"type": "text", "text": text_part}]
365
+
366
+ # Attach up to 4 images (data URLs)
367
+ img_paths = _collect_image_paths(evidence, limit=4)
368
+ for p in img_paths:
369
+ parts.append({"type": "image_url", "image_url": {"url": _image_to_data_url(p)}})
370
+
371
+ resp = vision_llm.invoke([sys, HumanMessage(content=parts)])
372
+ text = (resp.content or "").strip()
373
+ text = _ensure_final_answer_line(text, enabled=emit)
374
+
375
+ state["answer"] = text
376
+ state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit)
377
+ return state
378
+
379
+
380
+ def solve_text_only_node(state: "AgentState") -> "AgentState":
381
+ """
382
+ Text-only solve path. Consumes the question + textual evidence
383
+ (e.g., audio transcripts from ASR, OCR text, PDF text). No images attached.
384
+ """
385
+ print("enter solve_text_only_node")
386
+
387
+ emit = bool(state.get("emit_final_answer", True))
388
+ end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>"
389
+
390
+ question = (state.get("question") or "").strip()
391
+ evidence = state.get("evidence", [])
392
+
393
+ # Summarized text evidence (ASR/OCR/PDF text)
394
+ ev_text = _summarize_evidence(evidence) or "(none)"
395
+
396
+ # LLM (text-only). Swap model as you like.
397
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
398
+
399
+ sys = SystemMessage(content=(
400
+ "You solve GAIA tasks. Use careful step-by-step reasoning but keep it concise.\n"
401
+ "You can use the provided textual evidence if there is any. \n"
402
+ "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n" + end_instr
403
+ ))
404
+
405
+ user = HumanMessage(content=(
406
+ f"Question:\n{question}\n\n"
407
+ f"Textual evidence (summarized):\n{ev_text}"
408
+ ))
409
+
410
+ resp = llm.invoke([sys, user])
411
+ text = (resp.content or "").strip()
412
+ text = _ensure_final_answer_line(text, enabled=emit)
413
+
414
+ state["answer"] = text
415
+ state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit)
416
+ return state
417
+
418
+ def validate_format_node(state: AgentState) -> AgentState:
419
+ """
420
+ Ensure the final output contains `final_answer: ...` and capture it separately for scoring.
421
+ Also trims excessive whitespace and removes duplicate final_answer lines if any.
422
+ """
423
+ print("enter validate_format_node")
424
+
425
+ emit = bool(state.get("emit_final_answer", True))
426
+ txt = (state.get("answer") or "").strip()
427
+
428
+ if not txt:
429
+ if emit:
430
+ state["answer"] = "No answer generated.\n\nfinal_answer: [NO_ANSWER]"
431
+ state["parsed_final_answer"] = "[NO_ANSWER]"
432
+ else:
433
+ state["answer"] = "No answer generated."
434
+ state["parsed_final_answer"] = None
435
+ return state
436
+
437
+ if emit:
438
+ # keep only the LAST final_answer line if multiple
439
+ matches = list(re.finditer(r"(?im)^final_answer\s*:\s*(.+)$", txt))
440
+ if len(matches) == 0:
441
+ txt = _ensure_final_answer_line(txt, enabled=True)
442
+ elif len(matches) > 1:
443
+ last = matches[-1].group(0)
444
+ txt_wo = re.sub(r"(?im)^final_answer\s*:\s*.+\s*$", "", txt).strip()
445
+ txt = f"{txt_wo}\n\n{last}"
446
+ state["parsed_final_answer"] = _parse_final_answer(txt, enabled=True)
447
+ else:
448
+ # strip any lingering final_answer lines (paranoia)
449
+ txt = _ensure_final_answer_line(txt, enabled=False)
450
+ state["parsed_final_answer"] = None
451
+
452
+ state["answer"] = txt.strip()
453
+ return state
454
+
455
+ # ------------------------------- Router functions ------------------------------
456
+ def route_intake(state: AgentState) -> Literal["with_files","no_files"]:
457
+ """Route based on presence of attachments (purely programmatic)."""
458
+ attachment_urls = state.get("attachment_urls") or [] # safe default
459
+ return "with_files" if attachment_urls else "no_files"
460
+
461
+ def has_images(state: AgentState) -> bool:
462
+ for e in state.get("evidence", []):
463
+ mime = (e.get("meta") or {}).get("mime", "")
464
+ if str(mime).startswith("image"):
465
+ return True
466
+ return False
467
+
468
+ def route_after_preprocess(state: AgentState) -> Literal["visions","text"]:
469
+ return "vision" if has_images(state) else "text"
470
+
471
+ # ---------- Graph ----------
472
+ # Build graph function
473
+ def build_graph():
474
+ g = StateGraph(AgentState)
475
+ g.add_node("check_attachment", check_attachment_node)
476
+ g.add_node("fetch", fetch_node)
477
+ g.add_node("preprocess", preprocess_node)
478
+ g.add_node("solve_multimodal", solve_multimodal_node)
479
+ g.add_node("solve_text_only", solve_text_only_node)
480
+ g.add_node("validate", validate_format_node)
481
+
482
+ # Start the edges
483
+ g.add_edge(START, "check_attachment")
484
+
485
+ # Add conditional branching from check_attachment
486
+ g.add_conditional_edges(
487
+ "check_attachment",
488
+ route_intake, # returns "with_files" or "no_files"
489
+ {
490
+ "with_files": "fetch",
491
+ "no_files": "solve_text_only"
492
+ }
493
+ )
494
+
495
+ # files branch
496
+ g.add_edge("fetch", "preprocess")
497
+
498
+ g.add_conditional_edges(
499
+ "preprocess",
500
+ route_after_preprocess,
501
+ {
502
+ "vision": "solve_multimodal", # question + evidence + attach images
503
+ "text": "solve_text_only", # question + transcript/other text
504
+ }
505
+ )
506
+
507
+ # both branches converge
508
+ g.add_edge("solve_multimodal", "validate")
509
+ g.add_edge("solve_text_only", "validate")
510
+ g.add_edge("validate", END)
511
+
512
+ # Compile the graph
513
+ graph_complied = g.compile()
514
+ return graph_complied
515
+
516
+
517
+ # test
518
+ if __name__ == "__main__":
519
+ task_id = '0001'
520
+ task_q = 'Who is the current president of France'
521
+ task_url = []
522
+ sample = {
523
+ "task_id": task_id,
524
+ "question": task_q,
525
+ "attachment_urls": [task_url], # from GAIA sample
526
+ "local_files": [],
527
+ "evidence": [],
528
+ "answer": None,
529
+ "parsed_final_answer": None,
530
+ "emit_final_answer": False, # <<< pure output mode
531
+ }
532
+ agent_GAIA = build_graph()
533
+ out = agent_GAIA.invoke(sample)
534
+ print("---------------------------")
535
+ print(out["answer"])
app.py CHANGED
@@ -1,22 +1,199 @@
1
- from datasets import load_dataset
2
  import gradio as gr
 
 
 
 
3
 
4
- # all levels, validation split (public)
5
- ds_val = load_dataset(
6
- "gaia-benchmark/GAIA",
7
- "2023_all",
8
- split="validation",
9
- token=True, # use_auth_token=True on older installs also works
10
- trust_remote_code=True, # needed because GAIA uses a loading script
11
- )
12
 
13
- ds_val.set_format(type="pandas")
14
- df_val = ds_val[:]
15
 
16
- temp_id = df_val.at[0, 'task_id']
 
 
 
 
 
 
17
 
18
- with gr.Blocks(title="Display task_id") as demo:
19
- gr.Markdown( "Task_id: {}".format(temp_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  if __name__ == "__main__":
22
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import requests
4
+ import inspect
5
+ import pandas as pd
6
+ from agent import build_graph
7
 
8
+ # (Keep Constants as is)
9
+ # --- Constants ---
10
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
11
 
 
 
12
 
13
+ def run_and_submit_all( profile: gr.OAuthProfile | None):
14
+ """
15
+ Fetches all questions, runs the BasicAgent on them, submits all answers,
16
+ and displays the results.
17
+ """
18
+ # --- Determine HF Space Runtime URL and Repo URL ---
19
+ space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
20
 
21
+ if profile:
22
+ username= f"{profile.username}"
23
+ print(f"User logged in: {username}")
24
+ else:
25
+ print("User not logged in.")
26
+ return "Please Login to Hugging Face with the button.", None
27
+
28
+ api_url = DEFAULT_API_URL
29
+ questions_url = f"{api_url}/questions"
30
+ submit_url = f"{api_url}/submit"
31
+
32
+ # 1. Instantiate Agent ( modify this part to create your agent)
33
+ try:
34
+ agent_GAIA = build_graph()
35
+ except Exception as e:
36
+ print(f"Error instantiating agent: {e}")
37
+ return f"Error initializing agent: {e}", None
38
+ # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
39
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
40
+ print(agent_code)
41
+
42
+ # 2. Fetch Questions
43
+ print(f"Fetching questions from: {questions_url}")
44
+ try:
45
+ response = requests.get(questions_url, timeout=15)
46
+ response.raise_for_status()
47
+ questions_data = response.json()
48
+ if not questions_data:
49
+ print("Fetched questions list is empty.")
50
+ return "Fetched questions list is empty or invalid format.", None
51
+ print(f"Fetched {len(questions_data)} questions.")
52
+ except requests.exceptions.RequestException as e:
53
+ print(f"Error fetching questions: {e}")
54
+ return f"Error fetching questions: {e}", None
55
+ except requests.exceptions.JSONDecodeError as e:
56
+ print(f"Error decoding JSON response from questions endpoint: {e}")
57
+ print(f"Response text: {response.text[:500]}")
58
+ return f"Error decoding server response for questions: {e}", None
59
+ except Exception as e:
60
+ print(f"An unexpected error occurred fetching questions: {e}")
61
+ return f"An unexpected error occurred fetching questions: {e}", None
62
+
63
+ # 3. Run your Agent
64
+ results_log = []
65
+ answers_payload = []
66
+ print(f"Running agent on {len(questions_data)} questions...")
67
+ for item in questions_data:
68
+ task_id = item.get("task_id")
69
+ question_text = item.get("question")
70
+ task_url = f"{api_url}/files/{task_id}"
71
+ sample = {
72
+ "task_id": task_id,
73
+ "question": question_text,
74
+ "attachment_urls": [task_url], # from GAIA sample
75
+ "local_files": [],
76
+ "evidence": [],
77
+ "answer": None,
78
+ "parsed_final_answer": None,
79
+ "emit_final_answer": False, # <<< pure output mode
80
+ }
81
+
82
+ if not task_id or question_text is None:
83
+ print(f"Skipping item with missing task_id or question: {item}")
84
+ continue
85
+ try:
86
+ out = agent_GAIA.invoke(sample)
87
+ submitted_answer = out["answer"]
88
+
89
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
90
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
91
+ except Exception as e:
92
+ print(f"Error running agent on task {task_id}: {e}")
93
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
94
+
95
+ if not answers_payload:
96
+ print("Agent did not produce any answers to submit.")
97
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
98
+
99
+ # 4. Prepare Submission
100
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
101
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
102
+ print(status_update)
103
+
104
+ # 5. Submit
105
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
106
+ try:
107
+ response = requests.post(submit_url, json=submission_data, timeout=60)
108
+ response.raise_for_status()
109
+ result_data = response.json()
110
+ final_status = (
111
+ f"Submission Successful!\n"
112
+ f"User: {result_data.get('username')}\n"
113
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
114
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
115
+ f"Message: {result_data.get('message', 'No message received.')}"
116
+ )
117
+ print("Submission successful.")
118
+ results_df = pd.DataFrame(results_log)
119
+ return final_status, results_df
120
+ except requests.exceptions.HTTPError as e:
121
+ error_detail = f"Server responded with status {e.response.status_code}."
122
+ try:
123
+ error_json = e.response.json()
124
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
125
+ except requests.exceptions.JSONDecodeError:
126
+ error_detail += f" Response: {e.response.text[:500]}"
127
+ status_message = f"Submission Failed: {error_detail}"
128
+ print(status_message)
129
+ results_df = pd.DataFrame(results_log)
130
+ return status_message, results_df
131
+ except requests.exceptions.Timeout:
132
+ status_message = "Submission Failed: The request timed out."
133
+ print(status_message)
134
+ results_df = pd.DataFrame(results_log)
135
+ return status_message, results_df
136
+ except requests.exceptions.RequestException as e:
137
+ status_message = f"Submission Failed: Network error - {e}"
138
+ print(status_message)
139
+ results_df = pd.DataFrame(results_log)
140
+ return status_message, results_df
141
+ except Exception as e:
142
+ status_message = f"An unexpected error occurred during submission: {e}"
143
+ print(status_message)
144
+ results_df = pd.DataFrame(results_log)
145
+ return status_message, results_df
146
+
147
+
148
+ # --- Build Gradio Interface using Blocks ---
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("# Basic Agent Evaluation Runner")
151
+ gr.Markdown(
152
+ """
153
+ **Instructions:**
154
+ 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
155
+ 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
156
+ 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
157
+ ---
158
+ **Disclaimers:**
159
+ Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
160
+ This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
161
+ """
162
+ )
163
+
164
+ gr.LoginButton()
165
+
166
+ run_button = gr.Button("Run Evaluation & Submit All Answers")
167
+
168
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
169
+ # Removed max_rows=10 from DataFrame constructor
170
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
171
+
172
+ run_button.click(
173
+ fn=run_and_submit_all,
174
+ outputs=[status_output, results_table]
175
+ )
176
 
177
  if __name__ == "__main__":
178
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
179
+ # Check for SPACE_HOST and SPACE_ID at startup for information
180
+ space_host_startup = os.getenv("SPACE_HOST")
181
+ space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
182
+
183
+ if space_host_startup:
184
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
185
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
186
+ else:
187
+ print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
188
+
189
+ if space_id_startup: # Print repo URLs if SPACE_ID is found
190
+ print(f"✅ SPACE_ID found: {space_id_startup}")
191
+ print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
192
+ print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
193
+ else:
194
+ print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
195
+
196
+ print("-"*(60 + len(" App Starting ")) + "\n")
197
+
198
+ print("Launching Gradio Interface for Basic Agent Evaluation...")
199
+ demo.launch(debug=True, share=False)
requirements.txt CHANGED
@@ -1,3 +1,11 @@
1
  gradio
2
  requests
3
- datasets<4.0
 
 
 
 
 
 
 
 
 
1
  gradio
2
  requests
3
+ langgraph
4
+ langchain_openai
5
+ langchain_huggingface
6
+ sentence-transformers
7
+ langchain-community
8
+ ddgs
9
+ openai-whisper
10
+ pytesseract
11
+ ffmpeg