Frazer2810 commited on
Commit
5c3558a
·
verified ·
1 Parent(s): 97783bb

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +85 -68
agent.py CHANGED
@@ -1,62 +1,64 @@
1
- """LangGraph Agent – Solo GPT-4.1 (OpenAI) con docstring obbligatorie."""
2
 
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import ToolNode, tools_condition
7
- from langchain_openai import ChatOpenAI
8
- from langchain_core.messages import SystemMessage, HumanMessage
9
- from langchain_core.tools import tool
10
 
11
- # Loader & search tools
 
 
 
 
 
 
 
 
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
 
14
 
15
- # ------------------------------------------------------------------ #
16
- # Ambiente #
17
- # ------------------------------------------------------------------ #
18
- load_dotenv() # carica OPENAI_KEY dallo Space
19
- OPENAI_KEY = os.getenv("OPENAI_KEY")
20
- if not OPENAI_KEY:
21
- raise ValueError("❌ OPENAI_KEY non impostata: aggiungila nei Secrets dello Space.")
22
-
23
- # ------------------------------------------------------------------ #
24
- # TOOL: aritmetica #
25
- # ------------------------------------------------------------------ #
26
  @tool
27
  def multiply(a: int, b: int) -> int:
28
  """Multiply two integers and return the product."""
29
  return a * b
30
 
 
31
  @tool
32
  def add(a: int, b: int) -> int:
33
  """Add two integers and return the sum."""
34
  return a + b
35
 
 
36
  @tool
37
  def subtract(a: int, b: int) -> int:
38
  """Subtract the second integer from the first and return the difference."""
39
  return a - b
40
 
 
41
  @tool
42
  def divide(a: int, b: int) -> float:
43
- """Divide the first integer by the second and return the quotient.
44
-
45
- Raises:
46
- ValueError: If b is zero.
47
- """
48
  if b == 0:
49
  raise ValueError("Cannot divide by zero.")
50
  return a / b
51
 
 
52
  @tool
53
  def modulus(a: int, b: int) -> int:
54
  """Return the remainder of the division of a by b."""
55
  return a % b
56
 
57
- # ------------------------------------------------------------------ #
58
- # TOOL: Wikipedia #
59
- # ------------------------------------------------------------------ #
 
60
  @tool
61
  def wiki_search(query: str) -> str:
62
  """Search Wikipedia (max 2 docs) and return formatted content."""
@@ -67,12 +69,13 @@ def wiki_search(query: str) -> str:
67
  for d in docs
68
  )
69
 
70
- # ------------------------------------------------------------------ #
71
- # TOOL: Tavily web search #
72
- # ------------------------------------------------------------------ #
 
73
  @tool
74
  def web_search(query: str) -> str:
75
- """Perform a web search using Tavily (max 3 docs) and return formatted content."""
76
  docs = TavilySearchResults(max_results=3).invoke(query=query)
77
  return "\n\n---\n\n".join(
78
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -80,12 +83,13 @@ def web_search(query: str) -> str:
80
  for d in docs
81
  )
82
 
83
- # ------------------------------------------------------------------ #
84
- # TOOL: ArXiv #
85
- # ------------------------------------------------------------------ #
 
86
  @tool
87
  def arxiv_search(query: str) -> str:
88
- """Search ArXiv (max 3 docs) and return the first 1000 chars of each paper."""
89
  docs = ArxivLoader(query=query, load_max_docs=3).load()
90
  return "\n\n---\n\n".join(
91
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -93,62 +97,75 @@ def arxiv_search(query: str) -> str:
93
  for d in docs
94
  )
95
 
96
- # ------------------------------------------------------------------ #
97
- # System prompt #
98
- # ------------------------------------------------------------------ #
 
99
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
100
  system_prompt = f.read()
101
  sys_msg = SystemMessage(content=system_prompt)
102
 
103
- # ------------------------------------------------------------------ #
104
- # Tool list #
105
- # ------------------------------------------------------------------ #
106
  tools = [
107
- multiply, add, subtract, divide, modulus,
108
- wiki_search, web_search, arxiv_search,
 
 
 
 
 
 
109
  ]
110
 
111
- # ------------------------------------------------------------------ #
112
- # Build LangGraph #
113
- # ------------------------------------------------------------------ #
114
- def build_graph():
115
- """Return a LangGraph graph that uses only GPT-4.1 via OpenAI."""
116
- llm = ChatOpenAI(
117
- model_name="gpt-4.1",
118
- temperature=0,
119
- openai_api_key=OPENAI_KEY,
120
- )
121
- llm_with_tools = llm.bind_tools(tools)
 
 
 
 
 
 
 
 
122
 
123
- # Nodes --------------------------------------------------------- #
124
- def prepend_system(state: MessagesState):
125
- """Prepend system prompt to the incoming messages."""
126
- return {"messages": [sys_msg] + state["messages"]}
127
 
 
128
  def assistant(state: MessagesState):
129
- """Invoke the LLM (tool calling enabled)."""
130
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
131
 
132
- # Graph --------------------------------------------------------- #
133
  builder = StateGraph(MessagesState)
134
- builder.add_node("system", prepend_system)
135
  builder.add_node("assistant", assistant)
136
  builder.add_node("tools", ToolNode(tools))
137
 
138
- builder.add_edge(START, "system")
139
- builder.add_edge("system", "assistant")
140
  builder.add_conditional_edges("assistant", tools_condition)
141
  builder.add_edge("tools", "assistant")
142
 
143
  return builder.compile()
144
 
145
- # ------------------------------------------------------------------ #
146
- # Test rapido #
147
- # ------------------------------------------------------------------ #
 
148
  if __name__ == "__main__":
