Prasanthkumar commited on
Commit
166ba87
·
verified ·
1 Parent(s): f17a5c8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -135
model.py CHANGED
@@ -6,167 +6,83 @@ from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.document_loaders import WikipediaLoader
11
- from langchain_community.document_loaders import ArxivLoader
12
  from langchain_community.vectorstores import SupabaseVectorStore
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
- from langchain.tools.retriever import create_retriever_tool
16
  from supabase.client import Client, create_client
17
- import os
18
  load_dotenv()
19
 
20
  url = os.getenv("SUPABASE_URL")
21
  key = os.getenv("SUPABASE_KEY")
22
 
23
-
24
  @tool
25
  def multiply(a: int, b: int) -> int:
26
- """Multiply two numbers.
27
- Args:
28
- a: first int
29
- b: second int
30
- """
31
  return a * b
32
 
33
  @tool
34
  def add(a: int, b: int) -> int:
35
- """Add two numbers.
36
-
37
- Args:
38
- a: first int
39
- b: second int
40
- """
41
  return a + b
42
 
43
  @tool
44
  def subtract(a: int, b: int) -> int:
45
- """Subtract two numbers.
46
-
47
- Args:
48
- a: first int
49
- b: second int
50
- """
51
  return a - b
52
 
53
  @tool
54
- def divide(a: int, b: int) -> int:
55
- """Divide two numbers.
56
-
57
- Args:
58
- a: first int
59
- b: second int
60
- """
61
  if b == 0:
62
  raise ValueError("Cannot divide by zero.")
63
  return a / b
64
 
65
  @tool
66
  def modulus(a: int, b: int) -> int:
67
- """Get the modulus of two numbers.
68
-
69
- Args:
70
- a: first int
71
- b: second int
72
- """
73
  return a % b
74
 
 
75
  @tool
76
  def wiki_search(query: str) -> str:
77
- """Search Wikipedia for a query and return maximum 2 results.
78
-
79
- Args:
80
- query: The search query."""
81
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
82
- formatted_search_docs = "\n\n---\n\n".join(
83
- [
84
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
85
- for doc in search_docs
86
- ])
87
- return {"wiki_results": formatted_search_docs}
88
 
89
  @tool
90
  def web_search(query: str) -> str:
91
- """Search Tavily for a query and return maximum 3 results.
92
-
93
- Args:
94
- query: The search query."""
95
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
96
- formatted_search_docs = "\n\n---\n\n".join(
97
- [
98
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
99
- for doc in search_docs
100
- ])
101
- return {"web_results": formatted_search_docs}
102
 
103
  @tool
104
  def arvix_search(query: str) -> str:
105
- """Search Arxiv for a query and return maximum 3 result.
106
-
107
- Args:
108
- query: The search query."""
109
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
110
- formatted_search_docs = "\n\n---\n\n".join(
111
- [
112
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
113
- for doc in search_docs
114
- ])
115
- return {"arvix_results": formatted_search_docs}
116
 
117
-
118
-
119
- # load the system prompt from the file
120
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
121
  system_prompt = f.read()
122
-
123
- # System message
124
  sys_msg = SystemMessage(content=system_prompt)
125
 
126
- # build a retriever
127
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
128
- supabase: Client = create_client(
129
- url,
130
- key
131
- )
132
  vector_store = SupabaseVectorStore(
133
  client=supabase,
134
- embedding= embeddings,
135
  table_name="documents",
136
  query_name="match_documents_langchain",
137
  )
138
- create_retriever_tool = create_retriever_tool(
139
- retriever=vector_store.as_retriever(),
140
- name="Question Search",
141
- description="A tool to retrieve similar questions from a vector store.",
142
- )
143
-
144
 
 
145
 
146
- tools = [
147
- multiply,
148
- add,
149
- subtract,
150
- divide,
151
- modulus,
152
- wiki_search,
153
- web_search,
154
- arvix_search,
155
- ]
156
-
157
- # Build graph function
158
  def build_graph(provider: str = "groq"):
159
- """Build the graph"""
160
- # Load environment variables from .env file
161
  if provider == "google":
162
- # Google Gemini
163
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
164
  elif provider == "groq":
165
- # Groq https://console.groq.com/docs/models
166
  api_key = os.getenv('GROQ_API')
167
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0,api_key = api_key) # optional : qwen-qwq-32b gemma2-9b-it
168
  elif provider == "huggingface":
