nadim71 commited on
Commit
47b8bcf
·
verified ·
1 Parent(s): 643f23f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +164 -217
agent.py CHANGED
@@ -1,222 +1,169 @@
1
- """LangGraph Agent"""
2
  import os
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_groq import ChatGroq
8
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.document_loaders import WikipediaLoader
11
- from langchain_community.document_loaders import ArxivLoader
12
- from langchain_community.vectorstores import SupabaseVectorStore
13
- from langchain_core.messages import SystemMessage, HumanMessage
14
- from langchain_core.tools import tool
15
- from langchain.tools.retriever import create_retriever_tool
16
- from supabase.client import Client, create_client
17
-
18
-
19
-
20
- @tool
21
- def multiply(a: int, b: int) -> int:
22
- """Multiply two numbers.
23
- Args:
24
- a: first int
25
- b: second int
26
- """
27
- return a * b
28
-
29
- @tool
30
- def add(a: int, b: int) -> int:
31
- """Add two numbers.
32
-
33
- Args:
34
- a: first int
35
- b: second int
36
- """
37
- return a + b
38
-
39
- @tool
40
- def subtract(a: int, b: int) -> int:
41
- """Subtract two numbers.
42
-
43
- Args:
44
- a: first int
45
- b: second int
46
- """
47
- return a - b
48
-
49
- @tool
50
- def divide(a: int, b: int) -> int:
51
- """Divide two numbers.
52
-
53
- Args:
54
- a: first int
55
- b: second int
56
- """
57
- if b == 0:
58
- raise ValueError("Cannot divide by zero.")
59
- return a / b
60
-
61
- @tool
62
- def modulus(a: int, b: int) -> int:
63
- """Get the modulus of two numbers.
64
-
65
- Args:
66
- a: first int
67
- b: second int
68
- """
69
- return a % b
70
-
71
- @tool
72
- def wiki_search(query: str) -> str:
73
- """Search Wikipedia for a query and return maximum 2 results.
74
-
75
- Args:
76
- query: The search query."""
77
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
78
- formatted_search_docs = "\n\n---\n\n".join(
79
- [
80
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
81
- for doc in search_docs
82
- ])
83
- return {"wiki_results": formatted_search_docs}
84
-
85
- @tool
86
- def web_search(query: str) -> str:
87
- """Search Tavily for a query and return maximum 3 results.
88
-
89
- Args:
90
- query: The search query."""
91
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
92
- formatted_search_docs = "\n\n---\n\n".join(
93
- [
94
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
- for doc in search_docs
96
- ])
97
- return {"web_results": formatted_search_docs}
98
-
99
- @tool
100
- def arvix_search(query: str) -> str:
101
- """Search Arxiv for a query and return maximum 3 result.
102
-
103
- Args:
104
- query: The search query."""
105
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
106
- formatted_search_docs = "\n\n---\n\n".join(
107
- [
108
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
109
- for doc in search_docs
110
- ])
111
- return {"arvix_results": formatted_search_docs}
112
-
113
-
114
-
115
- # load the system prompt from the file
116
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
117
- system_prompt = f.read()
118
-
119
- # System message
120
- sys_msg = SystemMessage(content=system_prompt)
121
-
122
- # build a retriever
123
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
- supabase: Client = create_client(
125
- os.environ.get("SUPABASE_URL"),
126
- os.environ.get("SUPABASE_SERVICE_KEY"))
127
- vector_store = SupabaseVectorStore(
128
- client=supabase,
129
- embedding= embeddings,
130
- table_name="documents",
131
- query_name="match_documents_langchain",
132
  )
