DeekshithN05 commited on
Commit
c8d9c8c
·
verified ·
1 Parent(s): 9542791

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +146 -146
agent.py CHANGED
@@ -1,146 +1,146 @@
1
- """LangGraph Agent"""
2
-
3
- import os
4
- import json
5
- import getpass
6
- from dotenv import load_dotenv
7
-
8
- from langgraph.graph import START, StateGraph, MessagesState
9
- from langgraph.prebuilt import tools_condition, ToolNode
10
-
11
- from langchain_core.messages import SystemMessage, HumanMessage
12
- from langchain_core.vectorstores import InMemoryVectorStore
13
- from langchain_core.documents import Document
14
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
15
- from langchain_ollama import ChatOllama
16
-
17
- from tools.math.multiply import multiply
18
- from tools.math.add import add
19
- from tools.math.subtract import subtract
20
- from tools.math.divide import divide
21
- from tools.math.modulus import modulus
22
- from tools.math.power import power
23
- from tools.math.square_root import square_root
24
-
25
- from tools.search.arxiv_search import arxiv_search
26
- from tools.search.web_search import web_search
27
- from tools.search.wiki_search import wiki_search
28
-
29
- from tools.file.analyze_csv_file import analyze_csv_file
30
- from tools.file.analyze_excel_file import analyze_excel_file
31
- from tools.file.analyze_image import analyze_image
32
- from tools.file.download_file_from_url import download_file_from_url
33
- from tools.file.save_content_to_file import save_content_to_file
34
-
35
- # --- Load environment variables ---
36
- load_dotenv()
37
-
38
- # --- Constants ---
39
- DATASET_PATH = "dataset/metadata.jsonl"
40
- SYSTEM_PROMPT_PATH = "prompts/system_prompt.txt"
41
- TOOLS = [
42
- add,
43
- subtract,
44
- multiply,
45
- divide,
46
- modulus,
47
- power,
48
- square_root,
49
- web_search,
50
- wiki_search,
51
- arxiv_search,
52
- analyze_csv_file,
53
- analyze_excel_file,
54
- analyze_image,
55
- download_file_from_url,
56
- save_content_to_file,
57
- ]
58
-
59
-
60
- def load_vector_store() -> InMemoryVectorStore:
61
- """Load vector store with dataset examples."""
62
- if not os.path.exists(DATASET_PATH):
63
- raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.")
64
- embeddings = OpenAIEmbeddings()
65
- vector_store = InMemoryVectorStore(embeddings)
66
- documents = []
67
- with open(DATASET_PATH, "r", encoding="utf-8") as f:
68
- for line in f:
69
- entry = json.loads(line)
70
- content = (
71
- f"Question: {entry['Question']}\nFinal answer: {entry['Final answer']}"
72
- )
73
- doc = Document(page_content=content, metadata={"source": entry["task_id"]})
74
- documents.append(doc)
75
- vector_store.add_documents(documents)
76
- return vector_store
77
-
78
-
79
- def get_llm(provider: str):
80
- """Get LLM instance based on provider."""
81
- if provider == "openai":
82
- if not os.environ.get("OPENAI_API_KEY"):
83
- os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ")
84
- return ChatOpenAI(model="gpt-4.1", temperature=0)
85
- elif provider == "ollama":
86
- return ChatOllama(model="llama3.2", temperature=0)
87
- else:
88
- raise ValueError("Unsupported provider: choose 'openai' or 'ollama'")
89
-
90
-
91
- def load_system_prompt() -> SystemMessage:
92
- """Load system prompt from file."""
93
- if not os.path.exists(SYSTEM_PROMPT_PATH):
94
- raise FileNotFoundError(f"System prompt not found at {SYSTEM_PROMPT_PATH}.")
95
- with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
96
- return SystemMessage(content=f.read())
97
-
98
-
99
- def build_graph(provider: str = "openai"):
100
- """Build and compile the LangGraph agent."""
101
- llm = get_llm(provider).bind_tools(TOOLS)
102
- vector_store = load_vector_store()
103
- system_msg = load_system_prompt()
104
-
105
- def retriever(state: MessagesState):
106
- """Retrieve similar examples based on user query."""
107
- query = state["messages"][0].content
108
- similar = vector_store.similarity_search(query, k=3)
109
- if similar:
110
- refs = "\n\n".join(doc.page_content for doc in similar)
111
- example_msg = HumanMessage(content=f"Here are similar examples:\n\n{refs}")
112
- return {"messages": [system_msg] + state["messages"] + [example_msg]}
113
- return {"messages": [system_msg] + state["messages"]}
114
-
115
- def assistant(state: MessagesState):
116
- """Call LLM to generate next message."""
117
- response = llm.invoke(state["messages"])
118
- return {"messages": [response]}
119
-
120
- # --- Build graph ---
121
- graph = StateGraph(MessagesState)
122
- graph.add_node("retriever", retriever)
123
- graph.add_node("assistant", assistant)
124
- graph.add_node("tools", ToolNode(TOOLS))
125
-
126
- graph.add_edge(START, "retriever")
127
- graph.add_edge("retriever", "assistant")
128
- graph.add_conditional_edges("assistant", tools_condition)
129
- graph.add_edge("tools", "assistant")
130
-
131
- return graph.compile()
132
-
133
-
134
- def run_agent(query: str, provider: str = "openai"):
135
- """Run the agent on a given query."""
136
- graph = build_graph(provider)
137
- messages = [HumanMessage(content=query)]
138
- result = graph.invoke({"messages": messages})
139
- for msg in result["messages"]:
140
- msg.pretty_print()
141
-
142
-
143
- # --- Run locally ---
144
- if __name__ == "__main__":
145
- user_query = input("Enter your question: ")
146
- run_agent(user_query)
 
