Spaces:
Sleeping
Sleeping
| # Develop an AI agent with LangGraph and LangChain | |
| # to answer the questions in the "gaia-benchmark/GAIA" dataset. | |
| from pathlib import Path | |
| import os, re, base64, mimetypes, tempfile, uuid, subprocess, json | |
| from urllib.parse import urlparse, unquote | |
| from PIL import Image | |
| import pytesseract | |
| import whisper | |
| import requests | |
| from typing import TypedDict, List, Optional, Dict, Any, Literal | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import StateGraph, START, END | |
| from tavily import TavilyClient | |
| import serpapi | |
| import trafilatura | |
| from readability import Document | |
| import html as _html | |
| import wikipedia | |
| from urllib.parse import parse_qs | |
| from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound | |
| import yt_dlp | |
| # ==== NEW: (optional) tiny helpers used by browsing nodes ==== | |
| def _has_search_key() -> bool: | |
| """Return True if any supported search backend is configured.""" | |
| return bool( | |
| os.getenv("TAVILY_API_KEY") | |
| or os.getenv("SERPAPI_API_KEY") | |
| or (os.getenv("GOOGLE_API_KEY") and os.getenv("GOOGLE_CSE_ID")) | |
| ) | |
| # Optional: pdf parsing if GAIA sometimes includes PDFs | |
| try: | |
| import pdfplumber | |
| _HAS_PDFPLUMBER = True | |
| except Exception: | |
| _HAS_PDFPLUMBER = False | |
| # -------------- State ------------- | |
| class EvidenceItem(TypedDict): | |
| # ==== CHANGED: expanded allowed kinds to match actual usage paths ==== | |
| kind: Literal["audio_transcript","image_ocr","image_vqa","doc_text","unknown_file","preprocess_error"] | |
| text: str | |
| path: Optional[str] | |
| meta: Dict[str, Any] | |
| class AgentState(TypedDict): | |
| task_id: str | |
| question: str | |
| attachment_urls: List[str] # empty list when no files | |
| local_files: List[str] | |
| evidence: List[EvidenceItem] | |
| answer: Optional[str] | |
| parsed_final_answer: Optional[str] | |
| emit_final_answer: bool # <<< add this (default True if you want old behavior) | |
| # ==== NEW: state used by browse pipeline (optional) ==== | |
| use_browsing: Optional[bool] | |
| web_hits: Optional[List[Dict[str, str]]] | |
| # ==== NEW: urls found directly in the question ==== | |
| question_urls: Optional[List[str]] | |
| question_youtube_urls: Optional[List[str]] | |
| # -------------- helpers --------------- | |
| def _filename_from_cd(cd: str) -> str | None: | |
| # RFC 6266/5987: filename* takes precedence; fall back to filename | |
| if not cd: | |
| return None | |
| # filename*= | |
| m = re.search(r"filename\*\s*=\s*([^']*)'[^']*'([^;]+)", cd, flags=re.I) | |
| if m: | |
| return unquote(m.group(2)).strip().strip('"') | |
| # filename= | |
| m = re.search(r'filename\s*=\s*"?(.*?)(?:"|;|$)', cd, flags=re.I) | |
| if m: | |
| return m.group(1).strip().strip('"') | |
| return None | |
| def _pick_extension(ct: str | None) -> str | None: | |
| if not ct: | |
| return None | |
| ct = ct.split(";", 1)[0].strip() | |
| ext = mimetypes.guess_extension(ct) | |
| # Fix common mis-maps | |
| return {".jpe": ".jpg"}.get(ext, ext) | |
| def _summarize_evidence(evidence: List[Dict[str, Any]], limit_chars: int = 6000) -> str: | |
| """Compact the evidence text for prompting; keep provenance-style tags.""" | |
| chunks = [] | |
| for i, e in enumerate(evidence, 1): | |
| t = e.get("text", "") or "" | |
| if len(t) > 1200: # keep things small but informative | |
| t = t[:1200] + " …" | |
| meta = e.get("meta", {}) | |
| tag = f"{e.get('kind','?')}" | |
| if meta.get("mime"): | |
| tag += f"({meta['mime']})" | |
| if meta.get("title"): | |
| tag += f"[{meta['title']}]" | |
| if meta.get("url"): | |
| tag += f"<{meta['url']}>" | |
| chunks.append(f"[{i}:{tag}] {t}") | |
| out = "\n".join(chunks) | |
| return out if len(out) <= limit_chars else out[:limit_chars] + " …" | |
| def _collect_image_paths(evidence: List[Dict[str, Any]], limit: int = 4) -> List[str]: | |
| """Find image file paths to attach to a vision model.""" | |
| paths = [] | |
| for e in evidence: | |
| if e.get("path") and str(e.get("meta", {}).get("mime","")).startswith("image"): | |
| p = e["path"] | |
| if os.path.exists(p) and p not in paths: | |
| paths.append(p) | |
| if len(paths) >= limit: | |
| break | |
| return paths | |
| def _image_to_data_url(path: str) -> str: | |
| """Encode an image file as a data URL for OpenAI chat image parts.""" | |
| with open(path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| mime, _ = mimetypes.guess_type(path) | |
| mime = mime or "image/png" | |
| return f"data:{mime};base64,{b64}" | |
| def _ensure_final_answer_line(text: str, *, enabled: bool) -> str: | |
| """When enabled, ensure a `final_answer:` line. When disabled, strip any such line.""" | |
| if enabled: | |
| if re.search(r"(?im)^final_answer\s*:", text): | |
| return text | |
| # best-effort: take last non-empty line | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip() and not ln.strip().startswith("```")] | |
| candidate = lines[-1] if lines else "[NO_ANSWER]" | |
| return f"{text.rstrip()}\n\nfinal_answer: {candidate}" | |
| else: | |
| # remove any final_answer line(s) | |
| return re.sub(r"(?im)^final_answer\s*:\s*.*\n?", "", text).strip() | |
| def _parse_final_answer(text: str, *, enabled: bool) -> Optional[str]: | |
| """Only parse when enabled; otherwise return None.""" | |
| if not enabled: | |
| return None | |
| m = re.search(r"(?im)^final_answer\s*:\s*(.+)$", text) | |
| return m.group(1).strip() if m else None | |
| def _convert_to_wav_mono16k(src_path: str) -> str: | |
| print("converting to mono16... from: ", src_path) | |
| out = os.path.join(tempfile.gettempdir(), f"gaia_{uuid.uuid4().hex}.wav") | |
| cmd = ["ffmpeg", "-y", "-i", src_path, "-ac", "1", "-ar", "16000", out] | |
| # Capture stderr for debugging | |
| p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| if p.returncode != 0 or not os.path.exists(out): | |
| raise RuntimeError(f"ffmpeg failed: {p.stderr[-500:]}") | |
| return out | |
| # ==== NEW: URL helpers ==== | |
| _URL_RE = re.compile(r'https?://\S+') | |
| def _extract_urls(text: str) -> List[str]: | |
| return _URL_RE.findall(text or "") | |
| # ----------------------Tools ---------------------- | |
| def download_file(url: str, headers: dict | None = None, auth_token: str | None = None) -> str: | |
| """Download a file following redirects and honoring Content-Disposition. Returns local path.""" | |
| sess = requests.Session() | |
| hdrs = {"User-Agent": "gaia-agent/1.0"} | |
| if headers: | |
| hdrs.update(headers) | |
| if auth_token: | |
| hdrs["Authorization"] = f"Bearer {auth_token}" | |
| with sess.get(url, headers=hdrs, timeout=(10, 60), stream=True, allow_redirects=True) as r: | |
| r.raise_for_status() | |
| # Determine filename | |
| cd = r.headers.get("Content-Disposition", "") | |
| fname = _filename_from_cd(cd) | |
| if not fname: | |
| # Fallback to URL path | |
| path = urlparse(r.url).path or urlparse(url).path | |
| fname = os.path.basename(path) or f"download-{uuid.uuid4().hex}" | |
| # Ensure we have an extension | |
| base, ext = os.path.splitext(fname) | |
| if not ext: | |
| guess = _pick_extension(r.headers.get("Content-Type")) | |
| if guess: | |
| fname = base + guess | |
| # # Write to a temp folder (unique per call) | |
| out_dir = tempfile.mkdtemp(prefix="gaia_tmpdl_") | |
| out_path = os.path.join(out_dir, fname) | |
| print("out_path:", out_path) | |
| with open(out_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| return out_path | |
| # ==== NEW: cache Whisper model so we don't reload each call ==== | |
| _WHISPER = None | |
| def transcribe_audio(path: str, model_size: str = "base") -> str: | |
| """ | |
| Transcribe an audio file using Whisper (local). Converts to mono/16k WAV first for robustness. | |
| Returns the transcript text; raises on failure (caller handles). | |
| """ | |
| print("running transcribe_audio") | |
| global _WHISPER | |
| try: | |
| if _WHISPER is None: | |
| _WHISPER = whisper.load_model(model_size) | |
| result = _WHISPER.transcribe(path) | |
| return (result.get("text") or "").strip() | |
| except Exception as e: | |
| raise RuntimeError(f"Whisper error: {e}") | |
| def ocr_image(path: str) -> str: | |
| """OCR an image using Tesseract.""" | |
| # Install tesseract binary on your system first | |
| print("running ocr") | |
| img = Image.open(path) | |
| text = pytesseract.image_to_string(img) | |
| return text.strip() | |
| # ==== NEW: WEB / WIKI / YOUTUBE TOOLS ========================================= | |
| # Choose your search backend (Tavily simplest). Set env var before use. | |
| _USE_TAVILY = False # flip to False to use SerpAPI example | |
| if _USE_TAVILY: | |
| def web_search(query: str, k: int = 6) -> List[Dict[str, str]]: | |
| """ | |
| Web search via Tavily. Returns a list of {title, url, snippet}. | |
| Requires TAVILY_API_KEY. | |
| """ | |
| try: | |
| tv = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) | |
| res = tv.search( | |
| query=query, | |
| search_depth="advanced", | |
| max_results=k, | |
| include_answer=False, | |
| include_images=False, | |
| ) | |
| out = [] | |
| for r in res.get("results", []): | |
| out.append({ | |
| "title": r.get("title",""), | |
| "url": r.get("url",""), | |
| "snippet": (r.get("content","") or "")[:400] | |
| }) | |
| return out | |
| except Exception as e: | |
| return [{"title":"", "url":"", "snippet": f"[search error: {e}]"}] | |
| else: | |
| def web_search(query: str, k: int = 6) -> List[Dict[str, str]]: | |
| """ | |
| Web search via SerpAPI. Returns a list of {title, url, snippet}. | |
| Requires SERPAPI_API_KEY. | |
| """ | |
| try: | |
| params = {"engine":"google", "q":query, "num":k, "api_key":os.getenv("SERPAPI_API_KEY")} | |
| search = serpapi.search(params) | |
| # results = search.get_dict() | |
| results = search | |
| items = results.get("organic_results", []) | |
| out = [] | |
| for it in items[:k]: | |
| out.append({ | |
| "title": it.get("title",""), | |
| "url": it.get("link",""), | |
| "snippet": (it.get("snippet","") or "")[:400] | |
| }) | |
| return out | |
| except Exception as e: | |
| return [{"title":"", "url":"", "snippet": f"[search error: {e}]"}] | |
| def fetch_url_text(url: str, max_chars: int = 12000, timeout: int = 30) -> Dict[str, Any]: | |
| """ | |
| Download a web page and extract main article text using trafilatura, | |
| with a readability-lxml fallback. Returns {url, title, text}. | |
| """ | |
| sess = requests.Session() | |
| headers = { | |
| "User-Agent": "gaia-agent/1.0 (+https://example.org)", | |
| "Accept": "text/html,*/*;q=0.8", | |
| } | |
| try: | |
| r = sess.get(url, headers=headers, timeout=timeout) | |
| r.raise_for_status() | |
| html_content = r.text | |
| except Exception as e: | |
| return {"url": url, "title": "", "text": f"[fetch error: {e}]"} | |
| # 1) try trafilatura (best for boilerplate removal) | |
| try: | |
| downloaded = trafilatura.extract(html_content, include_comments=False, include_tables=False, url=url) | |
| if downloaded and len(downloaded) > 200: | |
| text = downloaded | |
| title = "" | |
| else: | |
| raise ValueError("trafilatura extraction too short") | |
| except Exception: | |
| # 2) fallback: readability | |
| try: | |
| doc = Document(html_content) | |
| title = doc.short_title() or "" | |
| text = doc.summary(html_partial=False) | |
| # rudimentary HTML strip | |
| text = re.sub(r"<[^>]+>", " ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| except Exception as e2: | |
| return {"url": url, "title": "", "text": f"[extraction error: {e2}]"} | |
| if len(text) > max_chars: | |
| text = text[:max_chars] + " …" | |
| # Try to fill title if empty | |
| if not title: | |
| m = re.search(r"<title[^>]*>(.*?)</title>", html_content, flags=re.I|re.S) | |
| if m: | |
| title = _html.unescape(m.group(1).strip()) | |
| return {"url": url, "title": title or "", "text": text} | |
| def wikipedia_lookup(query: str, sentences: int = 4) -> Dict[str, Any]: | |
| """ | |
| Simple Wikipedia lookup. Returns {title, url, summary}. | |
| """ | |
| try: | |
| wikipedia.set_lang("en") | |
| try: | |
| title = wikipedia.search(query, results=1)[0] | |
| except Exception as e: | |
| return {"title":"", "url":"", "summary": f"[wikipedia search error: {e}]"} | |
| try: | |
| summary = wikipedia.summary(title, sentences=sentences, auto_suggest=False) | |
| page = wikipedia.page(title, auto_suggest=False, preload=False) | |
| return {"title": page.title, "url": page.url, "summary": summary} | |
| except Exception as e: | |
| return {"title": title, "url":"", "summary": f"[wikipedia fetch error: {e}]"} | |
| except Exception as e: | |
| return {"title":"", "url":"", "summary": f"[wikipedia import error: {e}]"} | |
| def youtube_get_transcript(url_or_id: str, prefer_langs: List[str] | None = None) -> str: | |
| """ | |
| Get YouTube transcript via API (no download). Returns plain text. | |
| """ | |
| print('try to get youtube video transcript') | |
| try: | |
| prefer_langs = prefer_langs or ["en", "en-US", "en-GB", "auto"] | |
| vid = url_or_id | |
| print("vid: ", vid) | |
| if "youtube.com" in url_or_id or "youtu.be" in url_or_id: | |
| u = urlparse(url_or_id) | |
| if u.netloc.endswith("youtu.be"): | |
| vid = u.path.lstrip("/") | |
| else: | |
| vid = parse_qs(u.query).get("v", [""])[0] | |
| ytt_api = YouTubeTranscriptApi() | |
| trs_list = ytt_api.list(vid) | |
| # choose first matching language | |
| for lang in prefer_langs: | |
| try: | |
| trs = trs_list.find_transcript([lang]) | |
| chunks = trs.fetch() | |
| print("transcript from youtube website?") | |
| print(" ".join([c["text"] for c in chunks if c.get("text")]).strip()) | |
| return " ".join([c["text"] for c in chunks if c.get("text")]).strip() | |
| except Exception: | |
| continue | |
| # fallback: first any transcript | |
| trs = list(trs_list)[0] | |
| chunks = trs.fetch() | |
| print("transcript from youtube website?") | |
| print(" ".join([c["text"] for c in chunks if c.get("text")]).strip()) | |
| return " ".join([c["text"] for c in chunks if c.get("text")]).strip() | |
| except (TranscriptsDisabled, NoTranscriptFound): | |
| return "[no captions available]" | |
| except Exception as e: | |
| return f"[youtube transcript error: {e}]" | |
| def youtube_transcribe_audio(url: str, model_size: str = "base") -> str: | |
| """ | |
| Download YouTube audio (yt-dlp) and transcribe with Whisper. | |
| """ | |
| tmpdir = tempfile.mkdtemp(prefix="gaia_yt_") | |
| outfile = os.path.join(tmpdir, "%(id)s.%(ext)s") | |
| ydl_opts = { | |
| "format": "bestaudio/best", | |
| "outtmpl": outfile, | |
| "quiet": True, | |
| "no_warnings": True, | |
| "noplaylist": True, | |
| } | |
| try: | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(url, download=True) | |
| path = ydl.prepare_filename(info) | |
| # convert & transcribe | |
| wav = _convert_to_wav_mono16k(path) | |
| txt = transcribe_audio.invoke({"path": wav, "model_size": model_size}) | |
| return txt | |
| except Exception as e: | |
| return f"[youtube download/transcribe error: {e}]" | |
| # ------------------------------- Nodes ------------------------------ | |
| def check_attachment_node(state: AgentState) -> AgentState: | |
| """Check if there is attachment.""" | |
| print("enter check attachment node") | |
| # 1) Try HEAD first | |
| urls = state.get("attachment_urls") | |
| if not urls: | |
| print("No attachment URLs provided.") | |
| state["attachment_urls"] = [] | |
| return state | |
| url = urls[0] # Get the first URL from the list | |
| headers = {"Accept": "application/json"} | |
| timeout = 30 | |
| r = requests.head(url, headers=headers, allow_redirects=True, timeout=timeout) | |
| # Some servers don't support HEAD; 405/501 are common. Fallback to GET (stream) to read headers only. | |
| if r.status_code in (405, 501): | |
| r.close() | |
| r = requests.get(url, headers=headers, stream=True, allow_redirects=True, timeout=timeout) | |
| try: | |
| cd = r.headers.get("Content-Disposition", "") or r.headers.get("content-disposition", "") | |
| is_attachment = "attachment" in cd.lower() | |
| filename = None | |
| if is_attachment: | |
| m = re.search(r"filename\*=UTF-8''([^;]+)", cd, flags=re.I) | |
| if m: | |
| filename = unquote(m.group(1)) | |
| else: | |
| m = re.search(r'filename="?([^";]+)"?', cd, flags=re.I) | |
| if m: | |
| filename = m.group(1) | |
| print("Need to download attachment:", filename) | |
| else: | |
| print("No attachment header; skip downloading.") | |
| state["attachment_urls"] = [] | |
| return state | |
| finally: | |
| # If we fell back to GET(stream=True), make sure we don't keep the connection open. | |
| try: | |
| r.close() | |
| except Exception: | |
| pass | |
| def fetch_node(state: AgentState) -> AgentState: | |
| print("enter fetch_node") | |
| local_files = [] | |
| for u in state["attachment_urls"]: | |
| # If already local file paths, just append them | |
| if os.path.exists(u): | |
| local_files.append(u) | |
| else: | |
| p = download_file.invoke({"url": u}) | |
| local_files.append(p) | |
| state["local_files"] = local_files | |
| return state | |
| def preprocess_node(state: AgentState) -> AgentState: | |
| """ | |
| For each local file: | |
| - audio/* -> ASR transcript | |
| - image/* -> OCR text (basic enhancement to help OCR) | |
| - application/pdf -> text extraction (if pdfplumber available) | |
| Produces EvidenceItem entries and stores in state['evidence']. | |
| """ | |
| print("enter preprocessing node") | |
| ev: List[Dict[str, Any]] = list(state.get("evidence", [])) | |
| for path in state.get("local_files", []): | |
| mime, _ = mimetypes.guess_type(path) | |
| meta = {"mime": mime or "application/octet-stream", "filename": os.path.basename(path)} | |
| print("mime", mime) | |
| try: | |
| if mime and mime.startswith("audio"): | |
| print("mime start with audio") | |
| # --- ASR --- | |
| try: | |
| wav = _convert_to_wav_mono16k(path) | |
| except Exception as e: | |
| raise RuntimeError(f"Pre-conversion error: {e}") | |
| print("after conversion saving at tmp_wav path: ", wav) | |
| txt = transcribe_audio.invoke({"path": wav}) | |
| ev.append({"kind": "audio_transcript", "text": txt, "path": path, "meta": meta}) | |
| elif mime and mime.startswith("image"): | |
| print("mime start with image") | |
| # --- OCR with simple pre-enhancement --- | |
| try: | |
| print("upscaling original small image: ", path) | |
| img = Image.open(path) | |
| img = img.convert("L") # grayscale | |
| w, h = img.size | |
| if max(w, h) < 1600: # upscale small images to help OCR | |
| img = img.resize((w * 2, h * 2)) | |
| tmp_ocr = os.path.join(tempfile.gettempdir(), f"ocr_{uuid.uuid4().hex}.png") | |
| img.save(tmp_ocr) | |
| print("After upscaling save at tmp_ocr path: ", tmp_ocr) | |
| ocr = ocr_image.invoke({"path": tmp_ocr}) | |
| except Exception as e: | |
| ocr = f"[OCR error: {e}]" | |
| ev.append({"kind": "image_ocr", "text": ocr, "path": path, "meta": meta}) | |
| elif mime == "application/pdf" or (mime and mime.startswith("application") and path.lower().endswith(".pdf")): | |
| # --- PDF extraction (best-effort; image-only PDFs may need OCR) --- | |
| if _HAS_PDFPLUMBER: | |
| try: | |
| pages = [] | |
| with pdfplumber.open(path) as pdf: | |
| for pg in pdf.pages: | |
| pages.append(pg.extract_text() or "") | |
| txt = "\n\n".join(pages).strip() or "[Empty or image-based PDF; try OCR]" | |
| except Exception as e: | |
| txt = f"[PDF parse error: {e}]" | |
| else: | |
| txt = "[PDF support not installed; pip install pdfplumber]" | |
| ev.append({"kind": "doc_text", "text": txt, "path": path, "meta": meta}) | |
| else: | |
| # Unknown/unsupported; keep a breadcrumb so you can inspect later | |
| ev.append({"kind": "unknown_file", "text": "[Unsupported file type]", "path": path, "meta": meta}) | |
| except Exception as e: | |
| ev.append({"kind": "preprocess_error", "text": f"[Error processing {path}: {e}]", "path": path, "meta": meta}) | |
| state["evidence"] = ev | |
| return state | |
| def solve_multimodal_node(state: AgentState) -> AgentState: | |
| """ | |
| Use a vision-capable model (e.g., gpt-4o) and attach the image(s) PLUS the text evidence (ASR/OCR). | |
| """ | |
| print("enter solve_multimodal_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>" | |
| question = state.get("question", "").strip() | |
| evidence = state.get("evidence", []) | |
| vision_llm = ChatOpenAI(model="gpt-4o", temperature=0) # vision-capable | |
| sys = SystemMessage(content=( | |
| "You solve GAIA tasks using the provided evidence and attached images.\n" | |
| "Be precise, quote numbers/strings exactly. If uncertain, say so.\n" | |
| "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr | |
| )) | |
| # Summarized text evidence (ASR/OCR/PDF text) | |
| ev_text = _summarize_evidence(evidence) | |
| text_part = ( | |
| f"Question:\n{question}\n\n" | |
| f"Textual evidence (summarized):\n{ev_text}\n\n" | |
| "Use the attached images if any to read fine text, diagrams, or confirm details." | |
| ) | |
| parts: List[Any] = [{"type": "text", "text": text_part}] | |
| # Attach up to 4 images (data URLs) | |
| img_paths = _collect_image_paths(evidence, limit=4) | |
| for p in img_paths: | |
| parts.append({"type": "image_url", "image_url": {"url": _image_to_data_url(p)}}) | |
| resp = vision_llm.invoke([sys, HumanMessage(content=parts)]) | |
| text = (resp.content or "").strip() | |
| text = _ensure_final_answer_line(text, enabled=emit) | |
| state["answer"] = text | |
| state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit) | |
| return state | |
| def solve_text_only_node(state: "AgentState") -> "AgentState": | |
| """ | |
| Text-only solve path. Consumes the question + textual evidence | |
| (e.g., audio transcripts from ASR, OCR text, PDF text). No images attached. | |
| """ | |
| print("enter solve_text_only_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| end_instr = "" if not emit else " End your output with a single line: final_answer: <answer>" | |
| question = (state.get("question") or "").strip() | |
| evidence = state.get("evidence", []) | |
| # Summarized text evidence (ASR/OCR/PDF text) | |
| ev_text = _summarize_evidence(evidence) or "(none)" | |
| # LLM (text-only). Swap model as you like. | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| sys = SystemMessage(content=( | |
| "You solve GAIA tasks. Use careful step-by-step reasoning but keep it concise.\n" | |
| "You can use the provided textual evidence if there is any. \n" | |
| "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr | |
| )) | |
| user = HumanMessage(content=( | |
| f"Question:\n{question}\n\n" | |
| f"Textual evidence (summarized):\n{ev_text}" | |
| )) | |
| resp = llm.invoke([sys, user]) | |
| text = (resp.content or "").strip() | |
| text = _ensure_final_answer_line(text, enabled=emit) | |
| state["answer"] = text | |
| state["parsed_final_answer"] = _parse_final_answer(text, enabled=emit) | |
| return state | |
| def validate_format_node(state: AgentState) -> AgentState: | |
| """ | |
| Ensure the final output contains `final_answer: ...` and capture it separately for scoring. | |
| Also trims excessive whitespace and removes duplicate final_answer lines if any. | |
| """ | |
| print("enter validate_format_node") | |
| emit = bool(state.get("emit_final_answer", True)) | |
| txt = (state.get("answer") or "").strip() | |
| if not txt: | |
| if emit: | |
| state["answer"] = "No answer generated.\n\nfinal_answer: [NO_ANSWER]" | |
| state["parsed_final_answer"] = "[NO_ANSWER]" | |
| else: | |
| state["answer"] = "No answer generated." | |
| state["parsed_final_answer"] = None | |
| return state | |
| if emit: | |
| # keep only the LAST final_answer line if multiple | |
| matches = list(re.finditer(r"(?im)^final_answer\s*:\s*(.+)$", txt)) | |
| if len(matches) == 0: | |
| txt = _ensure_final_answer_line(txt, enabled=True) | |
| elif len(matches) > 1: | |
| last = matches[-1].group(0) | |
| txt_wo = re.sub(r"(?im)^final_answer\s*:\s*.+\s*$", "", txt).strip() | |
| txt = f"{txt_wo}\n\n{last}" | |
| state["parsed_final_answer"] = _parse_final_answer(txt, enabled=True) | |
| else: | |
| # strip any lingering final_answer lines (paranoia) | |
| txt = _ensure_final_answer_line(txt, enabled=False) | |
| state["parsed_final_answer"] = None | |
| state["answer"] = txt.strip() | |
| return state | |
| # ------------------------------- Router functions ------------------------------ | |
| def route_intake(state: AgentState) -> Literal["with_files","no_files"]: | |
| """Route based on presence of attachments (purely programmatic).""" | |
| attachment_urls = state.get("attachment_urls") or [] # safe default | |
| return "with_files" if attachment_urls else "no_files" | |
| def has_images(state: AgentState) -> bool: | |
| for e in state.get("evidence", []): | |
| mime = (e.get("meta") or {}).get("mime", "") | |
| if str(mime).startswith("image"): | |
| return True | |
| return False | |
| # ==== CHANGED: fix return type Literal to match actual branch key ==== | |
| def route_after_preprocess(state: AgentState) -> Literal["vision","text"]: | |
| return "vision" if has_images(state) else "text" | |
| # ==== NEW: Browsing router ==== | |
| def needs_browsing(q: str) -> bool: | |
| q = (q or "").lower() | |
| hot = ["today","current","latest","price","How","who","where","what","How many", | |
| "2023","2024","2025","news","wins","Which", | |
| "http://","https://","wikipedia","youtube.com"] | |
| # Only browse if we *also* have a search key, so the sample runs without keys. | |
| return _has_search_key() and any(w in q for w in hot) | |
| # ==== NEW: Decide browse node ==== | |
| def decide_browse_node(state: AgentState) -> AgentState: | |
| print("enter decide_browse_node") | |
| q = state.get("question", "") | |
| urls = _extract_urls(q) | |
| yt_urls = [u for u in urls if _is_youtube(u)] | |
| # Save for later stages | |
| state["question_urls"] = urls | |
| state["question_youtube_urls"] = yt_urls | |
| # Browse if: | |
| # - we have any YouTube links in the question (can handle w/o search key), OR | |
| # - the normal heuristic says we should browse (requires a search key) | |
| state["use_browsing"] = bool(yt_urls) or needs_browsing(q) | |
| return state | |
| def route_browse(state: AgentState) -> Literal["browse","skip"]: | |
| return "browse" if state.get("use_browsing") else "skip" | |
| # ==== NEW: Search node ==== | |
| def search_node(state: AgentState) -> AgentState: | |
| print("enter search_node") | |
| q = state.get("question","") | |
| # Start with YouTube links found in the question | |
| preseed = [{"title": "(from question)", "url": u, "snippet": ""} | |
| for u in (state.get("question_youtube_urls") + state.get("question_urls") or [])] | |
| # Do a web search only if keys are configured | |
| hits = [] | |
| if _has_search_key(): | |
| hits = web_search.invoke({"query": q, "k": 6}) or [] | |
| # Optionally seed Wikipedia for short queries | |
| if len(q.split()) <= 30: #8 | |
| wiki = wikipedia_lookup.invoke({"query": q, "sentences": 4}) | |
| if (wiki.get("summary") or "").strip(): | |
| state.setdefault("evidence", []).append({ | |
| "kind": "doc_text", | |
| "text": wiki["summary"], | |
| "path": None, | |
| "meta": {"source": "wikipedia", "title": wiki.get("title",""), | |
| "url": wiki.get("url",""), "mime":"text/plain"} | |
| }) | |
| # Combine: question YouTube links first, then search hits | |
| state["web_hits"] = preseed + hits | |
| return state | |
| def _is_youtube(u: str) -> bool: | |
| try: | |
| net = urlparse(u).netloc.lower() | |
| return ("youtube.com" in net) or ("youtu.be" in net) | |
| except Exception: | |
| return False | |
| def crawl_node(state: AgentState) -> AgentState: | |
| print("enter crawl_node") | |
| ev = list(state.get("evidence", [])) | |
| hits: List[Dict[str,str]] = state.get("web_hits", []) or [] | |
| print("hits: ", hits) | |
| # choose top M distinct domains | |
| def _domain(u: str) -> str: | |
| try: return urlparse(u).netloc.lower().lstrip("www.") | |
| except: return "" | |
| seen_domains = set() | |
| picked = [] | |
| for h in hits: | |
| u = h.get("url","") | |
| d = _domain(u) | |
| if not u or not d: | |
| continue | |
| if d in seen_domains: | |
| continue | |
| seen_domains.add(d) | |
| picked.append(h) | |
| if len(picked) >= 4: | |
| break | |
| print("picked: ", picked) | |
| # Fetch & extract | |
| for h in picked: | |
| u = h["url"] | |
| print("url: ", u) | |
| title = h.get("title","") | |
| # Special-case YouTube | |
| if _is_youtube(u): | |
| print("is_youtube? ", _is_youtube(u)) | |
| cap = youtube_get_transcript.invoke({"url_or_id": u}) | |
| print('cap: ', cap) | |
| if cap and not cap.startswith("[no captions"): | |
| ev.append({"kind":"doc_text","text":cap,"path":None, | |
| "meta":{"source":"youtube","title": title, "url":u,"mime":"text/plain"}}) | |
| continue | |
| # fallback: download+ASR (heavier) | |
| cap2 = youtube_transcribe_audio.invoke({"url": u, "model_size":"base"}) | |
| ev.append({"kind":"audio_transcript","text":cap2,"path":None, | |
| "meta":{"source":"youtube","title": title, "url":u,"mime":"audio"}}) | |
| continue | |
| out = fetch_url_text.invoke({"url": u, "max_chars": 12000}) | |
| text = out.get("text","") or "" | |
| page_title = out.get("title","") or title | |
| if not text: | |
| continue | |
| ev.append({ | |
| "kind": "doc_text", | |
| "text": text, | |
| "path": None, | |
| "meta": {"source":"web", "title": page_title, "url": u, "mime":"text/html"} | |
| }) | |
| state["evidence"] = ev | |
| return state | |
| # ---------- Graph ---------- | |
| # Build graph function | |
| def build_graph(): | |
| g = StateGraph(AgentState) | |
| # ==== NEW: browsing nodes ==== | |
| g.add_node("decide_browse", decide_browse_node) | |
| g.add_node("search", search_node) | |
| g.add_node("crawl", crawl_node) | |
| # Existing nodes | |
| g.add_node("check_attachment", check_attachment_node) | |
| g.add_node("fetch", fetch_node) | |
| g.add_node("preprocess", preprocess_node) | |
| g.add_node("solve_multimodal", solve_multimodal_node) | |
| g.add_node("solve_text_only", solve_text_only_node) | |
| g.add_node("validate", validate_format_node) | |
| # Start the edges | |
| g.add_edge(START, "decide_browse") | |
| # Browse or skip | |
| g.add_conditional_edges("decide_browse", route_browse, { | |
| "browse": "search", | |
| "skip": "check_attachment" | |
| }) | |
| g.add_edge("search", "crawl") | |
| g.add_edge("crawl", "check_attachment") | |
| # Add conditional branching from check_attachment | |
| g.add_conditional_edges( | |
| "check_attachment", | |
| route_intake, # returns "with_files" or "no_files" | |
| { | |
| "with_files": "fetch", | |
| "no_files": "solve_text_only" | |
| } | |
| ) | |
| # files branch | |
| g.add_edge("fetch", "preprocess") | |
| g.add_conditional_edges( | |
| "preprocess", | |
| route_after_preprocess, | |
| { | |
| "vision": "solve_multimodal", # question + evidence + attach images | |
| "text": "solve_text_only", # question + transcript/other text | |
| } | |
| ) | |
| # both branches converge | |
| g.add_edge("solve_multimodal", "validate") | |
| g.add_edge("solve_text_only", "validate") | |
| g.add_edge("validate", END) | |
| # Compile the graph | |
| graph_complied = g.compile() | |
| return graph_complied | |
| # test | |
| if __name__ == "__main__": | |
| task_id = '0001' | |
| task_q = 'Who is the current president of France' | |
| # ==== CHANGED: make it a flat empty list (not `[[]]`) | |
| attachment_urls: List[str] = [] | |
| sample: AgentState = { | |
| "task_id": task_id, | |
| "question": task_q, | |
| "attachment_urls": attachment_urls, # from GAIA sample | |
| "local_files": [], | |
| "evidence": [], | |
| "answer": None, | |
| "parsed_final_answer": None, | |
| # Tip: set True to force a final_answer line for scoring | |
| "emit_final_answer": False, # <<< pure output mode | |
| # new optional fields: | |
| "use_browsing": None, | |
| "web_hits": None, | |
| "question_urls": None, | |
| "question_youtube_urls": None | |
| } | |
| agent_GAIA = build_graph() | |
| out = agent_GAIA.invoke(sample) | |
| print("---------------------------") | |
| print(out["answer"]) | |