omarViga commited on
Commit
9ec7984
·
verified ·
1 Parent(s): 3e86fa9

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +87 -197
agent.py CHANGED
@@ -16,214 +16,104 @@ from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase import create_client, Client
18
 
19
-
20
- supabase_url = os.getenv("SUPABASE_URL")
21
- supabase_key = os.getenv("SUPABASE_KEY")
22
-
23
- # Obtener credenciales de entorno (funcionará con ambos métodos)
24
- supabase_url = os.environ.get('SUPABASE_URL')
25
- supabase_key = os.environ.get('SUPABASE_KEY')
26
-
27
- if not supabase_url or not supabase_key:
28
- raise ValueError("""
29
  Missing Supabase credentials. Please set:
30
  1. SUPABASE_URL and SUPABASE_KEY as Secrets in HF Space settings
31
  OR
32
- 2. As environment variables if running locally
 
33
  """)
34
 
35
- supabase: Client = create_client(supabase_url, supabase_key)
36
-
37
- @tool
38
- def multiply(a: int, b: int) -> int:
39
- """Multiply two numbers.
40
- Args:
41
- a: first int
42
- b: second int
43
- """
44
- return a * b
45
-
46
- @tool
47
- def add(a: int, b: int) -> int:
48
- """Add two numbers.
49
-
50
- Args:
51
- a: first int
52
- b: second int
53
- """
54
- return a + b
55
-
56
- @tool
57
- def subtract(a: int, b: int) -> int:
58
- """Subtract two numbers.
59
-
60
- Args:
61
- a: first int
62
- b: second int
63
- """
64
- return a - b
65
-
66
- @tool
67
- def divide(a: int, b: int) -> int:
68
- """Divide two numbers.
69
-
70
- Args:
71
- a: first int
72
- b: second int
73
- """
74
- if b == 0:
75
- raise ValueError("Cannot divide by zero.")
76
- return a / b
77
-
78
- @tool
79
- def modulus(a: int, b: int) -> int:
80
- """Get the modulus of two numbers.
81
-
82
- Args:
83
- a: first int
84
- b: second int
85
- """
86
- return a % b
87
-
88
- @tool
89
- def wiki_search(query: str) -> str:
90
- """Search Wikipedia for a query and return maximum 2 results.
91
-
92
- Args:
93
- query: The search query."""
94
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
95
- formatted_search_docs = "\n\n---\n\n".join(
96
- [
97
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
- for doc in search_docs
99
- ])
100
- return {"wiki_results": formatted_search_docs}
101
-
102
- @tool
103
- def web_search(query: str) -> str:
104
- """Search Tavily for a query and return maximum 3 results.
105
-
106
- Args:
107
- query: The search query."""
108
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
109
- formatted_search_docs = "\n\n---\n\n".join(
110
- [
111
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
112
- for doc in search_docs
113
- ])
114
- return {"web_results": formatted_search_docs}
115
-
116
- @tool
117
- def arvix_search(query: str) -> str:
118
- """Search Arxiv for a query and return maximum 3 result.
119
-
120
- Args:
121
- query: The search query."""
122
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
123
- formatted_search_docs = "\n\n---\n\n".join(
124
- [
125
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
126
- for doc in search_docs
127
- ])
128
- return {"arvix_results": formatted_search_docs}
129
-
130
 
131
-
132
- # load the system prompt from the file
133
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
134
  system_prompt = f.read()
135
 
136
- # System message
137
  sys_msg = SystemMessage(content=system_prompt)
138
 
139
- # build a retriever
140
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
141
- supabase: Client = create_client(
142
- os.environ.get("https://ujbrcpxgwgsxyjqjhtam.supabase.co"),
143
- os.environ.get("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InVqYnJjcHhnd2dzeHlqcWpodGFtIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NTA5MjU1NzIsImV4cCI6MjA2NjUwMTU3Mn0.i78C9tbh4jmVbhKfG_Xng1IMzATkf9fs31RtFYs22E8"))
144
- vector_store = SupabaseVectorStore(
145
- client=supabase,
146
- embedding= embeddings,
147
- table_name="documents",
148
- query_name="match_documents_langchain",
149
- )
150
- create_retriever_tool = create_retriever_tool(
151
- retriever=vector_store.as_retriever(),
152
- name="Question Search",
153
- description="A tool to retrieve similar questions from a vector store.",
154
- )
155
-
156
-
157
-
158
- tools = [
159
- multiply,
160
- add,
161
- subtract,
162
- divide,
163
- modulus,
164
- wiki_search,
165
- web_search,
166
- arvix_search,
167
- ]
168
 
169
- # Build graph function
170
  def build_graph(provider: str = "groq"):
