Frazer2810 commited on
Commit
4555257
·
verified ·
1 Parent(s): 240eb33

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +61 -118
agent.py CHANGED
@@ -1,66 +1,53 @@
1
- """LangGraph Agent – GPT-4.1 / Hugging Face Spaces (import lazy)"""
 
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition, ToolNode
6
  from langchain_openai import ChatOpenAI
 
 
7
 
8
- # --------------------------------------------------------------------------- #
9
- # Import facoltativi (se il pacchetto non c'è, il provider viene disattivato) #
10
- # --------------------------------------------------------------------------- #
11
- def _lazy_import(name):
12
- try:
13
- module = __import__(name, fromlist=["*"])
14
- return module
15
- except ModuleNotFoundError:
16
- return None
17
-
18
- lg_google = _lazy_import("langchain_google_genai")
19
- lg_groq = _lazy_import("langchain_groq")
20
- lg_hf = _lazy_import("langchain_huggingface")
21
-
22
- if lg_google:
23
- ChatGoogleGenerativeAI = lg_google.ChatGoogleGenerativeAI
24
- if lg_groq:
25
- ChatGroq = lg_groq.ChatGroq
26
- if lg_hf:
27
- ChatHuggingFace = lg_hf.ChatHuggingFace
28
- HuggingFaceEndpoint = lg_hf.HuggingFaceEndpoint
29
- HuggingFaceEmbeddings = lg_hf.HuggingFaceEmbeddings
30
- else:
31
- from langchain_huggingface import HuggingFaceEmbeddings # solo embeddings
32
-
33
- # --------------------------------------------------------------------------- #
34
- # Tools & loaders #
35
- # --------------------------------------------------------------------------- #
36
  from langchain_community.tools.tavily_search import TavilySearchResults
37
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
38
- from langchain_community.vectorstores import SupabaseVectorStore
39
- from langchain_core.messages import SystemMessage, HumanMessage
40
- from langchain_core.tools import tool
41
- from langchain.tools.retriever import create_retriever_tool
42
- from supabase.client import Client, create_client
43
 
44
- load_dotenv() # Secrets di HF Spaces
 
 
 
 
 
 
 
45
 
46
- # -------------------- TOOL di esempio -------------------- #
 
 
47
  @tool
48
  def multiply(a: int, b: int) -> int: return a * b
 
49
  @tool
50
  def add(a: int, b: int) -> int: return a + b
 
51
  @tool
52
  def subtract(a: int, b: int) -> int: return a - b
 
53
  @tool
54
  def divide(a: int, b: int) -> float:
55
  if b == 0:
56
  raise ValueError("Cannot divide by zero.")
57
  return a / b
 
58
  @tool
59
  def modulus(a: int, b: int) -> int: return a % b
60
 
61
- # -------------------- Wikipedia -------------------------- #
 
 
62
  @tool
63
  def wiki_search(query: str) -> str:
 
64
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
65
  return "\n\n---\n\n".join(
66
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -68,9 +55,12 @@ def wiki_search(query: str) -> str:
68
  for d in docs
69
  )
70
 
71
- # -------------------- Tavily ----------------------------- #
 
 
72
  @tool
73
  def web_search(query: str) -> str:
 
74
  docs = TavilySearchResults(max_results=3).invoke(query=query)
75
  return "\n\n---\n\n".join(
76
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -78,9 +68,12 @@ def web_search(query: str) -> str:
78
  for d in docs
79
  )
80
 
81
- # -------------------- ArXiv ------------------------------ #
 
 
82
  @tool
83
  def arxiv_search(query: str) -> str:
 
84
  docs = ArxivLoader(query=query, load_max_docs=3).load()
85
  return "\n\n---\n\n".join(
86
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -88,107 +81,57 @@ def arxiv_search(query: str) -> str:
88
  for d in docs
89
  )
90
 
91
- # --------------------------------------------------------------------------- #
92
- # System prompt #
93
- # --------------------------------------------------------------------------- #
94
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
95
  system_prompt = f.read()
96
  sys_msg = SystemMessage(content=system_prompt)
97
 
