dgsilvia commited on
Commit
062d582
·
verified ·
1 Parent(s): fa8605a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +214 -0
agent.py CHANGED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
+ from langchain_community.document_loaders import WikipediaLoader
8
+ from langchain_community.document_loaders import ArxivLoader
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+ from langchain.tools.retriever import create_retriever_tool
12
+ from langchain_community.tools import DuckDuckGoSearchResults
13
+ from langchain_community.vectorstores import Chroma
14
+ import json
15
+
16
+
17
+
18
+ @tool
19
+ def multiply(a: int, b: int) -> int:
20
+ """Multiply two numbers.
21
+ Args:
22
+ a: first int
23
+ b: second int
24
+ """
25
+ return a * b
26
+
27
+ @tool
28
+ def add(a: int, b: int) -> int:
29
+ """Add two numbers.
30
+ Args:
31
+ a: first int
32
+ b: second int
33
+ """
34
+ return a + b
35
+
36
+ @tool
37
+ def subtract(a: int, b: int) -> int:
38
+ """Subtract two numbers.
39
+ Args:
40
+ a: first int
41
+ b: second int
42
+ """
43
+ return a - b
44
+
45
+ @tool
46
+ def divide(a: int, b: int) -> int:
47
+ """Divide two numbers.
48
+ Args:
49
+ a: first int
50
+ b: second int
51
+ """
52
+ if b == 0:
53
+ raise ValueError("Cannot divide by zero.")
54
+ return a / b
55
+
56
+ @tool
57
+ def modulus(a: int, b: int) -> int:
58
+ """Get the modulus of two numbers.
59
+ Args:
60
+ a: first int
61
+ b: second int
62
+ """
63
+ return a % b
64
+
65
+ @tool
66
+ def wiki_search(query: str) -> str:
67
+ """Search Wikipedia for a query and return maximum 2 results.
68
+ Args:
69
+ query: The search query."""
70
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
71
+ formatted_search_docs = "\n\n---\n\n".join(
72
+ [
73
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
74
+ for doc in search_docs
75
+ ])
76
+ return {"wiki_results": formatted_search_docs}
77
+
78
+
79
+ @tool
80
+ def web_search(query: str) -> dict:
81
+ """Search DuckDuckGo for a query and return maximum 3 results using LangChain."""
82
+ # Crea il tool DuckDuckGo
83
+ search = DuckDuckGoSearchResults(max_results=3)
84
+ docs = search.run(query) # restituisce una lista di dict con 'title', 'link', 'snippet'
85
+
86
+ # Formattiamo i risultati per il LLM
87
+ formatted = "\n\n---\n\n".join(
88
+ f'<Document source="{doc["link"]}" page="">\n{doc["title"]}: {doc["snippet"]}\n</Document>'
89
+ for doc in docs
90
+ )
91
+ return {"web_results": formatted}
92
+
93
+ @tool
94
+ def arxiv_search(query: str) -> str:
95
+ """Search Arxiv for a query and return maximum 3 result.
96
+ Args:
97
+ query: The search query."""
98
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
99
+ formatted_search_docs = "\n\n---\n\n".join(
100
+ [
101
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
102
+ for doc in search_docs
103
+ ])
104
+ return {"arvix_results": formatted_search_docs}
105
+
106
+
107
+
108
+ # load the system prompt from the file
109
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
110
+ system_prompt = f.read()
111
+ print(system_prompt)
112
+ # System message
113
+ sys_msg = SystemMessage(content=system_prompt)
114
+
115
+
116
+
117
+ # Usa gli stessi embeddings
118
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
119
+
120
+ # Carica il vector store salvato precedentemente
121
+ vector_store = Chroma(
122
+ embedding_function=embeddings,
123
+ persist_directory="./chroma_db" # stesso path usato durante il salvataggio
124
+ )
125
+
126
+ # Crea il retriever tool
127
+ create_retriever_tool = create_retriever_tool(
128
+ retriever=vector_store.as_retriever(),
129
+ name="Question Search",
130
+ description="A tool to retrieve similar questions from a local Chroma vector store.",
131
+ )
132
+
133
+
134
+ tools = [
135
+ multiply,
136
+ add,
137
+ subtract,
138
+ divide,
139
+ modulus,
140
+ wiki_search,
141
+ web_search,
142
+ arxiv_search,
143
+ ]
144
+
145
+ # Build graph function
146
+ def build_graph():
147
+ """Build the graph"""
148
+ llm = ChatHuggingFace(
149
+ llm=HuggingFaceEndpoint(
150
+ repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
151
+ temperature=0,
152
+ ),
153
+ )
154
+ llm_with_tools = llm.bind_tools(tools)
155
+
156
+ # Node
157
+ def assistant(state: MessagesState):
158
+ """Assistant node"""
159
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
160
+
161
+ def retriever(state: MessagesState):
162
+ """Retriever node"""
163
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
164
+ example_msg = HumanMessage(
165
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
166
+ )
167
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
168
+
169
+ builder = StateGraph(MessagesState)
170
+ builder.add_node("retriever", retriever)
171
+ builder.add_node("assistant", assistant)
172
+ builder.add_node("tools", ToolNode(tools))
173
+ builder.add_edge(START, "retriever")
174
+ builder.add_edge("retriever", "assistant")
175
+ builder.add_conditional_edges(
176
+ "assistant",
177
+ tools_condition,
178
+ )
179
+ builder.add_edge("tools", "assistant")
180
+
181
+ # Compile graph
182
+ return builder.compile()
183
+
184
+ # test
185
+ if __name__ == "__main__":
186
+ graph = build_graph()
187
+ # Carica il file JSON
188
+ with open('questions.json', 'r', encoding='utf-8') as f:
189
+ data = json.load(f)
190
+
191
+ # Estrai le domande
192
+ questions = [entry['question'] for entry in data if 'question' in entry]
193
+
194
+ # Mostra o usa la lista di domande
195
+ for q in questions:
196
+ print('orig:', q)
197
+ messages = [HumanMessage(content=q)]
198
+ messages = graph.invoke({"messages": messages})
199
+
200
+ m=messages["messages"][-1]
201
+
202
+ #for m in messages["messages"]:
203
+ content = m.content if hasattr(m, "content") else str(m)
204
+ print("Full response:", content)
205
+
206
+ if "FINAL ANSWER:" in content:
207
+ answer = content.rsplit("FINAL ANSWER:", 1)[-1].strip()
208
+ print("✅ Estratto finale:", answer)
209
+ else:
210
+ print("❌ Nessuna risposta finale trovata.")
211
+
212
+ break
213
+
214
+