Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced GAIA Agent with LangGraph - Fixed Version | |
| Supports Ollama (local) and OpenAI (production) | |
| """ | |
| import os | |
| import re | |
| import json | |
| import requests | |
| import time | |
| import logging | |
| import base64 | |
| from typing import TypedDict, Annotated, Sequence, Literal | |
| import operator | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage | |
| from langchain_core.tools import tool | |
| from langchain_community.tools import DuckDuckGoSearchResults | |
| from langchain_experimental.utilities import PythonREPL | |
| import pandas as pd | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============ CONFIGURATION ============ | |
| OLLAMA_MODEL = "qwen2.5:32b" # Vision-capable model for image support | |
| OLLAMA_BASE_URL = "http://localhost:11434" | |
| OPENAI_MODEL = "gpt-4o" | |
| # Vision-capable Ollama models | |
| VISION_MODEL_KEYWORDS = ["vision", "vl", "llava", "bakllava", "gemma3", "qwen2.5-vl", "llama3.2-vision"] | |
| def _is_vision_model(model_name: str) -> bool: | |
| """Check if the model name suggests vision capability.""" | |
| if not model_name: | |
| return False | |
| model_lower = model_name.lower() | |
| return any(keyword in model_lower for keyword in VISION_MODEL_KEYWORDS) | |
| def is_ollama_available() -> bool: | |
| """Check if Ollama is running locally.""" | |
| try: | |
| response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=2) | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| def is_production() -> bool: | |
| """Check if running on HuggingFace Spaces.""" | |
| return bool(os.environ.get("SPACE_ID")) | |
| # ============ STATE DEFINITION ============ | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| task_id: str | |
| file_path: str | None | |
| iteration_count: int | |
| final_answer: str | None | |
| # ============ TOOL DEFINITIONS ============ | |
| def web_search(query: str) -> str: | |
| """ | |
| Search the web for current information using DuckDuckGo. | |
| Use for recent events, facts, statistics, or information you're uncertain about. | |
| Args: | |
| query: Search query string | |
| """ | |
| for name in ["ddgs.ddgs", "primp"]: | |
| logging.getLogger(name).setLevel(logging.ERROR) | |
| try: | |
| search = DuckDuckGoSearchResults(max_results=8, output_format="list") | |
| results = search.run(query) | |
| if isinstance(results, list): | |
| formatted = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted.append( | |
| f"Title: {r.get('title', 'N/A')}\n" | |
| f"Snippet: {r.get('snippet', 'N/A')}\n" | |
| f"Link: {r.get('link', 'N/A')}" | |
| ) | |
| return "\n\n---\n\n".join(formatted) if formatted else "No results found." | |
| return str(results) if results else "No results found." | |
| except Exception as e: | |
| return f"Search failed: {e}" | |
| def python_executor(code: str) -> str: | |
| """ | |
| Execute Python code for calculations, data analysis, or computational tasks. | |
| Available libraries: math, statistics, datetime, json, re, collections, pandas, numpy. | |
| Use print() to see output. | |
| Args: | |
| code: Python code to execute | |
| """ | |
| try: | |
| repl = PythonREPL() | |
| augmented_code = """ | |
| import math | |
| import statistics | |
| import datetime | |
| import json | |
| import re | |
| from collections import Counter, defaultdict | |
| import pandas as pd | |
| import numpy as np | |
| from fractions import Fraction | |
| from decimal import Decimal | |
| """ + code | |
| result = repl.run(augmented_code) | |
| output = result.strip() if result else "Code executed with no output. Use print()." | |
| if len(output) > 5000: | |
| output = output[:5000] + "\n... (truncated)" | |
| return output | |
| except Exception as e: | |
| return f"Execution error: {e}" | |
| def read_file(file_path: str) -> str: | |
| """ | |
| Read content from files. Supports: PDF, TXT, CSV, JSON, XLSX, XLS, PY, MP3, WAV, images. | |
| ALWAYS use this FIRST when a file is provided. | |
| Args: | |
| file_path: Path to the file | |
| """ | |
| try: | |
| if not os.path.exists(file_path): | |
| return f"Error: File not found at {file_path}" | |
| file_lower = file_path.lower() | |
| # Audio files | |
| if file_lower.endswith(('.mp3', '.wav', '.m4a', '.ogg', '.flac', '.webm')): | |
| return _transcribe_audio(file_path) | |
| # Image files - return path for vision model | |
| if file_lower.endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp')): | |
| return f"IMAGE_FILE:{file_path}" | |
| # PDF files | |
| if file_lower.endswith('.pdf'): | |
| try: | |
| from langchain_community.document_loaders import PyPDFLoader | |
| loader = PyPDFLoader(file_path) | |
| pages = loader.load() | |
| content = "\n\n--- Page Break ---\n\n".join([p.page_content for p in pages]) | |
| return f"PDF Content ({len(pages)} pages):\n{content}" | |
| except Exception as e: | |
| try: | |
| import pdfplumber | |
| with pdfplumber.open(file_path) as pdf: | |
| text = [] | |
| for i, page in enumerate(pdf.pages): | |
| page_text = page.extract_text() or "" | |
| tables = page.extract_tables() | |
| table_text = "" | |
| for table in tables: | |
| if table: | |
| table_text += "\n[TABLE]\n" | |
| for row in table: | |
| table_text += " | ".join(str(c) if c else "" for c in row) + "\n" | |
| text.append(f"Page {i+1}:\n{page_text}\n{table_text}") | |
| return f"PDF Content:\n" + "\n\n".join(text) | |
| except: | |
| return f"Error reading PDF: {e}" | |
| # Excel files | |
| if file_lower.endswith(('.xlsx', '.xls')): | |
| df_dict = pd.read_excel(file_path, sheet_name=None) | |
| result = [] | |
| for sheet_name, df in df_dict.items(): | |
| result.append(f"=== Sheet: {sheet_name} ({len(df)} rows) ===") | |
| result.append(f"Columns: {list(df.columns)}") | |
| result.append(df.to_string(max_rows=200)) | |
| return "\n\n".join(result) | |
| # CSV files | |
| if file_lower.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| return f"CSV ({len(df)} rows):\nColumns: {list(df.columns)}\n{df.to_string(max_rows=200)}" | |
| # JSON files | |
| if file_lower.endswith('.json'): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return f"JSON:\n{json.dumps(data, indent=2)}" | |
| # Default: text | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: | |
| content = f.read() | |
| if len(content) > 15000: | |
| content = content[:15000] + "\n... (truncated)" | |
| return f"File Content:\n{content}" | |
| except Exception as e: | |
| return f"Error reading file: {e}" | |
| def _transcribe_audio(file_path: str) -> str: | |
| """Transcribe audio using local Whisper (faster-whisper).""" | |
| try: | |
| from faster_whisper import WhisperModel | |
| # Use base model for speed, can be upgraded to "small", "medium", "large" for better accuracy | |
| model = WhisperModel("base", device="cpu", compute_type="int8") | |
| segments, info = model.transcribe(file_path, beam_size=5) | |
| transcript = " ".join([segment.text for segment in segments]) | |
| return f"Audio Transcription:\n{transcript}" | |
| except ImportError: | |
| return "Error: faster-whisper not installed. Install with: pip install faster-whisper" | |
| except Exception as e: | |
| logger.error(f"Audio transcription error: {e}") | |
| return f"Audio transcription failed: {e}" | |
| def calculator(expression: str) -> str: | |
| """ | |
| Evaluate mathematical expressions safely. | |
| Args: | |
| expression: Math expression like "sqrt(16) + log(100, 10)" | |
| """ | |
| try: | |
| import math | |
| safe_dict = { | |
| 'abs': abs, 'round': round, 'min': min, 'max': max, | |
| 'sum': sum, 'pow': pow, 'int': int, 'float': float, | |
| 'sqrt': math.sqrt, 'log': math.log, 'log10': math.log10, | |
| 'log2': math.log2, 'exp': math.exp, | |
| 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, | |
| 'ceil': math.ceil, 'floor': math.floor, | |
| 'pi': math.pi, 'e': math.e, 'factorial': math.factorial, | |
| } | |
| result = eval(expression, {"__builtins__": {}}, safe_dict) | |
| if isinstance(result, float) and result.is_integer(): | |
| return str(int(result)) | |
| return f"{result:.10g}" if isinstance(result, float) else str(result) | |
| except Exception as e: | |
| return f"Calculation error: {e}" | |
| def wikipedia_search(query: str) -> str: | |
| """ | |
| Search Wikipedia for factual information. | |
| Best for historical facts, biographies, scientific concepts. | |
| Args: | |
| query: Topic to search | |
| """ | |
| try: | |
| import urllib.parse | |
| search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={urllib.parse.quote(query)}&format=json&srlimit=3" | |
| response = requests.get(search_url, timeout=15) | |
| data = response.json() | |
| if 'query' not in data or not data['query'].get('search'): | |
| return f"No Wikipedia articles found for '{query}'" | |
| results = [] | |
| for item in data['query']['search'][:2]: | |
| title = item['title'] | |
| content_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=false&explaintext=true&titles={urllib.parse.quote(title)}&format=json&exchars=4000" | |
| content_response = requests.get(content_url, timeout=15) | |
| pages = content_response.json().get('query', {}).get('pages', {}) | |
| for page_id, page_data in pages.items(): | |
| if page_id != '-1': | |
| results.append(f"## {title}\n{page_data.get('extract', 'No content')}") | |
| return "\n\n---\n\n".join(results) if results else "No content found." | |
| except Exception as e: | |
| return f"Wikipedia search failed: {e}" | |
| def fetch_webpage(url: str) -> str: | |
| """ | |
| Fetch and extract text from a webpage URL. | |
| Args: | |
| url: The webpage URL | |
| """ | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0 (compatible; GaiaBot/1.0)'} | |
| response = requests.get(url, headers=headers, timeout=15) | |
| response.raise_for_status() | |
| try: | |
| from bs4 import BeautifulSoup | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| for el in soup(['script', 'style', 'nav', 'footer', 'header']): | |
| el.decompose() | |
| text = soup.get_text(separator='\n', strip=True) | |
| lines = [l.strip() for l in text.splitlines() if l.strip()] | |
| text = '\n'.join(lines) | |
| if len(text) > 10000: | |
| text = text[:10000] + "\n... (truncated)" | |
| return f"Webpage ({url}):\n{text}" | |
| except ImportError: | |
| return f"Raw HTML:\n{response.text[:10000]}" | |
| except Exception as e: | |
| return f"Failed to fetch: {e}" | |
| TOOLS = [web_search, python_executor, read_file, calculator, wikipedia_search, fetch_webpage] | |
| # ============ SYSTEM PROMPT ============ | |
| SYSTEM_PROMPT = """You are an expert AI solving GAIA benchmark questions. Your goal is MAXIMUM ACCURACY. | |
| ## CRITICAL: Answer Format (EXACT STRING MATCHING) | |
| Your final answer must be ONLY the answer value - nothing else. | |
| **Rules:** | |
| - Numbers: "42" (not "The answer is 42") | |
| - Names: Exact spelling "John Smith" | |
| - Lists: Comma-separated, NO spaces: "apple,banana,cherry" | |
| - Dates: Requested format or YYYY-MM-DD | |
| - Yes/No: "Yes" or "No" | |
| - NEVER use prefixes like "Answer:", "FINAL ANSWER:", etc. | |
| - NEVER explain - just the answer | |
| ## Strategy | |
| 1. **If file provided**: Use read_file FIRST - answer is usually there | |
| 2. **For calculations**: Use python_executor or calculator | |
| 3. **For facts**: wikipedia_search for historical, web_search for current | |
| 4. **For URLs in question**: Use fetch_webpage | |
| 5. **Verify**: Check spelling, formatting, precision | |
| ## When Ready | |
| State ONLY the answer value. Nothing else.""" | |
| # ============ AGENT CLASS ============ | |
| class GAIAAgent: | |
| """LangGraph agent for GAIA benchmark.""" | |
| def __init__( | |
| self, | |
| model_name: str = None, | |
| temperature: float = 0, | |
| max_iterations: int = 25, | |
| ): | |
| self.max_iterations = max_iterations | |
| self.use_openai = is_production() or not is_ollama_available() | |
| if self.use_openai: | |
| from langchain_openai import ChatOpenAI | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY not found") | |
| self.model_name = model_name or OPENAI_MODEL | |
| self.llm = ChatOpenAI(model=self.model_name, temperature=temperature, api_key=api_key) | |
| self.supports_vision = True # OpenAI models support vision | |
| logger.info(f"Using OpenAI: {self.model_name}") | |
| else: | |
| from langchain_ollama import ChatOllama | |
| self.model_name = model_name or OLLAMA_MODEL | |
| self.llm = ChatOllama(model=self.model_name, base_url=OLLAMA_BASE_URL, temperature=temperature) | |
| self.supports_vision = _is_vision_model(self.model_name) | |
| logger.info(f"Using Ollama: {self.model_name} (vision: {self.supports_vision})") | |
| self.llm_with_tools = self.llm.bind_tools(TOOLS) | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", self._agent_node) | |
| workflow.add_node("tools", ToolNode(TOOLS)) | |
| workflow.add_node("extract_answer", self._extract_answer_node) | |
| workflow.set_entry_point("agent") | |
| workflow.add_conditional_edges("agent", self._route, {"tools": "tools", "end": "extract_answer"}) | |
| workflow.add_edge("tools", "agent") | |
| workflow.add_edge("extract_answer", END) | |
| return workflow.compile() | |
| def _agent_node(self, state: AgentState) -> dict: | |
| messages = list(state["messages"]) | |
| iteration = state.get("iteration_count", 0) | |
| file_path = state.get("file_path") | |
| # If using Ollama vision and image exists, ensure image is included in the last user message | |
| if not self.use_openai and self.supports_vision and file_path and os.path.exists(file_path): | |
| ext = os.path.splitext(file_path)[1].lower() | |
| is_image = ext in ['.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp'] | |
| if is_image: | |
| # Check if the last message is a HumanMessage without image content | |
| # If so, we need to add the image to it | |
| last_msg = messages[-1] if messages else None | |
| if isinstance(last_msg, HumanMessage): | |
| # Check if message content is a string (text only) or list (multimodal) | |
| if isinstance(last_msg.content, str): | |
| # Convert text-only message to multimodal with image | |
| try: | |
| with open(file_path, "rb") as f: | |
| image_data = base64.b64encode(f.read()).decode('utf-8') | |
| media_type = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", | |
| "gif": "image/gif", "webp": "image/webp", "bmp": "image/bmp"}.get(ext.lstrip('.'), "image/png") | |
| # Replace the last message with multimodal version | |
| messages[-1] = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": last_msg.content}, | |
| {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{image_data}"}} | |
| ] | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to add image to message: {e}") | |
| if iteration >= self.max_iterations - 2: | |
| messages.append(SystemMessage(content="⚠️ FINAL: Provide answer NOW. Just the value.")) | |
| elif iteration >= self.max_iterations - 5: | |
| messages.append(SystemMessage(content="⚠️ Conclude soon. Provide the answer.")) | |
| if self.use_openai: | |
| time.sleep(0.5) | |
| try: | |
| response = self.llm_with_tools.invoke(messages) | |
| except Exception as e: | |
| error_str = str(e) | |
| logger.error(f"LLM error: {error_str}") | |
| # Check if error contains raw Python code (common with Ollama) | |
| if "error parsing tool call" in error_str.lower() and "raw=" in error_str: | |
| # Extract the raw code from the error message | |
| try: | |
| # Find the raw code between raw=' and ' | |
| match = re.search(r"raw='(.*?)'", error_str, re.DOTALL) | |
| if match: | |
| raw_code = match.group(1) | |
| logger.info(f"Detected raw Python code, wrapping in python_executor tool call") | |
| # Create a manual tool call for python_executor (dict format for langchain-core 0.3.x) | |
| from langchain_core.messages import ToolMessage | |
| tool_call_id = f"call_{int(time.time() * 1000)}" | |
| # Execute the code directly via the tool | |
| result = python_executor.invoke({"code": raw_code}) | |
| # Create a proper response with tool call (dict format) | |
| tool_call_dict = { | |
| "name": "python_executor", | |
| "args": {"code": raw_code}, | |
| "id": tool_call_id | |
| } | |
| ai_msg = AIMessage( | |
| content="", | |
| tool_calls=[tool_call_dict] | |
| ) | |
| tool_msg = ToolMessage( | |
| content=result, | |
| tool_call_id=tool_call_id | |
| ) | |
| return { | |
| "messages": [ai_msg, tool_msg], | |
| "iteration_count": iteration + 1 | |
| } | |
| except Exception as parse_error: | |
| logger.error(f"Failed to extract code from error: {parse_error}") | |
| return {"messages": [AIMessage(content="Error occurred.")], "iteration_count": iteration + 1} | |
| return {"messages": [response], "iteration_count": iteration + 1} | |
| def _route(self, state: AgentState) -> Literal["tools", "end"]: | |
| last = state["messages"][-1] | |
| if state.get("iteration_count", 0) >= self.max_iterations: | |
| return "end" | |
| if hasattr(last, "tool_calls") and last.tool_calls: | |
| return "tools" | |
| return "end" | |
| def _extract_answer_node(self, state: AgentState) -> dict: | |
| messages = state["messages"] | |
| # Find last substantive AI response | |
| content = "" | |
| for msg in reversed(messages): | |
| if isinstance(msg, AIMessage) and msg.content: | |
| c = msg.content.strip() | |
| # Skip if it's clearly garbage/prompt repetition | |
| if self._is_valid_answer_candidate(c): | |
| content = c | |
| break | |
| answer = self._clean_answer(content) | |
| return {"final_answer": answer} | |
| def _is_valid_answer_candidate(self, text: str) -> bool: | |
| """Check if text looks like a valid answer, not garbage.""" | |
| if not text or len(text) < 1: | |
| return False | |
| text_lower = text.lower() | |
| # Reject if it contains prompt text patterns | |
| bad_patterns = [ | |
| "numbers: just", "format rules", "must follow", | |
| "critical: answer format", "when ready", "your final answer", | |
| "the benchmark uses", "exact string matching", | |
| "no prefixes", "no explanations" | |
| ] | |
| if any(p in text_lower for p in bad_patterns): | |
| return False | |
| # Reject if it looks like the question was repeated | |
| if "provide the correct next move" in text_lower: | |
| return False | |
| if text.startswith("Review the"): | |
| return False | |
| # Reject tool call syntax | |
| if text.startswith("web_search(") or text.startswith("read_file("): | |
| return False | |
| return True | |
| def _clean_answer(self, raw: str) -> str: | |
| if not raw: | |
| return "" | |
| answer = raw.strip() | |
| # Remove markdown | |
| answer = re.sub(r'\*\*(.+?)\*\*', r'\1', answer) | |
| answer = re.sub(r'\*(.+?)\*', r'\1', answer) | |
| answer = re.sub(r'`(.+?)`', r'\1', answer) | |
| # Remove prefixes | |
| prefixes = [ | |
| r"^(?:the\s+)?(?:final\s+)?answer\s*(?:is)?:?\s*", | |
| r"^result\s*:?\s*", | |
| r"^therefore\s*,?\s*", | |
| r"^thus\s*,?\s*", | |
| r"^so\s*,?\s*", | |
| ] | |
| for p in prefixes: | |
| answer = re.sub(p, "", answer, flags=re.IGNORECASE) | |
| # Remove quotes | |
| if (answer.startswith('"') and answer.endswith('"')) or \ | |
| (answer.startswith("'") and answer.endswith("'")): | |
| answer = answer[1:-1] | |
| # Take first line | |
| answer = answer.split('\n')[0].strip() | |
| # Remove trailing period for short answers | |
| if answer.endswith('.') and len(answer.split()) <= 3: | |
| answer = answer[:-1] | |
| return answer.strip() | |
| def run(self, question: str, task_id: str = "", file_path: str = None) -> str: | |
| user_content = question | |
| audio_transcript = None | |
| # Handle files - dynamic image and audio detection | |
| if file_path and os.path.exists(file_path): | |
| ext = os.path.splitext(file_path)[1].lower() | |
| # Check for image files | |
| is_image = ext in ['.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp'] | |
| is_audio = ext in ['.mp3', '.wav', '.m4a', '.ogg', '.flac', '.webm'] | |
| # Handle images with OpenAI vision | |
| if is_image and self.use_openai: | |
| return self._run_with_vision(question, task_id, file_path) | |
| # Handle images with Ollama vision (if model supports it) | |
| if is_image and not self.use_openai and self.supports_vision: | |
| return self._run_with_ollama_vision(question, task_id, file_path) | |
| # Handle audio files - transcribe first | |
| if is_audio: | |
| audio_transcript = _transcribe_audio(file_path) | |
| # If transcription failed, continue with error message | |
| if audio_transcript.startswith("Error:"): | |
| logger.warning(f"Audio transcription failed: {audio_transcript}") | |
| else: | |
| # Combine question with audio transcript | |
| user_content = f"{question}\n\n{audio_transcript}" | |
| # Handle image + audio combination | |
| if is_image and is_audio: | |
| # This case is handled above - audio transcribed, image will be passed in messages | |
| pass | |
| elif is_image and not self.supports_vision: | |
| # Image detected but model doesn't support vision | |
| logger.warning(f"Image file detected but model {self.model_name} doesn't support vision") | |
| return f"Error: Image file provided but model {self.model_name} doesn't support vision. Please use a vision-capable model like llama3.2-vision or qwen2.5-vl." | |
| # Handle other file types | |
| if not is_image and not is_audio: | |
| file_hints = { | |
| '.xlsx': "EXCEL file - use read_file to examine ALL sheets", | |
| '.xls': "EXCEL file - use read_file to examine ALL sheets", | |
| '.csv': "CSV file - use read_file, then python_executor for analysis", | |
| '.pdf': "PDF file - use read_file to extract ALL text", | |
| '.py': "Python file - use read_file to see the code", | |
| } | |
| hint = file_hints.get(ext, "Use read_file to examine contents") | |
| user_content = f"""⚠️ FILE PROVIDED: {file_path} | |
| {hint} | |
| **Use read_file("{file_path}") FIRST.** | |
| Question: {question}""" | |
| # Check for URLs in question | |
| url_match = re.search(r'https?://[^\s]+', question) | |
| if url_match: | |
| user_content += f"\n\n💡 URL detected: {url_match.group()} - Consider using fetch_webpage if needed." | |
| # Build initial message - include image if using Ollama vision | |
| initial_messages = [SystemMessage(content=SYSTEM_PROMPT)] | |
| # If using Ollama vision and image exists, include image in message | |
| if file_path and os.path.exists(file_path): | |
| ext = os.path.splitext(file_path)[1].lower() | |
| is_image = ext in ['.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp'] | |
| if is_image and not self.use_openai and self.supports_vision: | |
| # Include image in HumanMessage for Ollama vision | |
| try: | |
| with open(file_path, "rb") as f: | |
| image_data = base64.b64encode(f.read()).decode('utf-8') | |
| media_type = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", | |
| "gif": "image/gif", "webp": "image/webp", "bmp": "image/bmp"}.get(ext.lstrip('.'), "image/png") | |
| user_msg = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": user_content}, | |
| {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{image_data}"}} | |
| ] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to encode image: {e}") | |
| user_msg = HumanMessage(content=user_content) | |
| else: | |
| user_msg = HumanMessage(content=user_content) | |
| else: | |
| user_msg = HumanMessage(content=user_content) | |
| initial_messages.append(user_msg) | |
| initial_state: AgentState = { | |
| "messages": initial_messages, | |
| "task_id": task_id, | |
| "file_path": file_path, | |
| "iteration_count": 0, | |
| "final_answer": None | |
| } | |
| try: | |
| final_state = self.graph.invoke(initial_state, {"recursion_limit": self.max_iterations * 2 + 10}) | |
| answer = final_state.get("final_answer", "") | |
| if not answer or not self._is_valid_answer_candidate(answer): | |
| # Try harder to find an answer | |
| for msg in reversed(final_state.get("messages", [])): | |
| if isinstance(msg, AIMessage) and msg.content: | |
| candidate = self._clean_answer(msg.content) | |
| if candidate and self._is_valid_answer_candidate(candidate): | |
| answer = candidate | |
| break | |
| return answer if answer else "Unable to determine answer" | |
| except Exception as e: | |
| logger.error(f"Agent error: {e}") | |
| return f"Agent error: {str(e)}" | |
| def _run_with_vision(self, question: str, task_id: str, image_path: str) -> str: | |
| """Handle image questions using GPT-4o vision.""" | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| # Read and encode image | |
| with open(image_path, "rb") as f: | |
| image_data = base64.b64encode(f.read()).decode('utf-8') | |
| ext = os.path.splitext(image_path)[1].lower() | |
| media_type = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", | |
| "gif": "image/gif", "webp": "image/webp"}.get(ext.lstrip('.'), "image/png") | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": "You are solving GAIA benchmark questions. Provide ONLY the answer value, no explanations or prefixes."}, | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": question}, | |
| {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{image_data}"}} | |
| ]} | |
| ], | |
| max_tokens=500, | |
| temperature=0 | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| return self._clean_answer(answer) | |
| except Exception as e: | |
| logger.error(f"Vision error: {e}") | |
| return f"Vision error: {str(e)}" | |
| def _run_with_ollama_vision(self, question: str, task_id: str, image_path: str) -> str: | |
| """Handle image questions using Ollama vision models.""" | |
| try: | |
| # Read and encode image | |
| with open(image_path, "rb") as f: | |
| image_data = base64.b64encode(f.read()).decode('utf-8') | |
| ext = os.path.splitext(image_path)[1].lower() | |
| media_type = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", | |
| "gif": "image/gif", "webp": "image/webp", "bmp": "image/bmp"}.get(ext.lstrip('.'), "image/png") | |
| # Create message with image | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": question}, | |
| {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{image_data}"}} | |
| ] | |
| ) | |
| # Invoke model with system prompt and image message | |
| response = self.llm.invoke([SystemMessage(content=SYSTEM_PROMPT), message]) | |
| answer = response.content if hasattr(response, 'content') else str(response) | |
| return self._clean_answer(answer) | |
| except Exception as e: | |
| logger.error(f"Ollama vision error: {e}") | |
| return f"Vision error: {str(e)}" | |
| def create_agent() -> GAIAAgent: | |
| """Create a configured agent.""" | |
| return GAIAAgent(temperature=0, max_iterations=25) | |