tsrrus commited on
Commit
2237b5d
·
verified ·
1 Parent(s): 66ca9ad

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +44 -100
agent.py CHANGED
@@ -1,4 +1,3 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
@@ -121,33 +120,21 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
  # build a retriever
124
- try:
125
- embeddings = HuggingFaceEmbeddings(
126
- model_name="sentence-transformers/all-mpnet-base-v2"
127
- ) # dim=768
128
- supabase: Client = create_client(
129
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
130
- )
131
- vector_store = SupabaseVectorStore(
132
- client=supabase,
133
- embedding=embeddings,
134
- table_name="documents",
135
- query_name="match_documents_langchain",
136
- )
137
-
138
- # Test the connection
139
- test_results = vector_store.similarity_search("test query", k=1)
140
- print(f"Vector store initialized successfully. Test returned {len(test_results)} results.")
141
-
142
- except Exception as e:
143
- print(f"Warning: Vector store initialization failed: {e}")
144
- vector_store = None
145
-
146
  create_retriever_tool = create_retriever_tool(
147
- retriever=vector_store.as_retriever() if vector_store else None,
148
  name="Question Search",
149
  description="A tool to retrieve similar questions from a vector store.",
150
- ) if vector_store else None
151
 
152
 
153
 
@@ -163,20 +150,25 @@ tools = [
163
  ]
164
 
165
  # Build graph function
166
- def build_graph(provider: str = "huggingface"):
167
- """Build the graph with improved error handling"""
168
-
169
- if provider == "groq":
170
- llm = ChatGroq(
171
- model="qwen-qwq-32b", temperature=0
172
- ) # optional : qwen-qwq-32b gemma2-9b-it
 
 
173
  elif provider == "huggingface":
 
174
  llm = ChatHuggingFace(
175
- llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
 
 
 
176
  )
177
  else:
178
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
179
-
180
  # Bind tools to LLM
181
  llm_with_tools = llm.bind_tools(tools)
182
 
@@ -184,76 +176,28 @@ def build_graph(provider: str = "huggingface"):
184
  def assistant(state: MessagesState):
185
  """Assistant node"""
186
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
187
 
188
  def retriever(state: MessagesState):
189
- """Retriever node with error handling"""
190
- try:
191
- # Check if vector_store is available
192
- if vector_store is None:
193
- print("Vector store not available, proceeding without retrieval")
194
- return {"messages": [sys_msg] + state["messages"]}
195
-
196
- similar_question = vector_store.similarity_search(state["messages"][0].content)
197
-
198
- # Check if we have results before accessing them
199
- if similar_question and len(similar_question) > 0:
200
- example_msg = HumanMessage(
201
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
202
- )
203
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
204
- else:
205
- # No similar questions found, proceed without reference
206
- print("No similar questions found in vector store")
207
- return {"messages": [sys_msg] + state["messages"]}
208
-
209
- except Exception as e:
210
- print(f"Error in retriever: {e}")
211
- # Fallback: continue without retrieval
212
- return {"messages": [sys_msg] + state["messages"]}
213
 
214
  builder = StateGraph(MessagesState)
215
  builder.add_node("retriever", retriever)
216
- builder.add_node("assistant", assistant)
217
- builder.add_node("tools", ToolNode(tools))
218
- builder.add_edge(START, "retriever")
219
- builder.add_edge("retriever", "assistant")
220
- builder.add_conditional_edges(
221
- "assistant",
222
- tools_condition,
223
- )
224
- builder.add_edge("tools", "assistant")
225
 
226
- # Compile graph
227
- return builder.compile()
228
-
229
- def retriever(state: MessagesState):
230
- """Retriever node with error handling"""
231
- try:
232
- similar_question = vector_store.similarity_search(state["messages"][0].content)
233
-
234
- # Check if we have results before accessing them
235
- if similar_question and len(similar_question) > 0:
236
- example_msg = HumanMessage(
237
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
238
- )
239
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
240
- else:
241
- # No similar questions found, proceed without reference
242
- print("No similar questions found in vector store")
243
- return {"messages": [sys_msg] + state["messages"]}
244
-
245
- except Exception as e:
246
- print(f"Error in retriever: {e}")
247
- # Fallback: continue without retrieval
248
- return {"messages": [sys_msg] + state["messages"]}
249
 
250
- # test
251
- if __name__ == "__main__":
252
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
253
- # Build the graph
254
- graph = build_graph(provider="groq")
255
- # Run the graph
256
- messages = [HumanMessage(content=question)]
257
- messages = graph.invoke({"messages": messages})
258
- for m in messages["messages"]:
259
- m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
 
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
  # build a retriever
123
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
+ supabase: Client = create_client(
125
+ os.environ.get("SUPABASE_URL"),
126
+ os.environ.get("SUPABASE_SERVICE_KEY"))
127
+ vector_store = SupabaseVectorStore(
128
+ client=supabase,
129
+ embedding= embeddings,
130
+ table_name="documents",
131
+ query_name="match_documents_langchain",
132
+ )
 
 
 
 
 
 
 
 
 
 
 
 
133
  create_retriever_tool = create_retriever_tool(
134
+ retriever=vector_store.as_retriever(),
135
  name="Question Search",
136
  description="A tool to retrieve similar questions from a vector store.",
137
+ )
138
 
139
 
140
 
 
150
  ]
151
 
152
  # Build graph function
153
+ def build_graph(provider: str = "groq"):
154
+ """Build the graph"""
155
+ # Load environment variables from .env file
156
+ if provider == "google":
157
+ # Google Gemini
158
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
159
+ elif provider == "groq":
160
+ # Groq https://console.groq.com/docs/models
161
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
  elif provider == "huggingface":
163
+ # TODO: Add huggingface endpoint
164
  llm = ChatHuggingFace(
165
+ llm=HuggingFaceEndpoint(
166
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
167
+ temperature=0,
168
+ ),
169
  )
170
  else:
171
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
172
  # Bind tools to LLM
173
  llm_with_tools = llm.bind_tools(tools)
174
 
 
176
  def assistant(state: MessagesState):
177
  """Assistant node"""
178
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
179
+
180
+ from langchain_core.messages import AIMessage
181
 
182
  def retriever(state: MessagesState):
183
+ query = state["messages"][-1].content
184
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
185
+
186
+ content = similar_doc.page_content
187
+ if "Final answer :" in content:
188
+ answer = content.split("Final answer :")[-1].strip()
189
+ else:
190
+ answer = content.strip()
191
+
192
+ return {"messages": [AIMessage(content=answer)]}
193
+
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  builder = StateGraph(MessagesState)
196
  builder.add_node("retriever", retriever)
 
 
 
 
 
 
 
 
 
197
 
198
+ # Retriever ist Start und Endpunkt
199
+ builder.set_entry_point("retriever")
200
+ builder.set_finish_point("retriever")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # Compile graph
203
+ return builder.compile()