orbulat commited on
Commit
f3da208
·
verified ·
1 Parent(s): 4fad2a0

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +35 -23
agent.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
  from langgraph.graph import START, StateGraph, MessagesState
3
  from langgraph.prebuilt import ToolNode, tools_condition
4
- from langchain_openai import ChatOpenAI
 
 
5
  from langchain_core.messages import SystemMessage, HumanMessage
6
  from langchain_core.tools import tool
7
  from langchain_community.tools.tavily_search import TavilySearchResults
@@ -14,6 +16,9 @@ from PIL import Image
14
  import re
15
  import requests
16
  from io import BytesIO
 
 
 
17
 
18
  # Load system prompt
19
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -22,7 +27,7 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
22
  # Tool: Wikipedia search
23
  @tool
24
  def wiki_search(query: str) -> str:
25
- """Wikipedia search tool."""
26
  try:
27
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
28
  return "\n\n---\n\n".join([doc.page_content for doc in docs])
@@ -32,7 +37,7 @@ def wiki_search(query: str) -> str:
32
  # Tool: Tavily web search
33
  @tool
34
  def web_search(query: str) -> str:
35
- """Tavily web search tool."""
36
  try:
37
  results = TavilySearchResults(max_results=3).invoke(query)
38
  if isinstance(results, list):
@@ -44,7 +49,7 @@ def web_search(query: str) -> str:
44
  # Tool: DuckDuckGo search
45
  @tool
46
  def duckduckgo_search(query: str) -> str:
47
- """DuckDuckGo search tool."""
48
  try:
49
  with DDGS() as ddgs:
50
  results = ddgs.text(query, max_results=3)
@@ -55,23 +60,21 @@ def duckduckgo_search(query: str) -> str:
55
  # Tool: YouTube transcript or duration extractor
56
  @tool
57
  def youtube_transcript(video_title_or_url: str) -> str:
58
- """YouTube transcript or duration extractor tool."""
59
  try:
60
  with DDGS() as ddgs:
61
  results = ddgs.videos(video_title_or_url, max_results=1)
62
  if not results:
63
  return "No video found by that title."
64
  video = results[0]
65
- video_url = video["url"]
66
- duration = video.get("duration")
67
- return f"Duration: {duration}"
68
  except Exception as e:
69
  return f"YouTube search failed: {e}"
70
 
71
- # Tool: Arxiv paper fetcher (parse arXiv.org abstract directly)
72
  @tool
73
  def arxiv_fetch(query_or_id: str) -> str:
74
- """Arxiv paper fetcher tool."""
75
  try:
76
  if re.match(r"\d{4}\.\d{5}(v\d+)?", query_or_id):
77
  abs_url = f"https://arxiv.org/abs/{query_or_id}"
@@ -88,7 +91,7 @@ def arxiv_fetch(query_or_id: str) -> str:
88
 
89
  @tool
90
  def math_solver(expression: str) -> str:
91
- """Math solver tool."""
92
  try:
93
  result = sympify(expression).evalf()
94
  return str(result)
@@ -97,12 +100,12 @@ def math_solver(expression: str) -> str:
97
 
98
  @tool
99
  def reverse_text(text: str) -> str:
100
- """Text reversal tool."""
101
  return text[::-1]
102
 
103
  @tool
104
  def image_info(url: str) -> str:
105
- """Image dimension fetcher tool."""
106
  try:
107
  response = requests.get(url)
108
  img = Image.open(BytesIO(response.content))
@@ -122,12 +125,21 @@ tools = [
122
  image_info
123
  ]
124
 
125
- def build_graph():
126
- llm = ChatOpenAI(
127
- model="gpt-4o",
128
- temperature=0,
129
- api_key=os.getenv("OPENAI_API_KEY")
130
- )
 
 
 
 
 
 
 
 
 
131
  llm_with_tools = llm.bind_tools(tools)
132
 
133
  def system_node(state: MessagesState):
@@ -147,9 +159,9 @@ def build_graph():
147
  return builder.compile()
148
 
149
  class BasicAgent:
150
- def __init__(self):
151
- print("GAIA LangGraph Agent Initialized")
152
- self.graph = build_graph()
153
 
154
  def __call__(self, question: str) -> str:
155
  try:
@@ -163,7 +175,7 @@ class BasicAgent:
163
  return f"FINAL ANSWER: error - {str(e)}"
164
 
165
  if __name__ == "__main__":
