prthm11 commited on
Commit
65409c0
·
verified ·
1 Parent(s): d03eb21

Delete utils/agent2.py

Browse files
Files changed (1) hide show
  1. utils/agent2.py +0 -259
utils/agent2.py DELETED
@@ -1,259 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
-
4
- from langchain_core.tools import tool
5
- from langgraph.prebuilt import tools_condition, ToolNode
6
- from langgraph.graph import START, StateGraph, MessagesState
7
-
8
- from langchain_community.tools.tavily_search import TavilySearchResults
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, HumanMessage
11
- from langchain_community.vectorstores import FAISS
12
- from langchain_huggingface import HuggingFaceEmbeddings
13
- from langchain_groq import ChatGroq
14
- from langchain.tools.retriever import create_retriever_tool
15
-
16
- """LangGraph Agent"""
17
- import os
18
- from dotenv import load_dotenv
19
- from langgraph.graph import START, StateGraph, MessagesState
20
- from langgraph.prebuilt import tools_condition
21
- from langgraph.prebuilt import ToolNode
22
- from langchain_google_genai import ChatGoogleGenerativeAI
23
- from langchain_groq import ChatGroq
24
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
25
- from langchain_community.tools.tavily_search import TavilySearchResults
26
- from langchain_community.document_loaders import WikipediaLoader
27
- from langchain_community.document_loaders import ArxivLoader
28
- from langchain_community.vectorstores import SupabaseVectorStore,FAISS
29
- from langchain_core.messages import SystemMessage, HumanMessage
30
- from langchain_core.tools import tool
31
- from langchain.tools.retriever import create_retriever_tool
32
- #from supabase.client import Client, create_client
33
-
34
- # ──────────────────────────────────────────────────────────
35
- # ENV
36
- # ──────────────────────────────────────────────────────────
37
- load_dotenv()
38
- # API Keys from .env file
39
- os.environ.setdefault("OPENAI_API_KEY", "<YOUR_OPENAI_KEY>") # Set your own key or env var
40
- os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY", "default_key_or_placeholder")
41
- os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY", "default_key_or_placeholder")
42
-
43
- # Tavily API Key
44
- TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "default_key_or_placeholder")
45
- _forbidden = ["nsfw", "porn", "sex", "explicit"]
46
- _playwright_available = True # set False to disable Playwright
47
-
48
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim = 768
49
-
50
- @tool
51
- def multiply(a: int, b: int) -> int:
52
- """Multiply two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
- return a * b
59
-
60
- @tool
61
- def add(a: int, b: int) -> int:
62
- """Add two numbers.
63
-
64
- Args:
65
- a: first int
66
- b: second int
67
- """
68
- return a + b
69
-
70
- @tool
71
- def subtract(a: int, b: int) -> int:
72
- """Subtract two numbers.
73
-
74
- Args:
75
- a: first int
76
- b: second int
77
- """
78
- return a - b
79
-
80
- @tool
81
- def divide(a: int, b: int) -> int:
82
- """Divide two numbers.
83
-
84
- Args:
85
- a: first int
86
- b: second int
87
- """
88
- if b == 0:
89
- raise ValueError("Cannot divide by zero.")
90
- return a / b
91
-
92
- @tool
93
- def modulus(a: int, b: int) -> int:
94
- """Get the modulus of two numbers.
95
-
96
- Args:
97
- a: first int
98
- b: second int
99
- """
100
- return a % b
101
-
102
- @tool
103
- def wiki_search(query: str) -> str:
104
- """Search Wikipedia for a query and return maximum 2 results.
105
-
106
- Args:
107
- query: The search query."""
108
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
109
- formatted_search_docs = "\n\n---\n\n".join(
110
- [
111
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
112
- for doc in search_docs
113
- ])
114
- return {"wiki_results": formatted_search_docs}
115
-
116
- @tool
117
- def web_search(query: str) -> str:
118
- """Search Tavily for a query and return maximum 3 results.
119
-
120
- Args:
121
- query: The search query."""
122
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
123
- formatted_search_docs = "\n\n---\n\n".join(
124
- [
125
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
126
- for doc in search_docs
127
- ])
128
- return {"web_results": formatted_search_docs}
129
-
130
- @tool
131
- def arvix_search(query: str) -> str:
132
- """Search Arxiv for a query and return maximum 3 result.
133
-
134
- Args:
135
- query: The search query."""
136
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
137
- formatted_search_docs = "\n\n---\n\n".join(
138
- [
139
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
140
- for doc in search_docs
141
- ])
142
- return {"arvix_results": formatted_search_docs}
143
-
144
-
145
-
146
- # load the system prompt from the file
147
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
148
- system_prompt = f.read()
149
-
150
- # System message
151
- sys_msg = SystemMessage(content=system_prompt)
152
-
153
- # build a retriever
154
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
155
-
156
- INDEX_PATH = "faiss_index"
157
- if os.path.exists(INDEX_PATH):
158
- vector_store = FAISS.load_local(INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
159
- else:
160
- vector_store = FAISS.from_texts(["__init__"], embeddings)
161
- vector_store.save_local(INDEX_PATH)
162
-
163
- create_retriever_tool = create_retriever_tool(
164
- retriever=vector_store.as_retriever(),
165
- name="Question Search",
166
- description="A tool to retrieve similar questions from a local FAISS vector store."
167
- )
168
-
169
-
170
-
171
- tools = [
172
- multiply,
173
- add,
174
- subtract,
175
- divide,
176
- modulus,
177
- wiki_search,
178
- web_search,
179
- arvix_search,
180
- ]
181
-
182
- # Build graph function
183
- def build_graph(provider: str = "groq"):
184
- """Build the graph"""
185
- # Load environment variables from .env file
186
- if provider == "google":
187
- # Google Gemini
188
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
189
- elif provider == "groq":
190
- # Groq https://console.groq.com/docs/models
191
- try:
192
- llm = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
193
- except Exception as e:
194
- print(f"Error initializing Groq: {str(e)}")
195
- raise
196
- elif provider == "huggingface":
197
- # TODO: Add huggingface endpoint
198
- llm = ChatHuggingFace(
199
- llm=HuggingFaceEndpoint(
200
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
201
- temperature=0,
202
- ),
203
- )
204
- else:
205
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
206
-
207
- # Bind tools to LLM
208
- llm_with_tools = llm.bind_tools(tools)
209
-
210
- # Node
211
- def assistant(state: MessagesState):
212
- """Assistant node"""
213
- try:
214
- if not state["messages"] or not state["messages"][-1].content:
215
- raise ValueError("Empty message content")
216
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
217
- except Exception as e:
218
- print(f"Error in assistant node: {str(e)}")
219
- raise
220
-
221
- def retriever(state: MessagesState):
222
- """Retriever node"""
223
- try:
224
- if not state["messages"] or not state["messages"][0].content:
225
- raise ValueError("Empty message content")
226
- similar_question = vector_store.similarity_search(state["messages"][0].content)
227
- example_msg = HumanMessage(
228
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
229
- )
230
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
231
- except Exception as e:
232
- print(f"Error in retriever node: {str(e)}")
233
- raise
234
-
235
- builder = StateGraph(MessagesState)
236
- builder.add_node("retriever", retriever)
237
- builder.add_node("assistant", assistant)
238
- builder.add_node("tools", ToolNode(tools))
239
- builder.add_edge(START, "retriever")
240
- builder.add_edge("retriever", "assistant")
241
- builder.add_conditional_edges(
242
- "assistant",
243
- tools_condition,
244
- )
245
- builder.add_edge("tools", "assistant")
246
-
247
- # Compile graph
248
- return builder.compile()
249
-
250
- # test
251
- if __name__ == "__main__":
252
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
253
- # Build the graph
254
- graph = build_graph(provider="groq")
255
- # Run the graph
256
- messages = [HumanMessage(content=question)]
257
- messages = graph.invoke({"messages": messages})
258
- for m in messages["messages"]:
259
- m.pretty_print()