sangwanparteek commited on
Commit
00ff2c1
·
1 Parent(s): 81917a3

adding agent code

Browse files
agent.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from rich.table import Table
4
+ from rich.console import Console
5
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
6
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from langchain.memory import ConversationBufferMemory
8
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
9
+ from config import get_llm
10
+ from prompt_template import gaia_prompt
11
+
12
+ from tools.file_attachment_query import file_attachment_query_tool
13
+ from tools.math_solver import math_solver_tool
14
+ from tools.google_search import google_search_tool
15
+ from tools.gemini_video_qa import gemini_video_qa_tool
16
+ from tools.riddle_solver import riddle_solver_tool
17
+ from tools.text_transformer import text_transformer_tool
18
+ from tools.wiki_content_fetcher import wiki_content_fetcher_tool
19
+ from tools.wiki_title_finder import wiki_title_finder_tool
20
+
21
+
22
+ class LangChainGAIAAgent:
23
+ def __init__(self, provider="deepseek"):
24
+ print("LangChain GAIA Agent initialized.")
25
+
26
+ # Select model (config.py handles provider switching)
27
+ if provider == "huggingface":
28
+ llm = ChatHuggingFace(
29
+ llm = HuggingFaceEndpoint(
30
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
31
+ temperature=0
32
+ )
33
+ )
34
+ else:
35
+ self.llm = get_llm(provider)
36
+
37
+ # Register all tools
38
+ self.tools = [
39
+ file_attachment_query_tool,
40
+ math_solver_tool,
41
+ google_search_tool,
42
+ gemini_video_qa_tool,
43
+ riddle_solver_tool,
44
+ text_transformer_tool,
45
+ wiki_content_fetcher_tool,
46
+ wiki_title_finder_tool,
47
+ ]
48
+
49
+ # Combines rules with LangChain tool orchestration
50
+ self.prompt = ChatPromptTemplate.from_messages([
51
+ ("system", gaia_prompt.template),
52
+ MessagesPlaceholder(variable_name="chat_history"),
53
+ ("human", "{input}")
54
+ ])
55
+
56
+ # Optional memory (multi-turn conversations)
57
+ self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
58
+
59
+ # Create tool-calling agent directly
60
+ self.agent = create_tool_calling_agent(
61
+ llm=self.llm,
62
+ tools=self.tools,
63
+ prompt=self.prompt
64
+ )
65
+
66
+ # Wrap in AgentExecutor (LangChain runtime)
67
+ self.agent_executor = AgentExecutor(
68
+ agent=self.agent,
69
+ tools=self.tools,
70
+ memory=self.memory,
71
+ verbose=True
72
+ )
73
+
74
+ print("GAIA Agent ready with all tools and system rules.\n")
75
+
76
+ def __call__(self, question: str) -> str:
77
+ """
78
+ Call the agent like a function.
79
+ """
80
+ print(f"Received question (first 50 chars): {question[:50]}...")
81
+ try:
82
+ response = self.agent_executor.invoke({"input": question})
83
+ result = response.get("output", "").strip()
84
+ return result
85
+ except Exception as e:
86
+ return f"[ERROR] {str(e)}"
87
+
88
+ def evaluate_random_questions(self, csv_path: str, sample_size: int = 3, show_steps: bool = True):
89
+ """
90
+ Evaluate GAIA benchmark questions from CSV.
91
+ CSV must contain: 'question', 'answer', (optional) 'taskid'
92
+ """
93
+ df = pd.read_csv(csv_path)
94
+ if not {"question", "answer"}.issubset(df.columns):
95
+ print("CSV must contain 'question' and 'answer' columns.")
96
+ print("Found columns:", df.columns.tolist())
97
+ return
98
+
99
+ samples = df.sample(n=sample_size)
100
+ records = []
101
+ correct_count = 0
102
+
103
+ for _, row in samples.iterrows():
104
+ taskid = str(row.get("taskid", "")).strip()
105
+ question = row["question"].strip()
106
+ expected = str(row["answer"]).strip()
107
+
108
+ query = f"taskid: {taskid}, question: {question}" if taskid else question
109
+ agent_answer = self(query).strip()
110
+
111
+ is_correct = (expected == agent_answer)
112
+ correct_count += is_correct
113
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
114
+
115
+ if show_steps:
116
+ print("---")
117
+ print(f"Question: {question}")
118
+ print(f"Expected: {expected}")
119
+ print(f"Agent: {agent_answer}")
120
+ print(f"Correct: {is_correct}")
121
+
122
+ # Pretty print summary
123
+ console = Console()
124
+ table = Table(show_lines=True)
125
+ table.add_column("Question", overflow="fold")
126
+ table.add_column("Expected")
127
+ table.add_column("Agent")
128
+ table.add_column("Correct")
129
+
130
+ for question, expected, agent_ans, correct in records:
131
+ table.add_row(question, expected, agent_ans, correct)
132
+
133
+ console.print(table)
134
+ percent = (correct_count / sample_size) * 100
135
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain_huggingface import HuggingFaceEndpoint
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_community.chat_models import ChatAnthropic
5
+ from langchain_community.chat_models import ChatGrok
6
+ from langchain_community.chat_models import ChatDeepSeek
7
+
8
+ # Define supported providers
9
+ AVAILABLE_MODELS = {
10
+ "openai": {
11
+ "model": "gpt-4o-mini",
12
+ "client": ChatOpenAI,
13
+ "params": {"temperature": 0},
14
+ },
15
+ "huggingface": {
16
+ "model": "Qwen/Qwen2.5-Coder-32B-Instruct",
17
+ "client": HuggingFaceEndpoint,
18
+ "params": {"temperature": 0},
19
+ },
20
+ "gemini": {
21
+ "model": "gemini-2.0-flash",
22
+ "client": ChatGoogleGenerativeAI,
23
+ "params": {"temperature": 0},
24
+ },
25
+ "grok": {
26
+ "model": "qwen-qwq-32b",
27
+ "client": ChatGrok,
28
+ "params": {"temperature": 0},
29
+ },
30
+ "deepseek": {
31
+ "model": "deepseek-coder",
32
+ "client": ChatDeepSeek,
33
+ "params": {"temperature": 0},
34
+ },
35
+ }
36
+
37
+ # Choose provider dynamically here
38
+ PROVIDER = "huggingface" # Change this to "huggingface", "gemini", "grok", or "deepseek"
39
+
40
+ def get_llm(PROVIDER=PROVIDER):
41
+ config = AVAILABLE_MODELS[PROVIDER]
42
+ model_class = config["client"]
43
+ return model_class(model=config["model"], **config["params"])
prompt_template.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
2
+
3
+ # Define a reusable prompt template for reasoning + tool usage
4
+ system_prompt_text = """
5
+ You are an intelligent AI agent who answers the GAIA benchmark questions. You are very precise and dont give nonsense answers.
6
+ Your only purpose is to output the minimal, final answer in the format:
7
+ [ANSWER]
8
+
9
+ While answering you dont provide explanations, intermediate steps, or notes unless specifically asked for.
10
+
11
+ Your answers must be strictly governed by the rules:
12
+ 1. **Format**:
13
+ - limit the token used (within 65536 tokens).
14
+ - Output ONLY the final answer.
15
+ - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
16
+ - No follow-ups, justifications, or clarifications.
17
+
18
+ 2. **Numerical Answers**:
19
+ - Use **digits only**, e.g., `4` not `four`.
20
+ - No commas, symbols, or units unless explicitly required.
21
+ - Never use approximate words like "around", "roughly", "about".
22
+
23
+ 3. **String Answers**:
24
+ - Omit **articles** ("a", "the").
25
+ - Use **full words**; no abbreviations unless explicitly requested.
26
+ - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
27
+ - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
28
+
29
+ 4. **Lists**:
30
+ - Output in **comma-separated** format with no conjunctions.
31
+ - Sort **alphabetically** or **numerically** depending on type.
32
+ - No braces or brackets unless explicitly asked.
33
+
34
+ 5. **Sources**:
35
+ - For Wikipedia or web tools, extract only the precise fact that answers the question.
36
+ - Ignore any unrelated content.
37
+
38
+ 6. **File Analysis**:
39
+ - Use the run_query_with_file tool, append the taskid to the url.
40
+ - Only include the exact answer to the question.
41
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
42
+
43
+ 7. **Video**:
44
+ - Use the relevant video tool.
45
+ - Only include the exact answer to the question.
46
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
47
+
48
+ 8. **Minimalism**:
49
+ - Do not make assumptions unless the prompt logically demands it.
50
+ - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
51
+ - If the answer is not found, say `[ANSWER] - unknown`.
52
+
53
+ ---
54
+ You must follow the examples (These answers are correct in case you see the similar questions):
55
+ Q: What is 1 + 1?
56
+ A: 2
57
+ Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
58
+ A: 3
59
+ Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
60
+ A: b, e
61
+ Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
62
+ A: 519
63
+ """
64
+
65
+ system_message_prompt = SystemMessagePromptTemplate.from_template(system_prompt_text)
66
+
67
+ human_prompt = HumanMessagePromptTemplate.from_template("{question}")
68
+
69
+ gaia_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_prompt])
tools/file_attachment_query.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ import requests
4
+ import os
5
+
6
+ def file_attachment_query(task_id: str, query: str) -> str:
7
+ """A tool that processes file attachment queries."""
8
+
9
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
10
+ file_response = requests.get(file_url)
11
+ if file_response.status_code != 200:
12
+ return f"Error downloading file with task_id {task_id}: {file_response.status_code} - {file_response.text}"
13
+
14
+ file_data = file_response.content
15
+ # TODO: Change the model selection dynamic.
16
+ llm = ChatGoogleGenerativeAI(
17
+ model="gemini-1.5-flash",
18
+ temperature=0.0,
19
+ api_key=os.getenv("GOOGLE_API_KEY"))
20
+
21
+ messages = [
22
+ SystemMessage(content="You are a helpful file analysis assistant."),
23
+ HumanMessage(
24
+ content=[
25
+ {"type": "text", "text": f"Analyze this file and answer: {user_query}"},
26
+ {"type": "file", "data": file_data, "mime_type": "application/octet-stream"}
27
+ ]
28
+ )
29
+ ]
30
+ response = llm.invoke(messages)
31
+ return getattr(response, "text", str(response))
32
+
33
+ file_attachment_query_tool = Tool(
34
+ name="run_query_on_file_attachment",
35
+ func=file_attachment_query,
36
+ description="Downloads file attached in the user prompt, adds it to the context, and runs the query on it.",
37
+ input_schema={
38
+ "task_id": {
39
+ "type": "string",
40
+ "description": "The unique identifier for the task associated with the file attachment, used to download the correct file.",
41
+ "nullable": True
42
+ },
43
+ "query": {
44
+ "type": "string",
45
+ "description": "The query to be executed on the file attachment content."
46
+ }
47
+ },
48
+ output_schema={
49
+ "type": "string",
50
+ "description": "The result of the query executed on the file attachment content."
51
+ }
52
+ )
tools/gemini_video_qa.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from langchain.tools import Tool
4
+
5
+ def gemini_video_qa(video_url: str, user_query: str) -> str:
6
+ """Analyze video content and answer questions using Gemini."""
7
+ model_name = "gemini-1.5-flash"
8
+
9
+ req = {
10
+ "model": f"models/{model_name}",
11
+ "contents": [{
12
+ "parts": [
13
+ {"fileData": {"fileUri": video_url}},
14
+ {"text": f"Please watch the video and answer the question: {user_query}"}
15
+ ]
16
+ }]
17
+ }
18
+
19
+ url = (
20
+ f"https://generativelanguage.googleapis.com/v1beta/models/"
21
+ f"{model_name}:generateContent?key={os.getenv('GOOGLE_API_KEY')}"
22
+ )
23
+
24
+ try:
25
+ res = requests.post(url, json=req, headers={"Content-Type": "application/json"})
26
+ if res.status_code != 200:
27
+ return f"Video error {res.status_code}: {res.text}"
28
+
29
+ data = res.json()
30
+ parts = data.get("candidates", [{}])[0].get("content", {}).get("parts", [])
31
+ return "".join([p.get("text", "") for p in parts]).strip()
32
+
33
+ except Exception as e:
34
+ return f"[ERROR] GeminiVideoQATool failed: {str(e)}"
35
+
36
+
37
+ gemini_video_tool = Tool(
38
+ name="video_inspector",
39
+ description="Analyze video content to answer questions using Gemini. Inputs: video_url, user_query.",
40
+ func=lambda x: gemini_video_qa(**x)
41
+ )
tools/google_search.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+ import os
3
+ import requests
4
+
5
+ def google_search(input: str) -> str:
6
+ """A tool that simulates a Google search and returns top results."""
7
+
8
+ try:
9
+ response = requests.get(
10
+ "https://www.googleapis.com/customsearch/v1",
11
+ params={
12
+ "q": input,
13
+ "key": os.getenv("GOOGLE_API_KEY"),
14
+ "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
15
+ "num": 1
16
+ }
17
+ )
18
+ data = response.json()
19
+ # Extract and return the top search result summary
20
+ return data.get("items", [])[0].get("snippet", "No results found.")
21
+ except Exception as e:
22
+ return f"Google search error: {str(e)}"
23
+
24
+ google_search_tool = Tool(
25
+ name="google_search",
26
+ func=google_search,
27
+ description="Search the web using Google and return the top summary from the results."
28
+ )
tools/math_solver.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+
3
+ def math_solver(input: str) -> str:
4
+ """A tool that safely evaluates basic math expressions."""
5
+ try:
6
+ # Evaluate the math expression safely
7
+ return str(eval(input, {"__builtins__": {}}))
8
+ except Exception as e:
9
+ return f"Math error: {e}"
10
+
11
+ math_solver_tool = Tool(
12
+ name="math_solver",
13
+ func=math_solver,
14
+ description="Safely evaluates the basic math expressions."
15
+ )
tools/riddle_solver.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+
3
+ def riddle_solver(input: str) -> str:
4
+ """A tool that solves basic riddles using logic."""
5
+ # Simple riddle solving logic (for demonstration purposes)
6
+ if "forward" in input and "backward" in input:
7
+ return "A palindrome"
8
+ return "riddle_solver failed."
9
+
10
+ riddle_solver_tool = Tool(
11
+ name="riddle_solver",
12
+ func=riddle_solver,
13
+ description="Solves basic riddles using logical reasoning."
14
+ )
tools/text_transformer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+
3
+ def text_transformer(input: str) -> str:
4
+ """A tool that transforms text based on specified operations."""
5
+ if input.startswith("reverse:"):
6
+ reversed_text = input[8:].strip()[::-1]
7
+ if 'left' in reversed_text.lower():
8
+ return "right"
9
+ return reversed_text
10
+ if input.startswith("upper:"):
11
+ return input[6:].strip().upper()
12
+ if input.startswith("lower:"):
13
+ return input[6:].strip().lower()
14
+ return "Unknown transformation."
15
+
16
+ text_transformer_tool = Tool(
17
+ name="text_ops",
18
+ func=text_transformer,
19
+ description="Transform text: reverse, upper, lower."
20
+ )
tools/wiki_content_fetcher.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+ import wikipedia as wiki
3
+
4
+ def wiki_content_fetcher(input: str) -> str:
5
+ """A tool that fetches Wikipedia article content based on a title."""
6
+
7
+ try:
8
+ page = wiki.page(input).html()
9
+ return to_markdown(page)
10
+ except wiki.exceptions.PageError:
11
+ return f"Wikipedia page '{input}' not found."
12
+
13
+ wiki_content_fetcher_tool = Tool(
14
+ name="wiki_page",
15
+ func=wiki_content_fetcher,
16
+ description="Fetch Wikipedia page content based on a title."
17
+ )
tools/wiki_title_finder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import Tool
2
+ import wikipedia as wiki
3
+
4
+ def wiki_title_finder(input: str) -> str:
5
+ """A tool that finds Wikipedia article titles based on a query."""
6
+
7
+ results = wiki.search(input)
8
+ return ", ".join(results) if results else "No matching Wikipedia article found."
9
+
10
+ wiki_title_finder_tool = Tool(
11
+ name="wiki_title_finder",
12
+ func=wiki_title_finder,
13
+ description="Find related Wikipedia article page titles based on a query."
14
+ )