orbulat commited on
Commit
0ab18c4
·
verified ·
1 Parent(s): 3d13372

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +102 -18
agent.py CHANGED
@@ -6,8 +6,14 @@ from langchain_core.messages import SystemMessage, HumanMessage
6
  from langchain_core.tools import tool
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
- from youtube_transcript_api import YouTubeTranscriptApi
 
 
 
 
10
  import re
 
 
11
 
12
  # Load system prompt
13
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -16,7 +22,7 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
16
  # Tool: Wikipedia search
17
  @tool
18
  def wiki_search(query: str) -> str:
19
- """Search Wikipedia for a query and return max 2 results."""
20
  try:
21
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
22
  return "\n\n---\n\n".join([doc.page_content for doc in docs])
@@ -26,30 +32,96 @@ def wiki_search(query: str) -> str:
26
  # Tool: Tavily web search
27
  @tool
28
  def web_search(query: str) -> str:
29
- """Web search with Tavily (Google-like)."""
30
  try:
31
- docs = TavilySearchResults(max_results=3).invoke(query)
32
- return "\n\n---\n\n".join([doc.page_content for doc in docs])
 
 
33
  except Exception as e:
34
  return f"Web search failed: {e}"
35
 
36
- # Tool: YouTube transcript parser
37
  @tool
38
- def youtube_transcript(video_url: str) -> str:
39
- """Extract transcript from a YouTube video URL."""
40
  try:
41
- video_id = re.search(r"v=([a-zA-Z0-9_-]{11})", video_url)
42
- if not video_id:
43
- return "Invalid YouTube URL"
44
- transcript = YouTubeTranscriptApi.get_transcript(video_id.group(1))
45
- text = " ".join([entry['text'] for entry in transcript])
46
- return text
47
  except Exception as e:
48
- return f"Transcript fetch failed: {e}"
49
 
50
- tools = [wiki_search, web_search, youtube_transcript]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Build LangGraph
53
  def build_graph():
54
  llm = ChatOpenAI(
55
  model="gpt-4o",
@@ -74,7 +146,6 @@ def build_graph():
74
  builder.add_edge("tools", "assistant")
75
  return builder.compile()
76
 
77
- # Agent for GAIA benchmark
78
  class BasicAgent:
79
  def __init__(self):
80
  print("GAIA LangGraph Agent Initialized")
@@ -90,3 +161,16 @@ class BasicAgent:
90
  return final_msg
91
  except Exception as e:
92
  return f"FINAL ANSWER: error - {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain_core.tools import tool
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
+ from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound
10
+ from duckduckgo_search import DDGS
11
+ from langchain_community.document_loaders import ArxivLoader
12
+ from sympy import sympify
13
+ from PIL import Image
14
  import re
15
+ import requests
16
+ from io import BytesIO
17
 
18
  # Load system prompt
19
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
22
  # Tool: Wikipedia search
23
  @tool
24
  def wiki_search(query: str) -> str:
25
+ """Wikipedia search tool."""
26
  try:
27
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
28
  return "\n\n---\n\n".join([doc.page_content for doc in docs])
 
32
  # Tool: Tavily web search
33
  @tool
34
  def web_search(query: str) -> str:
35
+ """Tavily web search tool."""
36
  try:
37
+ results = TavilySearchResults(max_results=3).invoke(query)
38
+ if isinstance(results, list):
39
+ return "\n\n---\n\n".join([r["content"] if isinstance(r, dict) else str(r) for r in results])
40
+ return str(results)
41
  except Exception as e:
42
  return f"Web search failed: {e}"
43
 
44
+ # Tool: DuckDuckGo search
45
  @tool
46
+ def duckduckgo_search(query: str) -> str:
47
+ """DuckDuckGo search tool."""
48
  try:
49
+ with DDGS() as ddgs:
50
+ results = ddgs.text(query, max_results=3)
51
+ return "\n\n---\n\n".join([r["body"] for r in results if "body" in r])
 
 
 
52
  except Exception as e:
53
+ return f"DuckDuckGo search failed: {e}"
54
 
55
+ # Tool: YouTube transcript or duration extractor
56
+ @tool
57
+ def youtube_transcript(video_title_or_url: str) -> str:
58
+ """YouTube transcript or duration extractor tool."""
59
+ try:
60
+ with DDGS() as ddgs:
61
+ results = ddgs.videos(video_title_or_url, max_results=1)
62
+ if not results:
63
+ return "No video found by that title."
64
+ video = results[0]
65
+ video_url = video["url"]
66
+ duration = video.get("duration")
67
+ return f"Duration: {duration}"
68
+ except Exception as e:
69
+ return f"YouTube search failed: {e}"
70
+
71
+ # Tool: Arxiv paper fetcher (parse arXiv.org abstract directly)
72
+ @tool
73
+ def arxiv_fetch(query_or_id: str) -> str:
74
+ """Arxiv paper fetcher tool."""
75
+ try:
76
+ if re.match(r"\d{4}\.\d{5}(v\d+)?", query_or_id):
77
+ abs_url = f"https://arxiv.org/abs/{query_or_id}"
78
+ api_url = f"http://export.arxiv.org/api/query?id_list={query_or_id}"
79
+ res = requests.get(api_url)
80
+ if res.status_code == 200:
81
+ return res.text[:2000] + f"\n\nFull: {abs_url}"
82
+ return f"Could not retrieve metadata from arXiv API"
83
+ else:
84
+ docs = ArxivLoader(query=query_or_id, load_max_docs=2).load()
85
+ return "\n\n---\n\n".join([doc.page_content for doc in docs])
86
+ except Exception as e:
87
+ return f"ArXiv fetch failed: {e}"
88
+
89
+ @tool
90
+ def math_solver(expression: str) -> str:
91
+ """Math solver tool."""
92
+ try:
93
+ result = sympify(expression).evalf()
94
+ return str(result)
95
+ except Exception as e:
96
+ return f"Math error: {e}"
97
+
98
+ @tool
99
+ def reverse_text(text: str) -> str:
100
+ """Text reversal tool."""
101
+ return text[::-1]
102
+
103
+ @tool
104
+ def image_info(url: str) -> str:
105
+ """Image dimension fetcher tool."""
106
+ try:
107
+ response = requests.get(url)
108
+ img = Image.open(BytesIO(response.content))
109
+ return f"Image size: {img.size} (width x height)"
110
+ except Exception as e:
111
+ return f"Image error: {e}"
112
+
113
+ # Tools list
114
+ tools = [
115
+ wiki_search,
116
+ web_search,
117
+ duckduckgo_search,
118
+ youtube_transcript,
119
+ arxiv_fetch,
120
+ math_solver,
121
+ reverse_text,
122
+ image_info
123
+ ]
124
 
 
125
  def build_graph():
126
  llm = ChatOpenAI(
127
  model="gpt-4o",
 
146
  builder.add_edge("tools", "assistant")
147
  return builder.compile()
148
 
 
149
  class BasicAgent:
150
  def __init__(self):
151
  print("GAIA LangGraph Agent Initialized")
 
161
  return final_msg
162
  except Exception as e:
163
  return f"FINAL ANSWER: error - {str(e)}"
164
+
165
+ if __name__ == "__main__":
166
+ agent = BasicAgent()
167
+ questions = [
168
+ "What is the zip code of the Eiffel Tower?",
169
+ "What is the capital city of Australia?",
170
+ "How long is the video titled 'The History of Time' on YouTube?",
171
+ "What does the arXiv paper '2303.12712' say about Transformer performance?",
172
+ ]
173
+
174
+ for q in questions:
175
+ print(f"\n[Question]: {q}")
176
+ print(agent(q))