Parthiban007 commited on
Commit
492fbc1
·
verified ·
1 Parent(s): 4eac0bf

Update llm_agent.py

Browse files
Files changed (1) hide show
  1. llm_agent.py +168 -174
llm_agent.py CHANGED
@@ -1,174 +1,168 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from langchain_google_genai import ChatGoogleGenerativeAI,GoogleGenerativeAIEmbeddings
7
- from langchain_community.document_loaders import WikipediaLoader
8
- from langchain_community.document_loaders import ArxivLoader
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.vectorstores import FAISS
11
- from langchain_core.messages import SystemMessage,HumanMessage
12
- from langchain_core.tools import tool
13
-
14
-
15
- load_dotenv()
16
-
17
- os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
18
-
19
- @tool
20
- def add(a:int,b:int)->int:
21
- """Add two Numbers
22
- Args:
23
- a:int
24
- b:int
25
-
26
- """
27
- return a+b
28
-
29
- @tool
30
- def subtract(a:int,b:int)->int:
31
- """Subtract two numbers
32
- Args:
33
- a:int
34
- b:int
35
- """
36
- return a-b
37
-
38
- @tool
39
- def multiply(a:int,b:int)->int:
40
- """Multiply Two Numbers
41
- Args:
42
- a:int
43
- b:int
44
- """
45
- return a*b
46
-
47
- @tool
48
- def divide(a:int,b:int)->int:
49
- """Divide two numbers
50
- Args:
51
- a:int
52
- b:int
53
- """
54
- if b==0:
55
- raise ValueError("Cannot Divide by Zero")
56
- return a//b
57
-
58
- @tool
59
- def modulus(a:int,b:int)->int:
60
- """Modulus of the two numbers
61
- Args:
62
- a:int
63
- b:int
64
- """
65
- return a%b
66
-
67
-
68
- @tool
69
- def wiki_search(query:str)->str:
70
- """Search Wikipedia for a query and return maximum 2 results
71
- Args:
72
- query: The Search Query : str
73
- """
74
- print(query)
75
-
76
- search_docs = WikipediaLoader(query=query,load_max_docs=2).load()
77
-
78
- return {"wiki_results": search_docs}
79
-
80
-
81
- @tool
82
- def web_search(query:str)->str:
83
- """ search Tavily for a query and return maximum 3 results
84
-
85
- Args:
86
- query: The Search Query
87
- """
88
- search_docs = TavilySearchResults(max_results=3).invoke(input=query)
89
-
90
- return {"web_results": search_docs}
91
-
92
- @tool
93
- def arxiv_search(query:str)->str:
94
- """"search Arxiv for a query and return maximum 3 results
95
- Args:
96
- query: search query
97
- """
98
-
99
- search_docs = ArxivLoader(query=query,load_max_docs = 3).load()
100
-
101
- return {"arxiv_resutls": search_docs}
102
-
103
-
104
- system_prompt = """
105
- You are a helpful assistant tasked with answering questions using a set of tools.
106
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
107
- FINAL ANSWER: [YOUR FINAL ANSWER].
108
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
109
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
110
- """
111
-
112
- sys_msg = SystemMessage(content=system_prompt)
113
-
114
-
115
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
116
-
117
-
118
-
119
- tools = [
120
- add,
121
- subtract,
122
- multiply,
123
- divide,
124
- modulus,
125
- wiki_search,
126
- web_search,
127
- arxiv_search,
128
- ]
129
-
130
-
131
-
132
- def build_graph():
133
- llm = ChatGoogleGenerativeAI(model = "gemini-2.0-flash")
134
- print(tools)
135
- # Bind tools to LLM
136
- llm_with_tools = llm.bind_tools(tools)
137
-
138
- # Node
139
- def assistant(state: MessagesState):
140
- """Assistant node"""
141
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
142
-
143
- # def retriever(state: MessagesState):
144
- # """Retriever node"""
145
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
146
- # example_msg = HumanMessage(
147
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question}",
148
- # )
149
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
150
-
151
- builder = StateGraph(MessagesState)
152
- # builder.add_node("retriever", retriever)
153
- builder.add_node("assistant", assistant)
154
- builder.add_node("tools", ToolNode(tools))
155
- builder.add_edge(START, "assistant")
156
- # builder.add_edge("retriever", "assistant")
157
- builder.add_conditional_edges(
158
- "assistant",
159
- tools_condition,
160
- )
161
- builder.add_edge("tools", "assistant")
162
-
163
- # Compile graph
164
- return builder.compile()
165
-
166
-
167
- if __name__ == "__main__":
168
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
169
- graph = build_graph()
170
- # Run the graph
171
- messages = [HumanMessage(content=question)]
172
- messages = graph.invoke({"messages": messages})
173
- for m in messages["messages"]:
174
- m.pretty_print()
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_groq import ChatGroq
8
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader
11
+ from langchain_community.document_loaders import ArxivLoader
12
+ from langchain_community.vectorstores import SupabaseVectorStore
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
+
16
+
17
+ load_dotenv()
18
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
19
+ @tool
20
+ def multiply(a: int, b: int) -> int:
21
+ """Multiply two numbers.
22
+ Args:
23
+ a: first int
24
+ b: second int
25
+ """
26
+ return a * b
27
+
28
+ @tool
29
+ def add(a: int, b: int) -> int:
30
+ """Add two numbers.
31
+
32
+ Args:
33
+ a: first int
34
+ b: second int
35
+ """
36
+ return a + b
37
+
38
+ @tool
39
+ def subtract(a: int, b: int) -> int:
40
+ """Subtract two numbers.
41
+
42
+ Args:
43
+ a: first int
44
+ b: second int
45
+ """
46
+ return a - b
47
+
48
+ @tool
49
+ def divide(a: int, b: int) -> int:
50
+ """Divide two numbers.
51
+
52
+ Args:
53
+ a: first int
54
+ b: second int
55
+ """
56
+ if b == 0:
57
+ raise ValueError("Cannot divide by zero.")
58
+ return a / b
59
+
60
+ @tool
61
+ def modulus(a: int, b: int) -> int:
62
+ """Get the modulus of two numbers.
63
+
64
+ Args:
65
+ a: first int
66
+ b: second int
67
+ """
68
+ return a % b
69
+
70
+ @tool
71
+ def wiki_search(query: str) -> str:
72
+ """Search Wikipedia for a query and return maximum 2 results.
73
+
74
+ Args:
75
+ query: The search query."""
76
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
77
+ formatted_search_docs = "\n\n---\n\n".join(
78
+ [
79
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
80
+ for doc in search_docs
81
+ ])
82
+ return {"wiki_results": formatted_search_docs}
83
+
84
+ @tool
85
+ def web_search(query: str) -> str:
86
+ """Search Tavily for a query and return maximum 3 results.
87
+
88
+ Args:
89
+ query: The search query."""
90
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
91
+ formatted_search_docs = "\n\n---\n\n".join(
92
+ [
93
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
94
+ for doc in search_docs
95
+ ])
96
+ return {"web_results": formatted_search_docs}
97
+
98
+ @tool
99
+ def arvix_search(query: str) -> str:
100
+ """Search Arxiv for a query and return maximum 3 result.
101
+
102
+ Args:
103
+ query: The search query."""
104
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
105
+ formatted_search_docs = "\n\n---\n\n".join(
106
+ [
107
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
108
+ for doc in search_docs
109
+ ])
110
+ return {"arvix_results": formatted_search_docs}
111
+
112
+
113
+
114
+ system_prompt = """
115
+ You are a helpful assistant tasked with answering questions using a set of tools.
116
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
117
+ FINAL ANSWER: [YOUR FINAL ANSWER].
118
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
119
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
120
+ """
121
+ sys_msg = SystemMessage(content=system_prompt)
122
+
123
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
+
125
+
126
+
127
+ tools = [
128
+ multiply,
129
+ add,
130
+ subtract,
131
+ divide,
132
+ modulus,
133
+ wiki_search,
134
+ web_search,
135
+ arvix_search,
136
+ ]
137
+
138
+ # Build graph function
139
+ def build_graph():
140
+
141
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
142
+
143
+ llm_with_tools = llm.bind_tools(tools)
144
+
145
+ def assistant(state: MessagesState):
146
+ """Assistant node"""
147
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
148
+
149
+
150
+ builder = StateGraph(MessagesState)
151
+ builder.add_node("assistant", assistant)
152
+ builder.add_node("tools", ToolNode(tools))
153
+ builder.add_edge(START, "assistant")
154
+ builder.add_conditional_edges(
155
+ "assistant",
156
+ tools_condition,
157
+ )
158
+ builder.add_edge("tools", "assistant")
159
+
160
+ return builder.compile()
161
+
162
+ if __name__ == "__main__":
163
+ question = "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?"
164
+ graph = build_graph()
165
+ messages = [HumanMessage(content=question)]
166
+ messages = graph.invoke({"messages": messages})
167
+ for m in messages["messages"]:
168
+ m.pretty_print()