Final_Assignment_D3MI4N / langgraph_new.py
D3MI4N's picture
improving tools
1d0ce3b
import os
import re
import sys
from dotenv import load_dotenv
import pandas as pd
import whisper
import requests
from urllib.parse import urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
# ** Retrieval imports **
from langchain_huggingface import HuggingFaceEmbeddings
from supabase.client import create_client
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.tools.retriever import create_retriever_tool
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode, tools_condition
load_dotenv()
# Enhanced system prompt optimized for GAIA
SYSTEM = SystemMessage(content="""
You are a precise QA agent specialized in answering GAIA benchmark questions.
CRITICAL RESPONSE RULES:
- Answer with ONLY the exact answer, no explanations or conversational text
- NO XML tags, NO "FINAL ANSWER:", NO introductory phrases
- For lists: comma-separated, alphabetized if requested, no trailing punctuation
- For numbers: use exact format requested (USD as 12.34, codes bare, etc.)
- For yes/no: respond only "Yes" or "No"
- Use tools systematically for factual lookups, audio/video transcription, and data analysis
Your goal is to provide exact answers that match GAIA ground truth precisely.
""".strip())
# ─────────────────────────────────────────────────────────────────────────────
# ENHANCED TOOLS WITH MCP-STYLE ORGANIZATION
# ─────────────────────────────────────────────────────────────────────────────
@tool
def enhanced_web_search(query: str) -> dict:
"""Advanced web search with multiple result processing and filtering."""
try:
# Use higher result count for better coverage
search_tool = TavilySearchResults(max_results=5)
docs = search_tool.run(query)
# Process and clean results
results = []
for d in docs:
content = d.get("content", "").strip()
url = d.get("url", "")
if content and len(content) > 20: # Filter out very short results
results.append(f"Source: {url}\nContent: {content}")
return {"web_results": "\n\n".join(results)}
except Exception as e:
return {"web_results": f"Search error: {str(e)}"}
@tool
def enhanced_wiki_search(query: str) -> dict:
"""Enhanced Wikipedia search with better content extraction."""
try:
# Try multiple query variations for better results
queries = [query, query.replace("_", " "), query.replace("-", " ")]
for q in queries:
try:
pages = WikipediaLoader(query=q, load_max_docs=3).load()
if pages:
content = "\n\n".join([
f"Page: {p.metadata.get('title', 'Unknown')}\n{p.page_content[:2000]}"
for p in pages
])
return {"wiki_results": content}
except:
continue
return {"wiki_results": "No Wikipedia results found"}
except Exception as e:
return {"wiki_results": f"Wikipedia error: {str(e)}"}
@tool
def youtube_transcript_tool(url: str) -> dict:
"""Extract transcript from YouTube videos with enhanced error handling."""
try:
print(f"DEBUG: Processing YouTube URL: {url}", file=sys.stderr)
# Extract video ID from various YouTube URL formats
video_id_patterns = [
r"(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/)([a-zA-Z0-9_-]{11})",
r"(?:v=|\/)([0-9A-Za-z_-]{11})"
]
video_id = None
for pattern in video_id_patterns:
match = re.search(pattern, url)
if match:
video_id = match.group(1)
break
if not video_id:
return {"transcript": "Error: Could not extract video ID from URL"}
print(f"DEBUG: Extracted video ID: {video_id}", file=sys.stderr)
# Try to get transcript
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
# Try to get English transcript first
try:
transcript = transcript_list.find_transcript(['en'])
except:
# If no English, get the first available
available_transcripts = list(transcript_list)
if available_transcripts:
transcript = available_transcripts[0]
else:
return {"transcript": "No transcripts available"}
transcript_data = transcript.fetch()
# Format transcript with timestamps for better context
formatted_transcript = []
for entry in transcript_data:
time_str = f"[{entry['start']:.1f}s]"
formatted_transcript.append(f"{time_str} {entry['text']}")
full_transcript = "\n".join(formatted_transcript)
return {"transcript": full_transcript}
except Exception as e:
return {"transcript": f"Error fetching transcript: {str(e)}"}
except Exception as e:
return {"transcript": f"YouTube processing error: {str(e)}"}
@tool
def enhanced_audio_transcribe(path: str) -> dict:
"""Enhanced audio transcription with better file handling."""
try:
# Handle both relative and absolute paths
if not os.path.isabs(path):
abs_path = os.path.abspath(path)
else:
abs_path = path
print(f"DEBUG: Transcribing audio file: {abs_path}", file=sys.stderr)
if not os.path.isfile(abs_path):
# Try current directory
current_dir_path = os.path.join(os.getcwd(), os.path.basename(path))
if os.path.isfile(current_dir_path):
abs_path = current_dir_path
else:
return {"transcript": f"Error: Audio file not found at {abs_path}"}
# Check for ffmpeg availability
try:
import subprocess
subprocess.run(["ffmpeg", "-version"], check=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (FileNotFoundError, subprocess.CalledProcessError):
return {"transcript": "Error: ffmpeg not found. Please install ffmpeg."}
# Load and transcribe
model = whisper.load_model("base")
result = model.transcribe(abs_path)
# Clean and format transcript
transcript = result["text"].strip()
return {"transcript": transcript}
except Exception as e:
return {"transcript": f"Transcription error: {str(e)}"}
@tool
def enhanced_excel_analysis(path: str, query: str = "", sheet_name: str = None) -> dict:
"""Enhanced Excel analysis with query-specific processing."""
try:
# Handle file path
if not os.path.isabs(path):
abs_path = os.path.abspath(path)
else:
abs_path = path
if not os.path.isfile(abs_path):
current_dir_path = os.path.join(os.getcwd(), os.path.basename(path))
if os.path.isfile(current_dir_path):
abs_path = current_dir_path
else:
return {"excel_analysis": f"Error: Excel file not found at {abs_path}"}
# Read Excel file
df = pd.read_excel(abs_path, sheet_name=sheet_name or 0)
# Basic info
analysis = {
"columns": list(df.columns),
"row_count": len(df),
"sheet_info": f"Analyzing sheet: {sheet_name or 'default'}"
}
# Query-specific analysis
query_lower = query.lower() if query else ""
if "total" in query_lower or "sum" in query_lower:
# Find numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
totals = {}
for col in numeric_cols:
totals[col] = df[col].sum()
analysis["totals"] = totals
if "food" in query_lower or "category" in query_lower:
# Look for categorical data
for col in df.columns:
if df[col].dtype == 'object':
categories = df[col].value_counts().to_dict()
analysis[f"{col}_categories"] = categories
# Always include sample data
analysis["sample_data"] = df.head(5).to_dict('records')
# Include summary statistics for numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) > 0:
analysis["numeric_summary"] = df[numeric_cols].describe().to_dict()
return {"excel_analysis": analysis}
except Exception as e:
return {"excel_analysis": f"Excel analysis error: {str(e)}"}
@tool
def web_file_downloader(url: str) -> dict:
"""Download and analyze files from web URLs."""
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
# Determine file type from URL or headers
content_type = response.headers.get('content-type', '').lower()
if 'audio' in content_type or url.endswith(('.mp3', '.wav', '.m4a')):
# Save temporarily and transcribe
temp_path = f"temp_audio_{hash(url) % 10000}.wav"
with open(temp_path, 'wb') as f:
f.write(response.content)
result = enhanced_audio_transcribe(temp_path)
# Clean up
try:
os.remove(temp_path)
except:
pass
return result
elif 'text' in content_type or 'html' in content_type:
return {"content": response.text[:5000]} # Limit size
else:
return {"content": f"Downloaded {len(response.content)} bytes of {content_type}"}
except Exception as e:
return {"content": f"Download error: {str(e)}"}
# ─────────────────────────────────────────────────────────────────────────────
# ENHANCED RETRIEVER TOOL
# ─────────────────────────────────────────────────────────────────────────────
try:
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
supabase = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"])
vector_store = SupabaseVectorStore(
client=supabase,
embedding=emb,
table_name="documents",
query_name="match_documents_langchain",
)
@tool
def gaia_qa_retriever(query: str) -> dict:
"""Retrieve similar GAIA Q&A pairs with enhanced search."""
try:
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
docs = retriever.invoke(query)
if not docs:
return {"gaia_results": "No similar GAIA examples found"}
results = []
for i, doc in enumerate(docs, 1):
content = doc.page_content
# Clean up the Q: A: format for better readability
content = content.replace("Q: ", "\nQuestion: ").replace(" A: ", "\nAnswer: ")
results.append(f"Example {i}:{content}\n")
return {"gaia_results": "\n".join(results)}
except Exception as e:
return {"gaia_results": f"Retrieval error: {str(e)}"}
TOOLS = [enhanced_web_search, enhanced_wiki_search, youtube_transcript_tool,
enhanced_audio_transcribe, enhanced_excel_analysis, web_file_downloader,
gaia_qa_retriever]
except Exception as e:
print(f"Warning: Supabase retriever not available: {e}")
TOOLS = [enhanced_web_search, enhanced_wiki_search, youtube_transcript_tool,
enhanced_audio_transcribe, enhanced_excel_analysis, web_file_downloader]
# ─────────────────────────────────────────────────────────────────────────────
# ENHANCED AGENT & GRAPH SETUP
# ─────────────────────────────────────────────────────────────────────────────
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) # Set temperature to 0 for consistency
llm_with_tools = llm.bind_tools(TOOLS)
# Build graph with proper state management
builder = StateGraph(MessagesState)
def enhanced_assistant_node(state: dict) -> dict:
"""Enhanced assistant node with better answer processing."""
MAX_TOOL_CALLS = 5 # Increased for complex GAIA questions
msgs = state.get("messages", [])
tool_call_count = state.get("tool_call_count", 0)
if not msgs or not isinstance(msgs[0], SystemMessage):
msgs = [SYSTEM] + msgs
print(f"\n➑️ Assistant processing (tool calls: {tool_call_count})", file=sys.stderr)
# Log the latest message for debugging
if msgs:
latest = msgs[-1]
if hasattr(latest, 'content'):
print(f"β†’ Latest input: {latest.content[:200]}...", file=sys.stderr)
try:
out: AIMessage = llm_with_tools.invoke(msgs)
print(f"β†’ Model wants to use tools: {len(out.tool_calls) > 0}", file=sys.stderr)
if out.tool_calls:
if tool_call_count >= MAX_TOOL_CALLS:
print("β›” Tool call limit reached", file=sys.stderr)
fallback = AIMessage(content="Unable to determine answer with available information.")
return {
"messages": msgs + [fallback],
"tool_call_count": tool_call_count
}
return {
"messages": msgs + [out],
"tool_call_count": tool_call_count + 1
}
# Process final answer for GAIA format
answer_content = process_final_answer(out.content)
print(f"βœ… Final answer: {answer_content!r}", file=sys.stderr)
return {
"messages": msgs + [AIMessage(content=answer_content)],
"tool_call_count": tool_call_count
}
except Exception as e:
print(f"❌ Assistant error: {e}", file=sys.stderr)
error_msg = AIMessage(content="Error processing request.")
return {
"messages": msgs + [error_msg],
"tool_call_count": tool_call_count
}
def process_final_answer(content: str) -> str:
"""Process the final answer to match GAIA requirements exactly."""
if not content:
return "Unable to determine answer"
# Remove any XML-like tags
content = re.sub(r'<[^>]*>', '', content)
# Remove common unwanted prefixes/suffixes
unwanted_patterns = [
r'^.*?(?:answer is|answer:|final answer:)\s*',
r'^.*?(?:the result is|result:)\s*',
r'^.*?(?:therefore,|thus,|so,)\s*',
r'\.$', # Remove trailing period
r'^["\'](.+)["\']$', # Remove quotes
]
for pattern in unwanted_patterns:
content = re.sub(pattern, r'\1' if '\\1' in pattern else '', content, flags=re.IGNORECASE)
# Clean up whitespace
content = content.strip()
# Handle lists - ensure proper comma separation without trailing punctuation
if ',' in content and not any(word in content.lower() for word in ['however', 'although', 'because']):
# This might be a list
items = [item.strip() for item in content.split(',')]
content = ', '.join(items)
content = content.rstrip('.,;')
# Take only the first line if there are multiple lines
content = content.split('\n')[0].strip()
return content if content else "Unable to determine answer"
# Build the graph
builder.add_node("assistant", enhanced_assistant_node)
builder.add_node("tools", ToolNode(TOOLS))
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
{"tools": "tools", END: END}
)
builder.add_edge("tools", "assistant")
# Compile the graph with configuration
graph = builder.compile()
# ─────────────────────────────────────────────────────────────────────────────
# GAIA API INTERACTION FUNCTIONS
# ─────────────────────────────────────────────────────────────────────────────
def get_gaia_questions():
"""Fetch questions from the GAIA API."""
try:
response = requests.get("https://agents-course-unit4-scoring.hf.space/questions")
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error fetching GAIA questions: {e}")
return []
def get_random_gaia_question():
"""Fetch a single random question from the GAIA API."""
try:
response = requests.get("https://agents-course-unit4-scoring.hf.space/random-question")
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error fetching random GAIA question: {e}")
return None
def answer_gaia_question(question_text: str) -> str:
"""Answer a single GAIA question using the agent."""
try:
# Create the initial state
initial_state = {
"messages": [HumanMessage(content=question_text)],
"tool_call_count": 0
}
# Invoke the graph
result = graph.invoke(initial_state)
if result and "messages" in result and result["messages"]:
return result["messages"][-1].content.strip()
else:
return "No answer generated"
except Exception as e:
print(f"Error answering question: {e}")
return f"Error: {str(e)}"
# ─────────────────────────────────────────────────────────────────────────────
# TESTING AND VALIDATION
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
print("πŸ” Enhanced GAIA Agent Graph Structure:")
try:
print(graph.get_graph().draw_mermaid())
except:
print("Could not generate mermaid diagram")
print("\nπŸ§ͺ Testing with GAIA-style questions...")
# Test questions that cover different GAIA capabilities
test_questions = [
"What is 2 + 2?",
"What is the capital of France?",
"List the vegetables from this list: broccoli, apple, carrot. Alphabetize and use comma separation.",
"Given the Excel file at test_sales.xlsx, what were total sales for food? Express in USD with two decimals.",
"Examine the audio file at ./test.wav. What is its transcript?",
]
# Add YouTube test if we have a valid URL
if os.path.exists("test.wav"):
test_questions.append("What does the speaker say in the audio file test.wav?")
for i, question in enumerate(test_questions, 1):
print(f"\nπŸ“ Test {i}: {question}")
try:
answer = answer_gaia_question(question)
print(f"βœ… Answer: {answer!r}")
except Exception as e:
print(f"❌ Error: {e}")
print("-" * 80)
# Test with a real GAIA question if API is available
print("\n🌍 Testing with real GAIA question...")
try:
random_q = get_random_gaia_question()
if random_q:
print(f"πŸ“‹ GAIA Question: {random_q.get('question', 'N/A')}")
answer = answer_gaia_question(random_q.get('question', ''))
print(f"🎯 Agent Answer: {answer!r}")
print(f"πŸ’‘ Task ID: {random_q.get('task_id', 'N/A')}")
except Exception as e:
print(f"Could not test with real GAIA question: {e}")