Spaces:
Sleeping
Sleeping
| # agent_v6.py | |
| # Develop an AI agent with LangGraph and LangChain | |
| # to answer the questions in the "gaia-benchmark/GAIA" dataset. | |
| from pathlib import Path | |
| import os, re, base64, mimetypes, tempfile, uuid, subprocess, json | |
| from urllib.parse import urlparse, unquote | |
| from PIL import Image | |
| import pytesseract | |
| import whisper | |
| import requests | |
| from typing import TypedDict, List, Optional, Dict, Any, Literal | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import StateGraph, START, END | |
| # Optional: pdf parsing if GAIA sometimes includes PDFs | |
| try: | |
| import pdfplumber | |
| _HAS_PDFPLUMBER = True | |
| except Exception: | |
| _HAS_PDFPLUMBER = False | |
| # -------------- State ------------- | |
| class EvidenceItem(TypedDict): | |
| kind: Literal["audio_transcript","image_ocr","image_vqa","doc_text"] | |
| text: str | |
| path: Optional[str] | |
| meta: Dict[str, Any] | |
| class AgentState(TypedDict): | |
| task_id: str | |
| question: str | |
| attachment_urls: List[str] # empty list when no files | |
| local_files: List[str] | |
| evidence: List[EvidenceItem] | |
| answer: Optional[str] | |
| parsed_final_answer: Optional[str] | |
| emit_final_answer: bool # <<< add this (default True if you want old behavior) | |
| # -------------- helpers --------------- | |
| def _filename_from_cd(cd: str) -> str | None: | |
| # RFC 6266/5987: filename* takes precedence; fall back to filename | |
| if not cd: | |
| return None | |
| # filename*= | |
| m = re.search(r"filename\*\s*=\s*([^']*)'[^']*'([^;]+)", cd, flags=re.I) | |
| if m: | |
| return unquote(m.group(2)).strip().strip('"') | |
| # filename= | |
| m = re.search(r'filename\s*=\s*"?(.*?)(?:"|;|$)', cd, flags=re.I) | |
| if m: | |
| return m.group(1).strip().strip('"') | |
| return None | |
| def _pick_extension(ct: str | None) -> str | None: | |
| if not ct: | |
| return None | |
| ct = ct.split(";", 1)[0].strip() | |
| ext = mimetypes.guess_extension(ct) | |
| # Fix common mis-maps | |
| return {".jpe": ".jpg"}.get(ext, ext) | |
| def _summarize_evidence(evidence: List[Dict[str, Any]], limit_chars: int = 6000) -> str: | |
| """Compact the evidence text for prompting; keep provenance-style tags.""" | |
| chunks = [] | |
| for i, e in enumerate(evidence, 1): | |
| t = e.get("text", "") or "" | |
| if len(t) > 1200: # keep things small but informative | |
| t = t[:1200] + " …" | |
| meta = e.get("meta", {}) | |
| tag = f"{e.get('kind','?')}" | |
| if meta.get("mime"): | |
| tag += f"({meta['mime']})" | |
| chunks.append(f"[{i}:{tag}] {t}") | |
| out = "\n".join(chunks) | |
| return out if len(out) <= limit_chars else out[:limit_chars] + " …" | |
| def _collect_image_paths(evidence: List[Dict[str, Any]], limit: int = 4) -> List[str]: | |
| """Find image file paths to attach to a vision model.""" | |
| paths = [] | |
| for e in evidence: | |
| if e.get("path") and str(e.get("meta", {}).get("mime","")).startswith("image"): | |
| p = e["path"] | |
| if os.path.exists(p) and p not in paths: | |
| paths.append(p) | |
| if len(paths) >= limit: | |
| break | |
| return paths | |
| def _image_to_data_url(path: str) -> str: | |
| """Encode an image file as a data URL for OpenAI chat image parts.""" | |
| with open(path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| mime, _ = mimetypes.guess_type(path) | |
| mime = mime or "image/png" | |
| return f"data:{mime};base64,{b64}" | |
| def _ensure_final_answer_line(text: str, *, enabled: bool) -> str: | |
| """When enabled, ensure a `final_answer:` line. When disabled, strip any such line.""" | |
| if enabled: | |
| if re.search(r"(?im)^final_answer\s*:", text): | |
| return text | |
| # best-effort: take last non-empty line | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip() and not ln.strip().startswith("```")] | |
| candidate = lines[-1] if lines else "[NO_ANSWER]" | |
| return f"{text.rstrip()}\n\nfinal_answer: {candidate}" | |
| else: | |
| # remove any final_answer line(s) | |
| return re.sub(r"(?im)^final_answer\s*:\s*.*\n?", "", text).strip() | |
| def _parse_final_answer(text: str, *, enabled: bool) -> Optional[str]: | |
| """Only parse when enabled; otherwise return None.""" | |
| if not enabled: | |
| return None | |
| m = re.search(r"(?im)^final_answer\s*:\s*(.+)$", text) | |
| return m.group(1).strip() if m else None | |
| def _convert_to_wav_mono16k(src_path: str) -> str: | |
| print("converting to mono16... from: ", src_path) | |
| out = os.path.join(tempfile.gettempdir(), f"gaia_{uuid.uuid4().hex}.wav") | |
| cmd = ["ffmpeg", "-y", "-i", src_path, "-ac", "1", "-ar", "16000", out] | |
| # Capture stderr for debugging | |
| p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| if p.returncode != 0 or not os.path.exists(out): | |
| raise RuntimeError(f"ffmpeg failed: {p.stderr[-500:]}") | |
| return out | |
| # ----------------------Tools ---------------------- | |
| def download_file(url: str, headers: dict | None = None, auth_token: str | None = None) -> str: | |
| """Download a file following redirects and honoring Content-Disposition. Returns local path.""" | |
| sess = requests.Session() | |
| hdrs = {"User-Agent": "gaia-agent/1.0"} | |
| if headers: | |
| hdrs.update(headers) | |
| if auth_token: | |
| hdrs["Authorization"] = f"Bearer {auth_token}" | |
| with sess.get(url, headers=hdrs, timeout=(10, 60), stream=True, allow_redirects=True) as r: | |
| r.raise_for_status() | |
| # Determine filename | |
| cd = r.headers.get("Content-Disposition", "") | |
| fname = _filename_from_cd(cd) | |
| if not fname: | |
| # Fallback to URL path | |
| path = urlparse(r.url).path or urlparse(url).path | |
| fname = os.path.basename(path) or f"download-{uuid.uuid4().hex}" | |
| # Ensure we have an extension | |
| base, ext = os.path.splitext(fname) | |
| if not ext: | |
| guess = _pick_extension(r.headers.get("Content-Type")) | |
| if guess: | |
| fname = base + guess | |
| # # Write to a temp folder (unique per call) | |
| out_dir = tempfile.mkdtemp(prefix="gaia_tmpdl_") | |
| out_path = os.path.join(out_dir, fname) | |
| # # Write to colab folder | |
| # out_dir: str | Path = "." | |
| # out_path = Path(out_dir) / fname | |
| print("out_path:", out_path) | |
| with open(out_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| return out_path | |
| def transcribe_audio(path: str, model_size: str = "base") -> str: | |
| """ | |
| Transcribe an audio file using Whisper (local). Converts to mono/16k WAV first for robustness. | |
| Returns the transcript text; raises on failure (caller handles). | |
| """ | |
| print("running transcribe_audio") | |
| try: | |
| model = whisper.load_model(model_size) | |
| result = model.transcribe(path) | |
| return (result.get("text") or "").strip() | |
| except Exception as e: | |
| raise RuntimeError(f"Whisper error: {e}") | |
| def ocr_image(path: str) -> str: | |
| """OCR an image using Tesseract.""" | |
| # Install tesseract binary on your system first | |
| print("running ocr") | |
| img = Image.open(path) | |
| text = pytesseract.image_to_string(img) | |
| return text.strip() | |
| # ------------------------------- Nodes ------------------------------ | |
| def check_attachment_node(state: AgentState) -> AgentState: | |
| """Check if there is attachment.""" | |
| print("enter check attachment node") | |
| # 1) Try HEAD first | |
| urls = state.get("attachment_urls") | |
| if not urls: | |
| print("No attachment URLs provided.") | |
| state["attachment_urls"] = [] | |
| return state | |
| url = urls[0] # Get the first URL from the list | |
| headers = {"Accept": "application/json"} | |
| timeout = 30 | |
| r = requests.head(url, headers=headers, allow_redirects=True, timeout=timeout) | |
| # Some servers don't support HEAD; 405/501 are common. Fallback to GET (stream) to read headers only. | |
| if r.status_code in (405, 501): | |
| r.close() | |
| r = requests.get(url, headers=headers, stream=True, allow_redirects=True, timeout=timeout) | |
| try: | |
| cd = r.headers.get("Content-Disposition", "") or r.headers.get("content-disposition", "") | |
| is_attachment = "attachment" in cd.lower() | |
| filename = None | |
| if is_attachment: | |
| m = re.search(r"filename\*=UTF-8''([^;]+)", cd, flags=re.I) | |
| if m: | |
| filename = unquote(m.group(1)) | |
| else: | |
| m = re.search(r'filename="?([^";]+)"?', cd, flags=re.I) | |
| if m: | |
| filename = m.group(1) | |
| print("Need to download attachment:", filename) | |
| else: | |
| print("No attachment header; skip downloading.") | |
| state["attachment_urls"] = [] | |
| return state | |
| finally: | |
| # If we fell back to GET(stream=True), make sure we don't keep the connection open. | |
| try: | |
| r.close() | |
| except Exception: | |
| pass | |
| def fetch_node(state: AgentState) -> AgentState: | |
| print("enter fetch_node") | |
| local_files = [] | |
| for u in state["attachment_urls"]: | |
| # If already local file paths, just append them | |
| if os.path.exists(u): | |
| local_files.append(u) | |
| else: | |
| p = download_file.invoke({"url": u}) | |
| local_files.append(p) | |
| state["local_files"] = local_files | |
| return state | |
| def preprocess_node(state: AgentState) -> AgentState: | |
| """ | |
| For each local file: | |
| - audio/* -> ASR transcript | |
| - image/* -> OCR text (basic enhancement to help OCR) | |
| - application/pdf -> text extraction (if pdfplumber available) | |
| Produces EvidenceItem entries and stores in state['evidence']. | |
| """ | |
| print("enter preprocessing node") | |
| ev: List[Dict[str, Any]] = list(state.get("evidence", [])) | |
| for path in state.get("local_files", []): | |
| mime, _ = mimetypes.guess_type(path) | |
| meta = {"mime": mime or "application/octet-stream", "filename": os.path.basename(path)} | |
| print("mime", mime) | |
| try: | |
| if mime and mime.startswith("audio"): | |
| print("mime start with audio") | |
| # print("path: ", path) | |
| # --- ASR --- | |
| try: | |
| wav = _convert_to_wav_mono16k(path) | |
| except Exception as e: | |
| raise RuntimeError(f"Pre-conversion error: {e}") | |
| print("after conversion saving at tmp_wav path: ", wav) | |
| txt = transcribe_audio.invoke({"path": wav}) | |
| ev.append({"kind": "audio_transcript", "text": txt, "path": path, "meta": meta}) | |
| elif mime and mime.startswith("image"): | |
| print("mime start with image") | |
| # --- OCR with simple pre-enhancement --- | |
| try: | |
| print("upscaling original small image: ", path) | |
| img = Image.open(path) | |
| img = img.convert("L") # grayscale | |
| w, h = img.size | |
| if max(w, h) < 1600: # upscale small images to help OCR | |
| img = img.resize((w * 2, h * 2)) | |
| tmp_ocr = os.path.join(tempfile.gettempdir(), f"ocr_{uuid.uuid4().hex}.png") | |
| img.save(tmp_ocr) | |
| print("After upscaling save at tmp_ocr path: ", tmp_ocr) | |
| ocr = ocr_image.invoke({"path": tmp_ocr}) | |
| except Exception as e: | |
| ocr = f"[OCR error: {e}]" | |
| ev.append({"kind": "image_ocr", "text": ocr, "path": path, "meta": meta}) | |
| elif mime == "application/pdf" or (mime and mime.startswith("application") and path.lower().endswith(".pdf")): | |
| # --- PDF extraction (best-effort; image-only PDFs may need OCR) --- | |
| if _HAS_PDFPLUMBER: | |
| try: | |
| pages = [] | |
| with pdfplumber.open(path) as pdf: | |
| for pg in pdf.pages: | |
| pages.append(pg.extract_text() or "") | |
| txt = "\n\n".join(pages).strip() or "[Empty or image-based PDF; try OCR]" | |
| except Exception as e: | |
| txt = f"[PDF parse error: {e}]" | |
| else: | |
| txt = "[PDF support not installed; pip install pdfplumber]" | |
| ev.append({"kind": "doc_text", "text": txt, "path": path, "meta": meta}) | |
| else: | |
| # Unknown/unsupported; keep a breadcrumb so you can inspect later | |
| ev.append({"kind": "unknown_file", "text": "[Unsupported file type]", "path": path, "meta": meta}) | |
| except Exception as e: | |
| ev.append({"kind": "preprocess_error", "text": f"[Error processing {path}: {e}]", "path": path, "meta": meta}) | |
| state["evidence"] = ev | |
| return state | |
| def solve_multimodal_node(state: AgentState) -> AgentState: | |
| """ | |
| Use a vision-capable model (e.g., gpt-4o) and attach the image(s) PLUS the text evidence (ASR/OCR). | |
| """ | |
| print("enter solve_multimodal_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>" | |
| question = state.get("question", "").strip() | |
| evidence = state.get("evidence", []) | |
| vision_llm = ChatOpenAI(model="gpt-4o", temperature=0) # vision-capable | |
| sys = SystemMessage(content=( | |
| "You solve GAIA tasks using the provided evidence and attached images.\n" | |
| "Be precise, quote numbers/strings exactly. If uncertain, say so.\n" | |
| "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr | |
| )) | |
| # Summarized text evidence (ASR/OCR/PDF text) | |
| ev_text = _summarize_evidence(evidence) | |
| text_part = ( | |
| f"Question:\n{question}\n\n" | |
| f"Textual evidence (summarized):\n{ev_text}\n\n" | |
| "Use the attached images if any to read fine text, diagrams, or confirm details." | |
| ) | |
| parts: List[Any] = [{"type": "text", "text": text_part}] | |
| # Attach up to 4 images (data URLs) | |
| img_paths = _collect_image_paths(evidence, limit=4) | |
| for p in img_paths: | |
| parts.append({"type": "image_url", "image_url": {"url": _image_to_data_url(p)}}) | |
| resp = vision_llm.invoke([sys, HumanMessage(content=parts)]) | |
| text = (resp.content or "").strip() | |
| text = _ensure_final_answer_line(text, enabled=emit) | |
| state["answer"] = text | |
| state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit) | |
| return state | |
| def solve_text_only_node(state: "AgentState") -> "AgentState": | |
| """ | |
| Text-only solve path. Consumes the question + textual evidence | |
| (e.g., audio transcripts from ASR, OCR text, PDF text). No images attached. | |
| """ | |
| print("enter solve_text_only_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>" | |
| question = (state.get("question") or "").strip() | |
| evidence = state.get("evidence", []) | |
| # Summarized text evidence (ASR/OCR/PDF text) | |
| ev_text = _summarize_evidence(evidence) or "(none)" | |
| # LLM (text-only). Swap model as you like. | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| sys = SystemMessage(content=( | |
| "You solve GAIA tasks. Use careful step-by-step reasoning but keep it concise.\n" | |
| "You can use the provided textual evidence if there is any. \n" | |
| "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr | |
| )) | |
| user = HumanMessage(content=( | |
| f"Question:\n{question}\n\n" | |
| f"Textual evidence (summarized):\n{ev_text}" | |
| )) | |
| resp = llm.invoke([sys, user]) | |
| text = (resp.content or "").strip() | |
| text = _ensure_final_answer_line(text, enabled=emit) | |
| state["answer"] = text | |
| state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit) | |
| return state | |
| def validate_format_node(state: AgentState) -> AgentState: | |
| """ | |
| Ensure the final output contains `final_answer: ...` and capture it separately for scoring. | |
| Also trims excessive whitespace and removes duplicate final_answer lines if any. | |
| """ | |
| print("enter validate_format_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| txt = (state.get("answer") or "").strip() | |
| if not txt: | |
| if emit: | |
| state["answer"] = "No answer generated.\n\nfinal_answer: [NO_ANSWER]" | |
| state["parsed_final_answer"] = "[NO_ANSWER]" | |
| else: | |
| state["answer"] = "No answer generated." | |
| state["parsed_final_answer"] = None | |
| return state | |
| if emit: | |
| # keep only the LAST final_answer line if multiple | |
| matches = list(re.finditer(r"(?im)^final_answer\s*:\s*(.+)$", txt)) | |
| if len(matches) == 0: | |
| txt = _ensure_final_answer_line(txt, enabled=True) | |
| elif len(matches) > 1: | |
| last = matches[-1].group(0) | |
| txt_wo = re.sub(r"(?im)^final_answer\s*:\s*.+\s*$", "", txt).strip() | |
| txt = f"{txt_wo}\n\n{last}" | |
| state["parsed_final_answer"] = _parse_final_answer(txt, enabled=True) | |
| else: | |
| # strip any lingering final_answer lines (paranoia) | |
| txt = _ensure_final_answer_line(txt, enabled=False) | |
| state["parsed_final_answer"] = None | |
| state["answer"] = txt.strip() | |
| return state | |
| # ------------------------------- Router functions ------------------------------ | |
| def route_intake(state: AgentState) -> Literal["with_files","no_files"]: | |
| """Route based on presence of attachments (purely programmatic).""" | |
| attachment_urls = state.get("attachment_urls") or [] # safe default | |
| return "with_files" if attachment_urls else "no_files" | |
| def has_images(state: AgentState) -> bool: | |
| for e in state.get("evidence", []): | |
| mime = (e.get("meta") or {}).get("mime", "") | |
| if str(mime).startswith("image"): | |
| return True | |
| return False | |
| def route_after_preprocess(state: AgentState) -> Literal["vision","text"]: | |
| return "vision" if has_images(state) else "text" | |
| # ---------- Graph ---------- | |
| # Build graph function | |
| def build_graph(): | |
| g = StateGraph(AgentState) | |
| g.add_node("check_attachment", check_attachment_node) | |
| g.add_node("fetch", fetch_node) | |
| g.add_node("preprocess", preprocess_node) | |
| g.add_node("solve_multimodal", solve_multimodal_node) | |
| g.add_node("solve_text_only", solve_text_only_node) | |
| g.add_node("validate", validate_format_node) | |
| # Start the edges | |
| g.add_edge(START, "check_attachment") | |
| # Add conditional branching from check_attachment | |
| g.add_conditional_edges( | |
| "check_attachment", | |
| route_intake, # returns "with_files" or "no_files" | |
| { | |
| "with_files": "fetch", | |
| "no_files": "solve_text_only" | |
| } | |
| ) | |
| # files branch | |
| g.add_edge("fetch", "preprocess") | |
| g.add_conditional_edges( | |
| "preprocess", | |
| route_after_preprocess, | |
| { | |
| "vision": "solve_multimodal", # question + evidence + attach images | |
| "text": "solve_text_only", # question + transcript/other text | |
| } | |
| ) | |
| # both branches converge | |
| g.add_edge("solve_multimodal", "validate") | |
| g.add_edge("solve_text_only", "validate") | |
| g.add_edge("validate", END) | |
| # Compile the graph | |
| graph_complied = g.compile() | |
| return graph_complied | |
| # test | |
| if __name__ == "__main__": | |
| task_id = '0001' | |
| task_q = 'Who is the current president of France' | |
| task_url = [] | |
| sample = { | |
| "task_id": task_id, | |
| "question": task_q, | |
| "attachment_urls": [task_url], # from GAIA sample | |
| "local_files": [], | |
| "evidence": [], | |
| "answer": None, | |
| "parsed_final_answer": None, | |
| "emit_final_answer": False, # <<< pure output mode | |
| } | |
| agent_GAIA = build_graph() | |
| out = agent_GAIA.invoke(sample) | |
| print("---------------------------") | |
| print(out["answer"]) |