169
- # TODO: Add huggingface endpoint
170
  llm = ChatHuggingFace(
171
  llm=HuggingFaceEndpoint(
172
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -175,59 +91,36 @@ def build_graph(provider: str = "groq"):
175
  )
176
  else:
177
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
178
- # Bind tools to LLM
179
  llm_with_tools = llm.bind_tools(tools)
180
 
181
- # Node
182
  def assistant(state: MessagesState):
183
- """Assistant node"""
184
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
185
-
186
- # def retriever(state: MessagesState):
187
- # """Retriever node"""
188
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
189
- # example_msg = HumanMessage(
190
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
191
- # )
192
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
193
 
194
  def retriever(state: MessagesState):
195
- """Retriever node"""
196
- query = state["messages"][0].content
197
- similar_question = vector_store.similarity_search(query)
198
-
199
- if not similar_question:
200
- print("No similar documents found.")
201
  return {"messages": [sys_msg] + state["messages"]}
202
-
203
  example_msg = HumanMessage(
204
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
205
  )
206
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
207
 
208
-
209
  builder = StateGraph(MessagesState)
210
  builder.add_node("retriever", retriever)
211
  builder.add_node("assistant", assistant)
212
  builder.add_node("tools", ToolNode(tools))
213
  builder.add_edge(START, "retriever")
214
  builder.add_edge("retriever", "assistant")
215
- builder.add_conditional_edges(
216
- "assistant",
217
- tools_condition,
218
- )
219
  builder.add_edge("tools", "assistant")
220
 
221
- # Compile graph
222
  return builder.compile()
223
 
224
- # test
225
  if __name__ == "__main__":
226
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
227
- # Build the graph
228
  graph = build_graph(provider="groq")
229
- # Run the graph
230
  messages = [HumanMessage(content=question)]
231
- messages = graph.invoke({"messages": messages})
232
- for m in messages["messages"]:
233
- m.pretty_print()
 
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
 
10
  from langchain_community.vectorstores import SupabaseVectorStore
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
13
+ from langchain_tavily import TavilySearch
14
  from supabase.client import Client, create_client
15
+
16
  load_dotenv()
17
 
18
  url = os.getenv("SUPABASE_URL")
19
  key = os.getenv("SUPABASE_KEY")
20
 
21
+ # Math Tools
22
  @tool
23
  def multiply(a: int, b: int) -> int:
 
 
 
 
 
24
  return a * b
25
 
26
  @tool
27
  def add(a: int, b: int) -> int:
 
 
 
 
 
 
28
  return a + b
29
 
30
  @tool
31
  def subtract(a: int, b: int) -> int:
 
 
 
 
 
 
32
  return a - b
33
 
34
  @tool
35
+ def divide(a: int, b: int) -> float:
 
 
 
 
 
 
36
  if b == 0:
37
  raise ValueError("Cannot divide by zero.")
38
  return a / b
39
 
40
  @tool
41
  def modulus(a: int, b: int) -> int:
 
 
 
 
 
 
42
  return a % b
43
 
44
+ # Search Tools
45
  @tool
46
  def wiki_search(query: str) -> str:
 
 
 
 
47
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
48
+ return "\n---\n".join([doc.page_content for doc in search_docs])
 
 
 
 
 
49
 
50
  @tool
51
  def web_search(query: str) -> str:
52
+ tavily = TavilySearch(k=3)
53
+ results = tavily.invoke(query)
54
+ return "\n---\n".join([doc.page_content for doc in results])
 
 
 
 
 
 
 
 
55
 
56
  @tool
57
  def arvix_search(query: str) -> str:
 
 
 
 
58
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
59
+ return "\n---\n".join([doc.page_content[:1000] for doc in search_docs])
 
 
 
 
 
60
 
61
+ # Load system prompt
 
 
62
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
63
  system_prompt = f.read()
 
 
64
  sys_msg = SystemMessage(content=system_prompt)
65
 
66
+ # Vector store setup
67
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
68
+ supabase: Client = create_client(url, key)
 
 
 
69
  vector_store = SupabaseVectorStore(
70
  client=supabase,
71
+ embedding=embeddings,
72
  table_name="documents",
73
  query_name="match_documents_langchain",
74
  )
75
+ retriever_tool = tool(name="Question Search", description="Retrieve similar questions from vector DB")(vector_store.as_retriever().invoke)
 
 
 
 
 
76
 
77
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search, retriever_tool]
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def build_graph(provider: str = "groq"):
 
 
80
  if provider == "google":
 
81
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
82
  elif provider == "groq":
 
83
  api_key = os.getenv('GROQ_API')
84
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=api_key)
85
  elif provider == "huggingface":
 
86
  llm = ChatHuggingFace(
87
  llm=HuggingFaceEndpoint(
88
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
91
  )
92
  else:
93
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
94
+
95
  llm_with_tools = llm.bind_tools(tools)
96
 
 
97
  def assistant(state: MessagesState):
98
+ return {"messages": [llm_with_tools.invoke(state["messages\])]}
 
 
 
 
 
 
 
 
 
99
 
100
  def retriever(state: MessagesState):
101
+ similar_docs = vector_store.similarity_search(state["messages"][0].content)
102
+ if not similar_docs:
 
 
 
 
103
  return {"messages": [sys_msg] + state["messages"]}
 
104
  example_msg = HumanMessage(
105
+ content=f"Here is a related example to help: \n\n{similar_docs[0].page_content}"
106
  )
107
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
108
 
 
109
  builder = StateGraph(MessagesState)
110
  builder.add_node("retriever", retriever)
111
  builder.add_node("assistant", assistant)
112
  builder.add_node("tools", ToolNode(tools))
113
  builder.add_edge(START, "retriever")
114
  builder.add_edge("retriever", "assistant")
115
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
116
  builder.add_edge("tools", "assistant")
117
 
 
118
  return builder.compile()
119
 
 
120
  if __name__ == "__main__":
121
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
122
  graph = build_graph(provider="groq")
 
123
  messages = [HumanMessage(content=question)]
124
+ output = graph.invoke({"messages": messages})
125
+ for m in output["messages"]:
126
+ m.pretty_print()