171
- """Build the graph"""
172
- # Load environment variables from .env file
173
- if provider == "google":
174
- # Google Gemini
175
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
176
- elif provider == "groq":
177
- # Groq https://console.groq.com/docs/models
178
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
179
- elif provider == "huggingface":
180
- # TODO: Add huggingface endpoint
181
- llm = ChatHuggingFace(
182
- llm=HuggingFaceEndpoint(
183
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
184
- temperature=0,
185
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
- else:
188
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
189
- # Bind tools to LLM
190
- llm_with_tools = llm.bind_tools(tools)
191
-
192
- # Node
193
- def assistant(state: MessagesState):
194
- """Assistant node"""
195
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
196
-
197
- def retriever(state: MessagesState):
198
- """Retriever node"""
199
- similar_question = vector_store.similarity_search(state["messages"][0].content)
200
- example_msg = HumanMessage(
201
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
202
- )
203
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
204
-
205
- builder = StateGraph(MessagesState)
206
- builder.add_node("retriever", retriever)
207
- builder.add_node("assistant", assistant)
208
- builder.add_node("tools", ToolNode(tools))
209
- builder.add_edge(START, "retriever")
210
- builder.add_edge("retriever", "assistant")
211
- builder.add_conditional_edges(
212
- "assistant",
213
- tools_condition,
214
- )
215
- builder.add_edge("tools", "assistant")
216
-
217
- # Compile graph
218
- return builder.compile()
219
-
220
- # test
221
- if __name__ == "__main__":
222
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
223
- # Build the graph
224
- graph = build_graph(provider="groq")
225
- # Run the graph
226
- messages = [HumanMessage(content=question)]
227
- messages = graph.invoke({"messages": messages})
228
- for m in messages["messages"]:
229
- m.pretty_print()
 
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase import create_client, Client
18
 
19
+ # Cargar variables de entorno
20
+ load_dotenv()
21
+
22
+ # Configuración de Supabase (segura)
23
+ try:
24
+ supabase_url = os.environ['SUPABASE_URL']
25
+ supabase_key = os.environ['SUPABASE_KEY']
26
+ supabase: Client = create_client(supabase_url, supabase_key)
27
+ except KeyError as e:
28
+ raise ValueError(f"""
29
  Missing Supabase credentials. Please set:
30
  1. SUPABASE_URL and SUPABASE_KEY as Secrets in HF Space settings
31
  OR
32
+ 2. As environment variables in a .env file for local development
33
+ Missing: {e}
34
  """)
35
 
36
+ # Herramientas (tools) permanecen igual...
37
+ # [Mantén todas tus herramientas (@tool) como están]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Cargar prompt del sistema
 
40
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
41
  system_prompt = f.read()
42
 
 
43
  sys_msg = SystemMessage(content=system_prompt)
44
 
45
+ # Configuración del vector store corregida
46
+ try:
47
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
48
+ vector_store = SupabaseVectorStore(
49
+ client=supabase,
50
+ embedding=embeddings,
51
+ table_name="documents",
52
+ query_name="match_documents_langchain",
53
+ )
54
+ retriever_tool = create_retriever_tool(
55
+ retriever=vector_store.as_retriever(),
56
+ name="Question Search",
57
+ description="Search for similar questions in our knowledge base",
58
+ )
59
+ tools.append(retriever_tool) # Añade esta herramienta a tu lista existente
60
+ except Exception as e:
61
+ print(f"Warning: Could not initialize vector store: {e}")
62
+ retriever_tool = None
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Función build_graph corregida
65
  def build_graph(provider: str = "groq"):
66
+ """Build the graph with error handling"""
67
+ try:
68
+ if provider == "google":
69
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
70
+ elif provider == "groq":
71
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
72
+ elif provider == "huggingface":
73
+ llm = ChatHuggingFace(
74
+ llm=HuggingFaceEndpoint(
75
+ endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
76
+ temperature=0,
77
+ )
78
+ )
79
+ else:
80
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'")
81
+
82
+ llm_with_tools = llm.bind_tools(tools)
83
+
84
+ # Nodos del grafo
85
+ def assistant(state: MessagesState):
86
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
87
+
88
+ def retriever(state: MessagesState):
89
+ if not retriever_tool:
90
+ return {"messages": state["messages"]}
91
+ similar_questions = vector_store.similarity_search(state["messages"][-1].content, k=1)
92
+ if similar_questions:
93
+ example_msg = HumanMessage(
94
+ content=f"Similar question found:\n\n{similar_questions[0].page_content}"
95
+ )
96
+ return {"messages": state["messages"] + [example_msg]}
97
+ return {"messages": state["messages"]}
98
+
99
+ # Construcción del grafo
100
+ builder = StateGraph(MessagesState)
101
+ builder.add_node("retriever", retriever)
102
+ builder.add_node("assistant", assistant)
103
+ builder.add_node("tools", ToolNode(tools))
104
+
105
+ builder.add_edge(START, "retriever")
106
+ builder.add_edge("retriever", "assistant")
107
+ builder.add_conditional_edges(
108
+ "assistant",
109
+ tools_condition,
110
  )
111
+ builder.add_edge("tools", "assistant")
112
+
113
+ return builder.compile()
114
+
115
+ except Exception as e:
116
+ print(f"Error building graph: {e}")
117
+ raise
118
+
119
+ # [El resto del código permanece igual]