98
- # --------------------------------------------------------------------------- #
99
- # Vector store / retriever #
100
- # --------------------------------------------------------------------------- #
101
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
102
- supabase: Client = create_client(
103
- os.environ.get("SUPABASE_URL"),
104
- os.environ.get("SUPABASE_SERVICE_KEY"),
105
- )
106
- vector_store = SupabaseVectorStore(
107
- client=supabase,
108
- embedding=embeddings,
109
- table_name="documents",
110
- query_name="match_documents_langchain",
111
- )
112
- question_search_tool = create_retriever_tool(
113
- retriever=vector_store.as_retriever(),
114
- name="Question Search",
115
- description="A tool to retrieve similar questions from a vector store.",
116
- )
117
-
118
- # --------------------------------------------------------------------------- #
119
- # Lista tool #
120
- # --------------------------------------------------------------------------- #
121
  tools = [
122
  multiply, add, subtract, divide, modulus,
123
  wiki_search, web_search, arxiv_search,
124
- question_search_tool,
125
  ]
126
 
127
- # --------------------------------------------------------------------------- #
128
- # Costruzione graph #
129
- # --------------------------------------------------------------------------- #
130
- def build_graph(provider: str = "openai"):
131
- # ------------------- LLM selection ------------------------------------- #
132
- if provider == "openai":
133
- key = os.getenv("OPENAI_KEY")
134
- if not key:
135
- raise ValueError("OPENAI_KEY mancante: aggiungi la secret nello Space.")
136
- llm = ChatOpenAI(model_name="gpt-4.1", temperature=0, openai_api_key=key)
137
-
138
- elif provider == "google":
139
- if not lg_google:
140
- raise ImportError("langchain_google_genai non installato.")
141
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
142
-
143
- elif provider == "groq":
144
- if not lg_groq:
145
- raise ImportError("langchain_groq non installato.")
146
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
147
-
148
- elif provider == "huggingface":
149
- if not lg_hf:
150
- raise ImportError("langchain_huggingface non installato.")
151
- llm = ChatHuggingFace(
152
- llm=HuggingFaceEndpoint(
153
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
154
- temperature=0,
155
- )
156
- )
157
- else:
158
- raise ValueError("Provider non valido.")
159
-
160
  llm_with_tools = llm.bind_tools(tools)
161
 
162
- # ------------------- Nodes -------------------------------------------- #
 
 
 
 
163
  def assistant(state: MessagesState):
164
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
165
 
166
- def retriever(state: MessagesState):
167
- similar = vector_store.similarity_search(state["messages"][0].content)
168
- if similar:
169
- example = HumanMessage(
170
- content=("Here I provide a similar question and answer for reference:\n\n"
171
- f"{similar[0].page_content}")
172
- )
173
- return {"messages": [sys_msg] + state["messages"] + [example]}
174
- return {"messages": [sys_msg] + state["messages"]}
175
-
176
- # ------------------- Graph -------------------------------------------- #
177
  builder = StateGraph(MessagesState)
178
- builder.add_node("retriever", retriever)
179
  builder.add_node("assistant", assistant)
180
  builder.add_node("tools", ToolNode(tools))
181
 
182
- builder.add_edge(START, "retriever")
183
- builder.add_edge("retriever", "assistant")
184
  builder.add_conditional_edges("assistant", tools_condition)
185
  builder.add_edge("tools", "assistant")
186
 
187
  return builder.compile()
188
 
189
- # --------------------------------------------------------------------------- #
190
- # Test rapido #
191
- # --------------------------------------------------------------------------- #
192
  if __name__ == "__main__":
193
  g = build_graph()
194
  q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
1
+ """LangGraph Agent – Solo GPT-4.1 (OpenAI)"""
2
+
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
  from langchain_openai import ChatOpenAI
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.tools import tool
10
 
11
+ # Tools – loaders
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
 
 
 
 
14
 
15
+ # ------------------------------------------------------------------ #
16
+ # Inizializzazione ambiente #
17
+ # ------------------------------------------------------------------ #
18
+ load_dotenv() # carica OPENAI_KEY dai secrets dello Space
19
+
20
+ OPENAI_KEY = os.getenv("OPENAI_KEY")
21
+ if not OPENAI_KEY:
22
+ raise ValueError("❌ OPENAI_KEY non impostata: aggiungila nei Secrets dello Space.")
23
 