149
- g = build_graph()
150
- query = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
151
- msgs = [HumanMessage(content=query)]
152
- result = g.invoke({"messages": msgs})
153
  for m in result["messages"]:
154
  m.pretty_print()
 
1
+ """LangGraph Agent – versione senza Supabase"""
2
 
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import ToolNode, tools_condition
 
 
 
7
 
8
+ # LLM providers
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_groq import ChatGroq
11
+ from langchain_huggingface import (
12
+ ChatHuggingFace,
13
+ HuggingFaceEndpoint,
14
+ )
15
+
16
+ # Tools & loaders
17
  from langchain_community.tools.tavily_search import TavilySearchResults
18
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
19
+ from langchain_core.messages import SystemMessage, HumanMessage
20
+ from langchain_core.tools import tool
21
 
22
+ load_dotenv() # carica eventuali variabili dal file .env
23
+
24
+ # --------------------------------------------------------------------------- #
25
+ # TOOL: operazioni aritmetiche #
26
+ # --------------------------------------------------------------------------- #
 
 
 
 
 
 
27
  @tool
28
  def multiply(a: int, b: int) -> int:
29
  """Multiply two integers and return the product."""
30
  return a * b
31
 
32
+
33
  @tool
34
  def add(a: int, b: int) -> int:
35
  """Add two integers and return the sum."""
36
  return a + b
37
 
38
+
39
  @tool
40
  def subtract(a: int, b: int) -> int:
41
  """Subtract the second integer from the first and return the difference."""
42
  return a - b
43
 
44
+
45
  @tool
46
  def divide(a: int, b: int) -> float:
47
+ """Divide a by b and return the quotient (error if b == 0)."""
 
 
 
 
48
  if b == 0:
49
  raise ValueError("Cannot divide by zero.")
50
  return a / b
51
 
52
+
53
  @tool
54
  def modulus(a: int, b: int) -> int:
55
  """Return the remainder of the division of a by b."""
56
  return a % b
57
 
58
+
59
+ # --------------------------------------------------------------------------- #
60
+ # TOOL: Wikipedia #
61
+ # --------------------------------------------------------------------------- #
62
  @tool
63
  def wiki_search(query: str) -> str:
64
  """Search Wikipedia (max 2 docs) and return formatted content."""
 
69
  for d in docs
70
  )
71
 
72
+
73
+ # --------------------------------------------------------------------------- #
74
+ # TOOL: Tavily web search #
75
+ # --------------------------------------------------------------------------- #
76
  @tool
77
  def web_search(query: str) -> str:
78
+ """Perform a web search with Tavily (max 3 docs) and return formatted content."""
79
  docs = TavilySearchResults(max_results=3).invoke(query=query)
80
  return "\n\n---\n\n".join(
81
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
83
  for d in docs
84
  )
85
 
86
+
87
+ # --------------------------------------------------------------------------- #
88
+ # TOOL: ArXiv #
89
+ # --------------------------------------------------------------------------- #
90
  @tool
91
  def arxiv_search(query: str) -> str:
92
+ """Search ArXiv (max 3 docs) and return first 1000 characters per paper."""
93
  docs = ArxivLoader(query=query, load_max_docs=3).load()
94
  return "\n\n---\n\n".join(
95
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
97
  for d in docs
98
  )
99
 
100
+
101
+ # --------------------------------------------------------------------------- #
102
+ # System prompt #
103
+ # --------------------------------------------------------------------------- #
104
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
105
  system_prompt = f.read()
106
  sys_msg = SystemMessage(content=system_prompt)
107
 
108
+ # --------------------------------------------------------------------------- #
109
+ # Lista tool #
110
+ # --------------------------------------------------------------------------- #
111
  tools = [
112
+ multiply,
113
+ add,
114
+ subtract,
115
+ divide,
116
+ modulus,
117
+ wiki_search,
118
+ web_search,
119
+ arxiv_search,
120
  ]
121
 
122
+ # --------------------------------------------------------------------------- #
123
+ # Build LangGraph #
124
+ # --------------------------------------------------------------------------- #
125
+ def build_graph(provider: str = "groq"):
126
+ """Return a LangGraph graph without Supabase dependencies."""
127
+ # ------------ LLM selection ------------------------------------------- #
128
+ if provider == "google":
129
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
130
+ elif provider == "groq":
131
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
132
+ elif provider == "huggingface":
133
+ llm = ChatHuggingFace(
134
+ llm=HuggingFaceEndpoint(
135
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
136
+ temperature=0,
137
+ )
138
+ )
139
+ else:
140
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
141
 
142
+ llm_with_tools = llm.bind_tools(tools)
 
 
 
143
 
144
+ # ------------------ Nodes -------------------------------------------- #
145
  def assistant(state: MessagesState):
146
+ """Invoke LLM with system prompt prepended."""
147
+ messages = [sys_msg] + state["messages"]
148
+ return {"messages": [llm_with_tools.invoke(messages)]}
149
 
150
+ # ------------------ Graph -------------------------------------------- #
151
  builder = StateGraph(MessagesState)
 
152
  builder.add_node("assistant", assistant)
153
  builder.add_node("tools", ToolNode(tools))
154
 
155
+ builder.add_edge(START, "assistant")
 
156
  builder.add_conditional_edges("assistant", tools_condition)
157
  builder.add_edge("tools", "assistant")
158
 
159
  return builder.compile()
160
 
161
+
162
+ # --------------------------------------------------------------------------- #
163
+ # Test rapido #
164
+ # --------------------------------------------------------------------------- #
165
  if __name__ == "__main__":
166
+ graph = build_graph(provider="groq")
167
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
168
+ messages = [HumanMessage(content=question)]
169
+ result = graph.invoke({"messages": messages})
170
  for m in result["messages"]:
171
  m.pretty_print()