133
- create_retriever_tool = create_retriever_tool(
134
- retriever=vector_store.as_retriever(),
135
- name="Question Search",
136
- description="A tool to retrieve similar questions from a vector store.",
137
- )
138
-
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- tools = [
142
- multiply,
143
- add,
144
- subtract,
145
- divide,
146
- modulus,
147
- wiki_search,
148
- web_search,
149
- arvix_search,
150
- ]
151
-
152
- # Build graph function
153
- def build_graph(provider: str = "google"):
154
- """Build the graph"""
155
- # Load environment variables from .env file
156
- if provider == "google":
157
- # Google Gemini
158
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
159
- elif provider == "groq":
160
- # Groq https://console.groq.com/docs/models
161
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
- elif provider == "huggingface":
163
- # TODO: Add huggingface endpoint
164
- llm = ChatHuggingFace(
165
- llm=HuggingFaceEndpoint(
166
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
167
- temperature=0,
168
- ),
169
- )
170
- else:
171
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
172
- # Bind tools to LLM
173
- llm_with_tools = llm.bind_tools(tools)
174
-
175
- # Node
176
- def assistant(state: MessagesState):
177
- """Assistant node"""
178
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
179
-
180
- # def retriever(state: MessagesState):
181
- # """Retriever node"""
182
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
183
- #example_msg = HumanMessage(
184
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
185
- # )
186
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
187
-
188
- from langchain_core.messages import AIMessage
189
-
190
- def retriever(state: MessagesState):
191
- query = state["messages"][-1].content
192
- similar_doc = vector_store.similarity_search(query, k=1)[0]
193
-
194
- content = similar_doc.page_content
195
- if "Final answer :" in content:
196
- answer = content.split("Final answer :")[-1].strip()
197
  else:
