DrishtiSharma commited on
Commit
6ec949a
·
verified ·
1 Parent(s): 18cfa89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -7,7 +7,7 @@ from langchain_core.tools import tool
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langgraph.graph import StateGraph, END
9
  from typing import TypedDict, Annotated, Sequence
10
- from langchain_core.messages import BaseMessage, HumanMessage
11
  import operator
12
  import networkx as nx
13
  import matplotlib.pyplot as plt
@@ -38,12 +38,11 @@ def search(query: str):
38
 
39
  tools = [search, multiply]
40
  tool_map = {tool.name: tool for tool in tools}
41
-
42
  model_with_tools = model.bind_tools(tools)
43
 
44
  # Define Agent State class
45
  class AgentState(TypedDict):
46
- messages: Annotated[Sequence[BaseMessage], operator.add]
47
 
48
  # Define workflow nodes
49
  def invoke_model(state):
@@ -70,8 +69,9 @@ def invoke_tool(state):
70
  if response == "No":
71
  raise ValueError(f"Execution of '{selected_tool}' was canceled.")
72
 
 
73
  response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
74
- return {"messages": [BaseMessage(content=str(response))]}
75
 
76
  def router(state):
77
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
@@ -129,11 +129,14 @@ if st.button("Run Workflow"):
129
  for s in compiled_app.stream({"messages": [HumanMessage(content=prompt)]}):
130
  intermediate_outputs.append(s)
131
 
132
- # Extract and display the final response safely
133
  final_output = intermediate_outputs[-1]
134
- messages = final_output.get('messages', [{}])[-1]
135
- response_content = getattr(messages, 'content', 'No content generated.')
136
- st.write("Response:", response_content)
 
 
 
137
  except Exception as e:
138
  st.error(f"Error: {e}")
139
 
 
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langgraph.graph import StateGraph, END
9
  from typing import TypedDict, Annotated, Sequence
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
  import operator
12
  import networkx as nx
13
  import matplotlib.pyplot as plt
 
38
 
39
  tools = [search, multiply]
40
  tool_map = {tool.name: tool for tool in tools}
 
41
  model_with_tools = model.bind_tools(tools)
42
 
43
  # Define Agent State class
44
  class AgentState(TypedDict):
45
+ messages: Annotated[Sequence[HumanMessage], operator.add]
46
 
47
  # Define workflow nodes
48
  def invoke_model(state):
 
69
  if response == "No":
70
  raise ValueError(f"Execution of '{selected_tool}' was canceled.")
71
 
72
+ # Invoke tool and return response as AIMessage
73
  response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
74
+ return {"messages": [AIMessage(content=str(response))]}
75
 
76
  def router(state):
77
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
 
129
  for s in compiled_app.stream({"messages": [HumanMessage(content=prompt)]}):
130
  intermediate_outputs.append(s)
131
 
132
+ # Extract and display the final response
133
  final_output = intermediate_outputs[-1]
134
+ messages = final_output.get('messages', [])
135
+ if messages:
136
+ response_content = messages[-1].content
137
+ st.write("Response:", response_content)
138
+ else:
139
+ st.write("Response: No content generated.")
140
  except Exception as e:
141
  st.error(f"Error: {e}")
142