MoBnJlal commited on
Commit
0aa27ac
·
verified ·
1 Parent(s): 60f2d59

Update Agent.py

Browse files
Files changed (1) hide show
  1. Agent.py +10 -7
Agent.py CHANGED
@@ -7,7 +7,7 @@ from langchain.memory import ConversationBufferMemory
7
  from langchain_community.tools import DuckDuckGoSearchRun
8
  #from langchain.schema import SystemMessage
9
  #from langchain.schema.messages import HumanMessage
10
- from langchain_core.messages import SystemMessage, HumanMessage
11
  GAIA_SYSTEM_PROMPT = """
12
  You are an advanced autonomous agent designed to pass the GAIA benchmark.
13
  You must:
@@ -34,8 +34,8 @@ llm = ChatGoogleGenerativeAI(
34
  max_output_tokens=1024
35
  )
36
  messages = [
37
- SystemMessage(content=GAIA_SYSTEM_PROMPT) ,
38
- HumanMessage(content="Get started, messages on the way.")
39
  ]
40
 
41
  response = llm.invoke(messages)
@@ -100,10 +100,13 @@ agent = initialize_agent(
100
  )
101
 
102
  def run_agent(state: MessagesState) -> MessagesState:
103
- result = agent.run(state.input)
104
- state.output = result
105
- state.intermediate_steps.append(result)
106
- return state
 
 
 
107
 
108
  graph = StateGraph(MessagesState)
109
 
 
7
  from langchain_community.tools import DuckDuckGoSearchRun
8
  #from langchain.schema import SystemMessage
9
  #from langchain.schema.messages import HumanMessage
10
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
11
  GAIA_SYSTEM_PROMPT = """
12
  You are an advanced autonomous agent designed to pass the GAIA benchmark.
13
  You must:
 
34
  max_output_tokens=1024
35
  )
36
  messages = [
37
+ SystemMessage(content=GAIA_SYSTEM_PROMPT) #,
38
+ #HumanMessage(content="Get started, messages on the way.")
39
  ]
40
 
41
  response = llm.invoke(messages)
 
100
  )
101
 
102
  def run_agent(state: MessagesState) -> MessagesState:
103
+ # Let the agent respond based on the full message history
104
+ result = agent.invoke(state["messages"])
105
+
106
+ # Add the AI's response to the message history
107
+ state["messages"].append(AIMessage(content=result.content))
108
+
109
+ return state
110
 
111
  graph = StateGraph(MessagesState)
112