198
- answer = content.strip()
199
-
200
- return {"messages": [AIMessage(content=answer)]}
201
-
202
- # builder = StateGraph(MessagesState)
203
- #builder.add_node("retriever", retriever)
204
- #builder.add_node("assistant", assistant)
205
- #builder.add_node("tools", ToolNode(tools))
206
- #builder.add_edge(START, "retriever")
207
- #builder.add_edge("retriever", "assistant")
208
- #builder.add_conditional_edges(
209
- # "assistant",
210
- # tools_condition,
211
- #)
212
- #builder.add_edge("tools", "assistant")
213
-
214
- builder = StateGraph(MessagesState)
215
- builder.add_node("retriever", retriever)
216
-
217
- # Retriever ist Start und Endpunkt
218
- builder.set_entry_point("retriever")
219
- builder.set_finish_point("retriever")
220
-
221
- # Compile graph
222
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import List
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_experimental.tools import PythonREPLTool
5
+ from langchain_community.vectorstores import FAISS
6
+ #from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain_core.documents import Document
8
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
9
+ from langchain_core.messages import HumanMessage
10
+
11
+
12
+
13
+
14
+
15
+ # -----------------------------
16
+ # LLM
17
+ # -----------------------------
18
+ #llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
19
+ #repo_id="deepseek-ai/DeepSeek-V4-Pro"
20
+ llm = ChatHuggingFace(
21
+ llm=HuggingFaceEndpoint(
22
+ repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
23
+ huggingfacehub_api_token=HF_KEY,
24
+ task="conversational", # Specify task for the conversational model
25
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
 
 
 
 
 
 
27
 
28
+ # -----------------------------
29
+ # Tools
30
+ # -----------------------------
31
+ search = DuckDuckGoSearchRun()
32
+ python_tool = PythonREPLTool()
33
+
34
+ TOOLS = {
35
+ "search": search.run,
36
+ "python": python_tool.run,
37
+ "llm": lambda x: llm.invoke([HumanMessage(content=x)]).content,
38
+ "summarize": lambda text: llm.invoke([HumanMessage(content=f"Summarize the following:\n{text}")]).content
39
+ }
40
+
41
+ # -----------------------------
42
+ # Memory (Vector DB)
43
+ # -----------------------------
44
+ embeddings = HuggingFaceEmbeddings()
45
+ # Initialize FAISS with a dummy document to prevent IndexError when trying to determine embedding dimension
46
+ vectorstore = FAISS.from_documents([Document(page_content="initialization_document_for_dimension_inference")], embeddings)
47
+
48
+ def store_memory(text: str):
49
+ vectorstore.add_documents([Document(page_content=text)])
50
+
51
+ def retrieve_memory(query: str):
52
+ docs = vectorstore.similarity_search(query, k=3)
53
+ return "\n".join([d.page_content for d in docs])
54
+
55
+ # -----------------------------
56
+ # Planner
57
+ # -----------------------------
58
+ def plan(goal, history):
59
+ prompt = f"""
60
+ You are an autonomous agent.
61
+
62
+ Goal: {goal}
63
+
64
+ Previous steps:
65
+ {history}
66
+
67
+ Decide the NEXT action:
68
+ - search(query)
69
+ - python(code)
70
+ - llm(prompt)
71
+ - summarize(text)
72
+ - finish(answer)
73
+
74
+ Respond ONLY in one line.
75
+ """
76
+ return llm.invoke([HumanMessage(content=prompt)]).content.strip()
77
+
78
+ # -----------------------------
79
+ # Executor
80
+ # -----------------------------
81
+ def execute(action: str):
82
+ try:
83
+ if action.startswith("search("):
84
+ query = action[len("search("):-1]
85
+ return TOOLS["search"](query)
86
+
87
+ elif action.startswith("python("):
88
+ code = action[len("python("):-1]
89
+ return TOOLS["python"](code)
90
+
91
+ elif action.startswith("llm("):
92
+ prompt = action[len("llm("):-1]
93
+ return TOOLS["llm"](prompt)
94
+
95
+ elif action.startswith("summarize("):
96
+ text_to_summarize = action[len("summarize("):-1]
97
+ return TOOLS["summarize"](text_to_summarize)
98
+
99
+ elif action.startswith("finish("):
100
+ return action[len("finish("):-1]
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  else:
103
+ return "Invalid action"
104
+
105
+ except Exception as e:
106
+ return f"Error: {str(e)}"
107
+
108
+ # -----------------------------
109
+ # Critic (loop control)
110
+ # -----------------------------
111
+ def critic(goal, last_result):
112
+ prompt = f"""
113
+ Goal: {goal}
114
+
115
+ Latest result:
116
+ {last_result}
117
+
118
+ Is the goal achieved? Answer YES or NO.
119
+ """
120
+ return "YES" in llm.invoke([HumanMessage(content=prompt)]).content.upper()
121
+
122
+ # -----------------------------
123
+ # Autonomous Loop
124
+ # -----------------------------
125
+ def autonomous_agent(goal: str, max_steps=15):
126
+
127
+ history = ""
128
+ print(f"\n🎯 Goal: {goal}\n")
129
+
130
+ for step in range(max_steps):
131
+ print(f"--- Step {step+1} ---")
132
+
133
+ # Retrieve memory
134
+ memory_context = retrieve_memory(goal)
135
+
136
+ action = plan(goal, history + "\nMemory:\n" + memory_context)
137
+ print(f"🧠 Plan: {action}")
138
+
139
+ result = execute(action)
140
+ print(f"⚙️ Result: {result[:300]}...\n")
141
+
142
+ # Store memory
143
+ store_memory(f"Action: {action}\nResult: {result}")
144
+
145
+ history += f"\nStep {step+1}: {action} → {result}"
146
+
147
+ # Finish condition
148
+ if action.startswith("finish("):
149
+ print("✅ Finished by agent")
150
+ return result
151
+
152
+ # Critic check
153
+ if critic(goal, result):
154
+ print("✅ Critic determined goal achieved")
155
+ return result
156
+
157
+ return "❌ Max steps reached without completion"
158
+
159
+ # -----------------------------
160
+ # Run
161
+ # -----------------------------
162
+ if __name__ == "__main__":
163
+ while True:
164
+ goal = input("\nEnter goal (or 'exit'): ")
165
+ if goal == "exit":
166
+ break
167
+
168
+ result = autonomous_agent(goal)
169
+ print(f"\n🤖 Final Output:\n{result}\n")