Frazer2810 commited on
Commit
435c072
·
verified ·
1 Parent(s): f6d75b3

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +178 -200
agent.py CHANGED
@@ -1,114 +1,124 @@
1
- """LangGraph Agent with OpenAI"""
2
  import os
 
3
  from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition, ToolNode
5
- from langchain_openai import ChatOpenAI
6
- from langchain_community.document_loaders import WikipediaLoader
7
- from langchain_community.document_loaders import ArxivLoader
8
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain_core.tools import tool
 
 
10
 
11
- # Tools definition
 
 
 
 
 
 
 
12
  @tool
13
- def multiply(a: int, b: int) -> int:
14
- """Multiply two numbers.
15
-
16
- Args:
17
- a: first int
18
- b: second int
19
- """
20
- return a * b
21
 
22
  @tool
23
- def add(a: int, b: int) -> int:
24
- """Add two numbers.
25
-
26
- Args:
27
- a: first int
28
- b: second int
29
- """
30
- return a + b
31
 
32
  @tool
33
- def subtract(a: int, b: int) -> int:
34
- """Subtract two numbers.
35
-
36
- Args:
37
- a: first int
38
- b: second int
39
- """
40
- return a - b
41
 
42
  @tool
43
  def divide(a: int, b: int) -> float:
44
- """Divide two numbers.
45
-
46
- Args:
47
- a: first int
48
- b: second int
49
- """
50
  if b == 0:
51
  raise ValueError("Cannot divide by zero.")
52
  return a / b
53
 
54
  @tool
55
- def modulus(a: int, b: int) -> int:
56
- """Get the modulus of two numbers.
57
-
58
- Args:
59
- a: first int
60
- b: second int
61
- """
62
- return a % b
63
 
 
 
 
64
  @tool
65
  def wiki_search(query: str) -> str:
66
- """Search Wikipedia for a query and return maximum 2 results.
67
-
68
- Args:
69
- query: The search query.
70
- """
71
- try:
72
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
73
- if not search_docs:
74
- return f"No Wikipedia results found for: {query}"
75
-
76
- formatted_search_docs = "\n\n---\n\n".join(
77
- [
78
- f'Source: {doc.metadata.get("source", "Wikipedia")}\nContent: {doc.page_content[:2000]}...'
79
- for doc in search_docs
80
- ])
81
- return formatted_search_docs
82
- except Exception as e:
83
- return f"Error searching Wikipedia: {str(e)}"
84
-
85
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
 
87
  @tool
88
  def arxiv_search(query: str) -> str:
89
- """Search Arxiv for a query and return maximum 3 results.
90
-
91
- Args:
92
- query: The search query.
93
- """
94
- try:
95
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
96
- if not search_docs:
97
- return f"No Arxiv results found for: {query}"
98
-
99
- formatted_search_docs = "\n\n---\n\n".join(
100
- [
101
- f'Title: {doc.metadata.get("Title", "Unknown")}\nAuthors: {doc.metadata.get("Authors", "Unknown")}\nContent: {doc.page_content[:1500]}...'
102
- for doc in search_docs
103
- ])
104
- return formatted_search_docs
105
- except Exception as e:
106
- return f"Error searching Arxiv: {str(e)}"
107
-
108
- # System prompt
109
- system_prompt = """You are a general AI assistant. I will ask you a question. Do not report your thoughts or comments, and finish your answer with the following template: [YOUR FINAL ANSWER]. [YOUR FINAL ANSWER] should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
110
-
111
- # Tools list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  tools = [
113
  multiply,
114
  add,
@@ -116,125 +126,93 @@ tools = [
116
  divide,
117
  modulus,
118
  wiki_search,
 
119
  arxiv_search,
 
120
  ]
121
 
122
- class LangGraphAgent:
123
- """LangGraph Agent with OpenAI that can be used in HuggingFace Space evaluation"""
124
-
125
- def __init__(self):
126
- """Initialize the agent with OpenAI LLM and tools"""
127
- print("Initializing LangGraphAgent...")
128
-
129
- # Get API key from environment
130
- self.api_key = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI_API_KEY")
131
- if not self.api_key:
132
- raise ValueError("OPENAI_KEY environment variable is required")
133
-
134
- # Initialize the graph
135
- self.graph = self._build_graph()
136
- print("LangGraphAgent initialized successfully.")
137
-
138
- def _build_graph(self):
139
- """Build the LangGraph workflow"""
140
- # Initialize OpenAI LLM
141
  llm = ChatOpenAI(
142
- model="gpt-4-turbo", # Changed from gpt-4-turbo-preview
143
  temperature=0,
144
- api_key=self.api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  )
146
-
147
- # Bind tools to LLM
148
- llm_with_tools = llm.bind_tools(tools)
149
-
150
- # System message
151
- sys_msg = SystemMessage(content=system_prompt)
152
-
153
- # Node functions
154
- def assistant(state: MessagesState):
155
- """Assistant node"""
156
- # Ensure system message is included
157
- messages = state["messages"]
158
- if not any(isinstance(msg, SystemMessage) for msg in messages):
159
- messages = [sys_msg] + messages
160
-
161
- response = llm_with_tools.invoke(messages)
162
- return {"messages": [response]}
163
-
164
- # Build the graph
165
- builder = StateGraph(MessagesState)
166
-
167
- # Add nodes
168
- builder.add_node("assistant", assistant)
169
- builder.add_node("tools", ToolNode(tools))
170
-
171
- # Add edges
172
- builder.add_edge(START, "assistant")
173
- builder.add_conditional_edges(
174
- "assistant",
175
- tools_condition,
176
  )
177
- builder.add_edge("tools", "assistant")
178
-
179
- # Compile and return
180
- return builder.compile()
181
-
182
- def __call__(self, question: str) -> str:
183
- """
184
- Process a question and return an answer.
185
-
186
- Args:
187
- question: The question to answer
188
-
189
- Returns:
190
- str: The answer to the question
191
- """
192
- print(f"Agent received question (first 100 chars): {question[:100]}...")
193
-
194
- try:
195
- # Create message
196
- messages = [HumanMessage(content=question)]
197
-
198
- # Invoke the graph
199
- result = self.graph.invoke({"messages": messages})
200
-
201
- # Extract the final answer
202
- ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)]
203
-
204
- if ai_messages:
205
- answer = ai_messages[-1].content
206
- print(f"Agent returning answer (first 100 chars): {answer[:100]}...")
207
- return answer
208
- else:
209
- return "I couldn't generate a response. Please try again."
210
-
211
- except Exception as e:
212
- print(f"Error processing question: {e}")
213
- return f"Error: {str(e)}"
214
-
215
- # For backwards compatibility and testing
216
- BasicAgent = LangGraphAgent
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  if __name__ == "__main__":
219
- # Test the agent
220
- print("Testing LangGraphAgent...")
221
- if not os.environ.get("OPENAI_KEY"):
222
- print("Error: OPENAI_KEY environment variable not set")
223
- print("Please set it with: export OPENAI_KEY=your-openai-api-key")
224
- exit(1)
225
-
226
- try:
227
- agent = LangGraphAgent()
228
- test_questions = [
229
- "What is 15 * 23?",
230
- "Search Wikipedia for information about quantum computing",
231
- "What are the latest developments in AI according to recent papers on Arxiv?",
232
- ]
233
-
234
- for question in test_questions:
235
- print(f"\nQuestion: {question}")
236
- answer = agent(question)
237
- print(f"Answer: {answer}")
238
-
239
- except Exception as e:
240
- print(f"Error during testing: {e}")
 
1
+ """LangGraph Agent versione GPT-4.1 / Hugging Face Spaces"""
2
  import os
3
+ from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+
8
+ # LLM providers
9
+ from langchain_openai import ChatOpenAI # NEW (GPT-4.1)
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_groq import ChatGroq
12
+ from langchain_huggingface import (
13
+ ChatHuggingFace,
14
+ HuggingFaceEndpoint,
15
+ HuggingFaceEmbeddings,
16
+ )
17
+
18
+ # Tools & loaders
19
+ from langchain_community.tools.tavily_search import TavilySearchResults
20
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
21
+ from langchain_community.vectorstores import SupabaseVectorStore
22
+ from langchain_core.messages import SystemMessage, HumanMessage
23
  from langchain_core.tools import tool
24
+ from langchain.tools.retriever import create_retriever_tool
25
+ from supabase.client import Client, create_client
26
 
27
+ # --------------------------------------------------------------------------- #
28
+ # Carica variabili d'ambiente (.env eventuale + secrets di HF Spaces) #
29
+ # --------------------------------------------------------------------------- #
30
+ load_dotenv() # nei Spaces le secrets sono già in os.environ
31
+
32
+ # --------------------------------------------------------------------------- #
33
+ # TOOL di esempio (aritmetica) #
34
+ # --------------------------------------------------------------------------- #
35
  @tool
36
+ def multiply(a: int, b: int) -> int: return a * b
 
 
 
 
 
 
 
37
 
38
  @tool
39
+ def add(a: int, b: int) -> int: return a + b
 
 
 
 
 
 
 
40
 
41
  @tool
42
+ def subtract(a: int, b: int) -> int: return a - b
 
 
 
 
 
 
 
43
 
44
  @tool
45
  def divide(a: int, b: int) -> float:
 
 
 
 
 
 
46
  if b == 0:
47
  raise ValueError("Cannot divide by zero.")
48
  return a / b
49
 
50
  @tool
51
+ def modulus(a: int, b: int) -> int: return a % b
 
 
 
 
 
 
 
52
 
53
+ # --------------------------------------------------------------------------- #
54
+ # TOOL: Wikipedia #
55
+ # --------------------------------------------------------------------------- #
56
  @tool
57
  def wiki_search(query: str) -> str:
58
+ """Search Wikipedia (max 2 docs) and return formatted result."""
59
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
60
+ return "\n\n---\n\n".join(
61
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
62
+ f"{d.page_content}\n</Document>"
63
+ for d in docs
64
+ )
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # --------------------------------------------------------------------------- #
67
+ # TOOL: Tavily web search #
68
+ # --------------------------------------------------------------------------- #
69
+ @tool
70
+ def web_search(query: str) -> str:
71
+ """Search Tavily (max 3 docs) and return formatted result."""
72
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
73
+ return "\n\n---\n\n".join(
74
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
75
+ f"{d.page_content}\n</Document>"
76
+ for d in docs
77
+ )
78
 
79
+ # --------------------------------------------------------------------------- #
80
+ # TOOL: ArXiv #
81
+ # --------------------------------------------------------------------------- #
82
  @tool
83
  def arxiv_search(query: str) -> str:
84
+ """Search ArXiv (max 3 docs) and return formatted snippet."""
85
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
86
+ return "\n\n---\n\n".join(
87
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
88
+ f"{d.page_content[:1000]}\n</Document>"
89
+ for d in docs
90
+ )
91
+
92
+ # --------------------------------------------------------------------------- #
93
+ # System prompt #
94
+ # --------------------------------------------------------------------------- #
95
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
96
+ system_prompt = f.read()
97
+ sys_msg = SystemMessage(content=system_prompt)
98
+
99
+ # --------------------------------------------------------------------------- #
100
+ # Vector store per il retriever #
101
+ # --------------------------------------------------------------------------- #
102
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
103
+ supabase: Client = create_client(
104
+ os.environ.get("SUPABASE_URL"),
105
+ os.environ.get("SUPABASE_SERVICE_KEY"),
106
+ )
107
+ vector_store = SupabaseVectorStore(
108
+ client=supabase,
109
+ embedding=embeddings,
110
+ table_name="documents",
111
+ query_name="match_documents_langchain",
112
+ )
113
+ question_search_tool = create_retriever_tool(
114
+ retriever=vector_store.as_retriever(),
115
+ name="Question Search",
116
+ description="A tool to retrieve similar questions from a vector store.",
117
+ )
118
+
119
+ # --------------------------------------------------------------------------- #
120
+ # Registrazione tool list #
121
+ # --------------------------------------------------------------------------- #
122
  tools = [
123
  multiply,
124
  add,
 
126
  divide,
127
  modulus,
128
  wiki_search,
129
+ web_search,
130
  arxiv_search,
131
+ question_search_tool,
132
  ]
133
 
134
+ # --------------------------------------------------------------------------- #
135
+ # Costruzione del graph LangGraph #
136
+ # --------------------------------------------------------------------------- #
137
+ def build_graph(provider: str = "openai"):
138
+ """Restituisce un graph LangGraph pronto all'uso.
139
+
140
+ provider: "openai" (default), "google", "groq", "huggingface"
141
+ """
142
+ # --- Selezione LLM ------------------------------------------------------ #
143
+ if provider == "openai":
144
+ openai_key = os.getenv("OPENAI_KEY")
145
+ if not openai_key:
146
+ raise ValueError(
147
+ "❌ Environment variable OPENAI_KEY mancante. "
148
+ "Aggiungi la secret dal tab 'Secrets' dello Space."
149
+ )
 
 
 
150
  llm = ChatOpenAI(
151
+ model_name="gpt-4.1",
152
  temperature=0,
153
+ openai_api_key=openai_key,
154
+ )
155
+
156
+ elif provider == "google":
157
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
158
+
159
+ elif provider == "groq":
160
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
161
+
162
+ elif provider == "huggingface":
163
+ llm = ChatHuggingFace(
164
+ llm=HuggingFaceEndpoint(
165
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
166
+ temperature=0,
167
+ )
168
  )
169
+ else:
170
+ raise ValueError(
171
+ "Invalid provider. Choose 'openai', 'google', 'groq' or 'huggingface'."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # Abilita tool calling
175
+ llm_with_tools = llm.bind_tools(tools)
176
+
177
+ # ------------------------- NODES --------------------------------------- #
178
+ def assistant(state: MessagesState):
179
+ """Invoca il modello."""
180
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
+
182
+ def retriever(state: MessagesState):
183
+ """Aggiunge alla history un Q/A simile come esempio."""
184
+ similar = vector_store.similarity_search(state["messages"][0].content)
185
+ if similar:
186
+ example_msg = HumanMessage(
187
+ content=(
188
+ "Here I provide a similar question and answer for reference:\n\n"
189
+ f"{similar[0].page_content}"
190
+ )
191
+ )
192
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
193
+ return {"messages": [sys_msg] + state["messages"]}
194
+
195
+ # --------------------------- GRAPH ------------------------------------- #
196
+ builder = StateGraph(MessagesState)
197
+ builder.add_node("retriever", retriever)
198
+ builder.add_node("assistant", assistant)
199
+ builder.add_node("tools", ToolNode(tools))
200
+
201
+ builder.add_edge(START, "retriever")
202
+ builder.add_edge("retriever", "assistant")
203
+ builder.add_conditional_edges("assistant", tools_condition)
204
+ builder.add_edge("tools", "assistant")
205
+
206
+ return builder.compile()
207
+
208
+
209
+ # --------------------------------------------------------------------------- #
210
+ # Quick test (python agent.py) #
211
+ # --------------------------------------------------------------------------- #
212
  if __name__ == "__main__":
213
+ graph = build_graph(provider="openai")
214
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
215
+ msgs = [HumanMessage(content=question)]
216
+ result = graph.invoke({"messages": msgs})
217
+ for m in result["messages"]:
218
+ m.pretty_print()