1
+ """LangGraph Agent"""
2
+
3
+ import os
4
+ import json
5
+ import getpass
6
+ from dotenv import load_dotenv
7
+
8
+ from langgraph.graph import START, StateGraph, MessagesState
9
+ from langgraph.prebuilt import tools_condition, ToolNode
10
+
11
+ from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+ from langchain_core.documents import Document
14
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
15
+ from langchain_ollama import ChatOllama
16
+
17
+ from tools.math.multiply import multiply
18
+ from tools.math.add import add
19
+ from tools.math.subtract import subtract
20
+ from tools.math.divide import divide
21
+ from tools.math.modulus import modulus
22
+ from tools.math.power import power
23
+ from tools.math.square_root import square_root
24
+
25
+ from tools.search.arxiv_search import arxiv_search
26
+ from tools.search.web_search import web_search
27
+ from tools.search.wiki_search import wiki_search
28
+
29
+ from tools.file.analyze_csv_file import analyze_csv_file
30
+ from tools.file.analyze_excel_file import analyze_excel_file
31
+ from tools.file.analyze_image import analyze_image
32
+ from tools.file.download_file_from_url import download_file_from_url
33
+ from tools.file.save_content_to_file import save_content_to_file
34
+
35
+ # --- Load environment variables ---
36
+ load_dotenv()
37
+
38
+ # --- Constants ---
39
+ DATASET_PATH = "dataset/metadata.jsonl"
40
+ SYSTEM_PROMPT_PATH = "prompts/system_prompt.txt"
41
+ TOOLS = [
42
+ add,
43
+ subtract,
44
+ multiply,
45
+ divide,
46
+ modulus,
47
+ power,
48
+ square_root,
49
+ web_search,
50
+ wiki_search,
51
+ arxiv_search,
52
+ analyze_csv_file,
53
+ analyze_excel_file,
54
+ analyze_image,
55
+ download_file_from_url,
56
+ save_content_to_file,
57
+ ]
58
+
59
+
60
+ def load_vector_store() -> InMemoryVectorStore:
61
+ """Load vector store with dataset examples."""
62
+ if not os.path.exists(DATASET_PATH):
63
+ raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.")
64
+ embeddings = OpenAIEmbeddings()
65
+ vector_store = InMemoryVectorStore(embeddings)
66
+ documents = []
67
+ with open(DATASET_PATH, "r", encoding="utf-8") as f:
68
+ for line in f:
69
+ entry = json.loads(line)
70
+ content = (
71
+ f"Question: {entry['Question']}\nFinal answer: {entry['Final answer']}"
72
+ )
73
+ doc = Document(page_content=content, metadata={"source": entry["task_id"]})
74
+ documents.append(doc)
75
+ vector_store.add_documents(documents)
76
+ return vector_store
77
+
78
+
79
+ def get_llm(provider: str):
80
+ """Get LLM instance based on provider."""
81
+ if provider == "openai":
82
+ if not os.environ.get("OPENAI_API_KEY"):
83
+ os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ")
84
+ return ChatOpenAI(model="gpt-4.1", temperature=0)
85
+ elif provider == "ollama":
86
+ return ChatOllama(model="llama3", temperature=0)
87
+ else:
88
+ raise ValueError("Unsupported provider: choose 'openai' or 'ollama'")
89
+
90
+
91
+ def load_system_prompt() -> SystemMessage:
92
+ """Load system prompt from file."""
93
+ if not os.path.exists(SYSTEM_PROMPT_PATH):
94
+ raise FileNotFoundError(f"System prompt not found at {SYSTEM_PROMPT_PATH}.")
95
+ with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
96
+ return SystemMessage(content=f.read())
97
+
98
+
99
+ def build_graph(provider: str = "openai"):
100
+ """Build and compile the LangGraph agent."""
101
+ llm = get_llm(provider).bind_tools(TOOLS)
102
+ vector_store = load_vector_store()
103
+ system_msg = load_system_prompt()
104
+
105
+ def retriever(state: MessagesState):
106
+ """Retrieve similar examples based on user query."""
107
+ query = state["messages"][0].content
108
+ similar = vector_store.similarity_search(query, k=3)
109
+ if similar:
110
+ refs = "\n\n".join(doc.page_content for doc in similar)
111
+ example_msg = HumanMessage(content=f"Here are similar examples:\n\n{refs}")
112
+ return {"messages": [system_msg] + state["messages"] + [example_msg]}
113
+ return {"messages": [system_msg] + state["messages"]}
114
+
115
+ def assistant(state: MessagesState):
116
+ """Call LLM to generate next message."""
117
+ response = llm.invoke(state["messages"])
118
+ return {"messages": [response]}
119
+
120
+ # --- Build graph ---
121
+ graph = StateGraph(MessagesState)
122
+ graph.add_node("retriever", retriever)
123
+ graph.add_node("assistant", assistant)
124
+ graph.add_node("tools", ToolNode(TOOLS))
125
+
126
+ graph.add_edge(START, "retriever")
127
+ graph.add_edge("retriever", "assistant")
128
+ graph.add_conditional_edges("assistant", tools_condition)
129
+ graph.add_edge("tools", "assistant")
130
+
131
+ return graph.compile()
132
+
133
+
134
+ def run_agent(query: str, provider: str = "openai"):
135
+ """Run the agent on a given query."""
136
+ graph = build_graph(provider)
137
+ messages = [HumanMessage(content=query)]
138
+ result = graph.invoke({"messages": messages})
139
+ for msg in result["messages"]:
140
+ msg.pretty_print()
141
+
142
+
143
+ # --- Run locally ---
144
+ if __name__ == "__main__":
145
+ user_query = input("Enter your question: ")
146
+ run_agent(user_query)