24
+ # ------------------------------------------------------------------ #
25
+ # TOOL di esempio (aritmetica) #
26
+ # ------------------------------------------------------------------ #
27
  @tool
28
  def multiply(a: int, b: int) -> int: return a * b
29
+
30
  @tool
31
  def add(a: int, b: int) -> int: return a + b
32
+
33
  @tool
34
  def subtract(a: int, b: int) -> int: return a - b
35
+
36
  @tool
37
  def divide(a: int, b: int) -> float:
38
  if b == 0:
39
  raise ValueError("Cannot divide by zero.")
40
  return a / b
41
+
42
  @tool
43
  def modulus(a: int, b: int) -> int: return a % b
44
 
45
+ # ------------------------------------------------------------------ #
46
+ # TOOL: Wikipedia #
47
+ # ------------------------------------------------------------------ #
48
  @tool
49
  def wiki_search(query: str) -> str:
50
+ """Search Wikipedia (max 2 docs) and return formatted result."""
51
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
  return "\n\n---\n\n".join(
53
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
55
  for d in docs
56
  )
57
 
58
+ # ------------------------------------------------------------------ #
59
+ # TOOL: Tavily #
60
+ # ------------------------------------------------------------------ #
61
  @tool
62
  def web_search(query: str) -> str:
63
+ """Search Tavily (max 3 docs) and return formatted result."""
64
  docs = TavilySearchResults(max_results=3).invoke(query=query)
65
  return "\n\n---\n\n".join(
66
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
68
  for d in docs
69
  )
70
 
71
+ # ------------------------------------------------------------------ #
72
+ # TOOL: ArXiv #
73
+ # ------------------------------------------------------------------ #
74
  @tool
75
  def arxiv_search(query: str) -> str:
76
+ """Search ArXiv (max 3 docs) and return formatted snippet."""
77
  docs = ArxivLoader(query=query, load_max_docs=3).load()
78
  return "\n\n---\n\n".join(
79
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
81
  for d in docs
82
  )
83
 
84
+ # ------------------------------------------------------------------ #
85
+ # System prompt #
86
+ # ------------------------------------------------------------------ #
87
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
88
  system_prompt = f.read()
89
  sys_msg = SystemMessage(content=system_prompt)
90
 
91
+ # ------------------------------------------------------------------ #
92
+ # Lista tool #
93
+ # ------------------------------------------------------------------ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  tools = [
95
  multiply, add, subtract, divide, modulus,
96
  wiki_search, web_search, arxiv_search,
 
97
  ]
98
 
99
+ # ------------------------------------------------------------------ #
100
+ # Costruzione graph #
101
+ # ------------------------------------------------------------------ #
102
+ def build_graph():
103
+ """Restituisce un graph LangGraph che usa solo GPT-4.1."""
104
+ llm = ChatOpenAI(
105
+ model_name="gpt-4.1",
106
+ temperature=0,
107
+ openai_api_key=OPENAI_KEY,
108
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  llm_with_tools = llm.bind_tools(tools)
110
 
111
+ # Nodes --------------------------------------------------------- #
112
+ def prepend_system(state: MessagesState):
113
+ """Prepend system prompt to conversation."""
114
+ return {"messages": [sys_msg] + state["messages"]}
115
+
116
  def assistant(state: MessagesState):
117
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
118
 
119
+ # Graph --------------------------------------------------------- #
 
 
 
 
 
 
 
 
 
 
120
  builder = StateGraph(MessagesState)
121
+ builder.add_node("system", prepend_system)
122
  builder.add_node("assistant", assistant)
123
  builder.add_node("tools", ToolNode(tools))
124
 
125
+ builder.add_edge(START, "system")
126
+ builder.add_edge("system", "assistant")
127
  builder.add_conditional_edges("assistant", tools_condition)
128
  builder.add_edge("tools", "assistant")
129
 
130
  return builder.compile()
131
 
132
+ # ------------------------------------------------------------------ #
133
+ # Test rapido (facoltativo) #
134
+ # ------------------------------------------------------------------ #
135
  if __name__ == "__main__":
136
  g = build_graph()
137
  q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"