166
- agent = BasicAgent()
167
  questions = [
168
  "What is the zip code of the Eiffel Tower?",
169
  "What is the capital city of Australia?",
 
1
  import os
2
  from langgraph.graph import START, StateGraph, MessagesState
3
  from langgraph.prebuilt import ToolNode, tools_condition
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_groq import ChatGroq
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
7
  from langchain_core.messages import SystemMessage, HumanMessage
8
  from langchain_core.tools import tool
9
  from langchain_community.tools.tavily_search import TavilySearchResults
 
16
  import re
17
  import requests
18
  from io import BytesIO
19
+ from dotenv import load_dotenv
20
+
21
+ load_dotenv()
22
 
23
  # Load system prompt
24
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
27
  # Tool: Wikipedia search
28
  @tool
29
  def wiki_search(query: str) -> str:
30
+ """Search Wikipedia for a query and return content from up to 2 documents."""
31
  try:
32
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
33
  return "\n\n---\n\n".join([doc.page_content for doc in docs])
 
37
  # Tool: Tavily web search
38
  @tool
39
  def web_search(query: str) -> str:
40
+ """Search the web using Tavily and return content from up to 3 results."""
41
  try:
42
  results = TavilySearchResults(max_results=3).invoke(query)
43
  if isinstance(results, list):
 
49
  # Tool: DuckDuckGo search
50
  @tool
51
  def duckduckgo_search(query: str) -> str:
52
+ """Search using DuckDuckGo and return summaries from up to 3 results."""
53
  try:
54
  with DDGS() as ddgs:
55
  results = ddgs.text(query, max_results=3)
 
60
  # Tool: YouTube transcript or duration extractor
61
  @tool
62
  def youtube_transcript(video_title_or_url: str) -> str:
63
+ """Get duration of a YouTube video using its title or URL."""
64
  try:
65
  with DDGS() as ddgs:
66
  results = ddgs.videos(video_title_or_url, max_results=1)
67
  if not results:
68
  return "No video found by that title."
69
  video = results[0]
70
+ return f"Duration: {video.get('duration')}"
 
 
71
  except Exception as e:
72
  return f"YouTube search failed: {e}"
73
 
74
+ # Tool: Arxiv paper fetcher
75
  @tool
76
  def arxiv_fetch(query_or_id: str) -> str:
77
+ """Fetch metadata from arXiv either by ID or search query."""
78
  try:
79
  if re.match(r"\d{4}\.\d{5}(v\d+)?", query_or_id):
80
  abs_url = f"https://arxiv.org/abs/{query_or_id}"
 
91
 
92
  @tool
93
  def math_solver(expression: str) -> str:
94
+ """Evaluate a math expression and return the result."""
95
  try:
96
  result = sympify(expression).evalf()
97
  return str(result)
 
100
 
101
  @tool
102
  def reverse_text(text: str) -> str:
103
+ """Reverse the input string."""
104
  return text[::-1]
105
 
106
  @tool
107
  def image_info(url: str) -> str:
108
+ """Fetch image size (width x height) from a given URL."""
109
  try:
110
  response = requests.get(url)
111
  img = Image.open(BytesIO(response.content))
 
125
  image_info
126
  ]
127
 
128
+ def build_graph(provider: str = "groq"):
129
+ if provider == "google":
130
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
131
+ elif provider == "groq":
132
+ llm = ChatGroq(model="llama3-70b-8192", temperature=0)
133
+ elif provider == "huggingface":
134
+ llm = ChatHuggingFace(
135
+ llm=HuggingFaceEndpoint(
136
+ url="https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct",
137
+ temperature=0,
138
+ ),
139
+ )
140
+ else:
141
+ raise ValueError("Invalid provider. Choose 'google', 'groq', or 'huggingface'.")
142
+
143
  llm_with_tools = llm.bind_tools(tools)
144
 
145
  def system_node(state: MessagesState):
 
159
  return builder.compile()
160
 
161
  class BasicAgent:
162
+ def __init__(self, provider="groq"):
163
+ print(f"GAIA LangGraph Agent Initialized using {provider}")
164
+ self.graph = build_graph(provider)
165
 
166
  def __call__(self, question: str) -> str:
167
  try:
 
175
  return f"FINAL ANSWER: error - {str(e)}"
176
 
177
  if __name__ == "__main__":
178
+ agent = BasicAgent(provider="groq")
179
  questions = [
180
  "What is the zip code of the Eiffel Tower?",
181
  "What is the capital city of Australia?",