mhamzaanjum380 commited on
Commit
d5fbf43
·
verified ·
1 Parent(s): 47f5f4e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +144 -0
agent.py CHANGED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader
10
+ from langchain_community.document_loaders import ArxivLoader
11
+ from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langchain.tools.retriever import create_retriever_tool
13
+ from langchain.vectorstores import FAISS
14
+ from langchain.schema import Document
15
+
16
+ from .helping_tools import (
17
+ multiply,
18
+ add,
19
+ subtract,
20
+ divide,
21
+ modulus,
22
+ wiki_search,
23
+ web_search,
24
+ arvix_search,
25
+ )
26
+ # Load metadata.jsonl
27
+ import json
28
+
29
+ # Load the metadata.jsonl file
30
+ with open('metadata.jsonl', 'r') as jsonl_file:
31
+ json_list = list(jsonl_file)
32
+
33
+ # Load dotenv file
34
+ load_dotenv()
35
+
36
+ # metadata.jsonl questions load
37
+ json_QA = []
38
+ for json_str in json_list:
39
+ json_data = json.loads(json_str)
40
+ json_QA.append(json_data)
41
+
42
+ # metadata.jsonl questions
43
+ docs = []
44
+ for sample in json_QA:
45
+ content = f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}"
46
+ doc = Document(
47
+ page_content=content,
48
+ metadata={
49
+ "source": sample['task_id']
50
+ }
51
+ )
52
+ docs.append(doc)
53
+
54
+ # load the system prompt from the file
55
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
56
+ system_prompt = f.read()
57
+
58
+ # System message
59
+ sys_msg = SystemMessage(content=system_prompt)
60
+
61
+ # build a retriever
62
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
63
+ vector_store = FAISS.from_documents(documents=docs, embedding=embeddings)
64
+
65
+ create_retrieve_tool = create_retriever_tool(
66
+ retriever=vector_store.as_retriever(),
67
+ name="Question Search",
68
+ description="A tool to retrieve similar questions from a vector store.",
69
+ )
70
+
71
+ tools = [
72
+ multiply,
73
+ add,
74
+ subtract,
75
+ divide,
76
+ modulus,
77
+ wiki_search,
78
+ web_search,
79
+ arvix_search,
80
+ similar_question_search,
81
+ question_retrieve_tool
82
+ ]
83
+
84
+ # Build graph function
85
+ def build_graph(provider: str = "groq"):
86
+ """Build the graph"""
87
+ # Load environment variables from .env file
88
+ if provider == "google":
89
+ # Google Gemini
90
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
91
+ elif provider == "groq":
92
+ # Groq https://console.groq.com/docs/models
93
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
94
+ elif provider == "huggingface":
95
+ # TODO: Add huggingface endpoint
96
+ llm = ChatHuggingFace(
97
+ llm=HuggingFaceEndpoint(
98
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
99
+ temperature=0,
100
+ ),
101
+ )
102
+ else:
103
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
104
+ # Bind tools to LLM
105
+ llm_with_tools = llm.bind_tools(tools)
106
+
107
+ # Node
108
+ def assistant(state: MessagesState):
109
+ """Assistant node"""
110
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
111
+
112
+ def retriever(state: MessagesState):
113
+ """Retriever node"""
114
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
115
+ example_msg = HumanMessage(
116
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
117
+ )
118
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
119
+
120
+ builder = StateGraph(MessagesState)
121
+ builder.add_node("retriever", retriever)
122
+ builder.add_node("assistant", assistant)
123
+ builder.add_node("tools", ToolNode(tools))
124
+ builder.add_edge(START, "retriever")
125
+ builder.add_edge("retriever", "assistant")
126
+ builder.add_conditional_edges(
127
+ "assistant",
128
+ tools_condition,
129
+ )
130
+ builder.add_edge("tools", "assistant")
131
+
132
+ # Compile graph
133
+ return builder.compile()
134
+
135
+ if __name__ == "__main__":
136
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
137
+ # Build the graph
138
+ graph = build_graph(provider="google")
139
+ # Run the graph
140
+ messages = [HumanMessage(content=question)]
141
+ messages = graph.invoke({"messages": messages})
142
+ for m in messages["messages"]:
143
+ m.pretty_print()
144
+