anandshende-videocx commited on
Commit
4870d9e
·
verified ·
1 Parent(s): 88f4382

Update my_agent.py

Browse files
Files changed (1) hide show
  1. my_agent.py +73 -10
my_agent.py CHANGED
@@ -9,6 +9,10 @@ from langchain_ollama import ChatOllama
9
  from langchain.agents.middleware.types import AgentState
10
  from langchain.messages import HumanMessage, AIMessage, SystemMessage
11
 
 
 
 
 
12
  hf_token = os.getenv("HF_TOKEN")
13
 
14
  class AgentResponseState(AgentState):
@@ -34,13 +38,22 @@ class BasicAgent:
34
  # # debug=True,
35
  # )
36
  tools = [
37
- DuckDuckGoSearchRun(),
 
 
 
 
38
  ]
39
 
40
- builder = StateGraph(MessagesState)
41
 
42
- model = create_agent(llm, tools)
 
 
 
 
43
  builder.add_node("assistant", model)
 
44
  builder.add_node("tools", ToolNode(tools))
45
 
46
  # Define edges: these determine how the control flow moves
@@ -50,8 +63,20 @@ class BasicAgent:
50
  # If the latest message requires a tool, route to tools
51
  # Otherwise, provide a direct response
52
  tools_condition,
 
 
 
 
53
  )
54
  builder.add_edge("tools", "assistant")
 
 
 
 
 
 
 
 
55
  self.agent = builder.compile()
56
 
57
  print("BasicAgent initialized.")
@@ -64,18 +89,56 @@ class BasicAgent:
64
  print(f"Agent returning fixed answer: {fixed_answer}")
65
  return fixed_answer
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def generate_answer(self, question: str) -> str:
68
  response = self.agent.invoke(
69
  {
70
  "messages": [
71
  {
72
- "role": "user",
 
 
 
 
73
  "content": question,
74
- }
75
- ]
76
- }
 
 
77
  )
78
  print(f"Agent raw response: {response}")
79
- print(f"response.content => {response['messages']}")
80
- print(f"AI response => {response['messages'][-1].content}")
81
- return response['messages'][-1].content
 
9
  from langchain.agents.middleware.types import AgentState
10
  from langchain.messages import HumanMessage, AIMessage, SystemMessage
11
 
12
+
13
+ from prompts import system_prompt, qa_system_prompt
14
+ from my_tools import wiki_search, arxiv_search, web_search, visit_webpage, translate_to_english
15
+
16
  hf_token = os.getenv("HF_TOKEN")
17
 
18
  class AgentResponseState(AgentState):
 
38
  # # debug=True,
39
  # )
40
  tools = [
41
+ wiki_search,
42
+ arxiv_search,
43
+ web_search,
44
+ visit_webpage,
45
+ translate_to_english,
46
  ]
47
 
48
+ builder = StateGraph(GraphMessagesState)
49
 
50
+ model = create_agent(
51
+ llm,
52
+ tools,
53
+ system_prompt=system_prompt,
54
+ )
55
  builder.add_node("assistant", model)
56
+ builder.add_node("assistant_qa", self.call_qa)
57
  builder.add_node("tools", ToolNode(tools))
58
 
59
  # Define edges: these determine how the control flow moves
 
63
  # If the latest message requires a tool, route to tools
64
  # Otherwise, provide a direct response
65
  tools_condition,
66
+ {
67
+ "tools": "tools",
68
+ END: "assistant_qa",
69
+ },
70
  )
71
  builder.add_edge("tools", "assistant")
72
+ builder.add_conditional_edges(
73
+ "assistant_qa",
74
+ tools_condition,
75
+ {
76
+ "tools": "tools",
77
+ END: END,
78
+ },
79
+ )
80
  self.agent = builder.compile()
81
 
82
  print("BasicAgent initialized.")
 
89
  print(f"Agent returning fixed answer: {fixed_answer}")
90
  return fixed_answer
91
 
92
+ def call_qa(self, graph_state: GraphMessagesState) -> str:
93
+ # print(f"Calling LLM QA for question: {graph_state['question']}")
94
+ # print(type(graph_state["messages"]))
95
+ # print(graph_state["messages"])
96
+
97
+ # parsed_messages = [
98
+ # {"role": m.type, "content": m.content} for m in graph_state["messages"]
99
+ # ]
100
+ parsed_messages = [
101
+ SystemMessage(content=qa_system_prompt)
102
+ ]
103
+ parsed_messages.extend(graph_state["messages"][1:])
104
+ parsed_messages.append(HumanMessage(content=f"Question: {graph_state['question']}"))
105
+ print(f"\n\n\n parsed_messages => {parsed_messages}")
106
+
107
+ # response = self.llm_qa.invoke(
108
+ # {
109
+ # "messages": [
110
+ # *parsed_messages,
111
+ # {
112
+ # "role": "human",
113
+ # "content": graph_state["question"],
114
+ # },
115
+ # ]
116
+ # },
117
+ # {"callbacks": [langfuse_handler]},
118
+ # )
119
+ response = self.llm_qa.invoke(
120
+ parsed_messages,
121
+ {"callbacks": [langfuse_handler]},
122
+ )
123
+ print(f"LLAMA 2 -> QA Agent raw response: {response}")
124
+ return response.model_dump()
125
+
126
  def generate_answer(self, question: str) -> str:
127
  response = self.agent.invoke(
128
  {
129
  "messages": [
130
  {
131
+ "role": "system",
132
+ "content": system_prompt,
133
+ },
134
+ {
135
+ "role": "human",
136
  "content": question,
137
+ },
138
+ ],
139
+ "question": question,
140
+ },
141
+ {"callbacks": [langfuse_handler]},
142
  )
143
  print(f"Agent raw response: {response}")
144
+ return response["messages"][-1].content