Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from langchain_core.tools import tool
|
|
| 7 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 8 |
from langgraph.graph import StateGraph, END
|
| 9 |
from typing import TypedDict, Annotated, Sequence
|
| 10 |
-
from langchain_core.messages import
|
| 11 |
import operator
|
| 12 |
import networkx as nx
|
| 13 |
import matplotlib.pyplot as plt
|
|
@@ -38,12 +38,11 @@ def search(query: str):
|
|
| 38 |
|
| 39 |
tools = [search, multiply]
|
| 40 |
tool_map = {tool.name: tool for tool in tools}
|
| 41 |
-
|
| 42 |
model_with_tools = model.bind_tools(tools)
|
| 43 |
|
| 44 |
# Define Agent State class
|
| 45 |
class AgentState(TypedDict):
|
| 46 |
-
messages: Annotated[Sequence[
|
| 47 |
|
| 48 |
# Define workflow nodes
|
| 49 |
def invoke_model(state):
|
|
@@ -70,8 +69,9 @@ def invoke_tool(state):
|
|
| 70 |
if response == "No":
|
| 71 |
raise ValueError(f"Execution of '{selected_tool}' was canceled.")
|
| 72 |
|
|
|
|
| 73 |
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
|
| 74 |
-
return {"messages": [
|
| 75 |
|
| 76 |
def router(state):
|
| 77 |
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
|
|
@@ -129,11 +129,14 @@ if st.button("Run Workflow"):
|
|
| 129 |
for s in compiled_app.stream({"messages": [HumanMessage(content=prompt)]}):
|
| 130 |
intermediate_outputs.append(s)
|
| 131 |
|
| 132 |
-
# Extract and display the final response
|
| 133 |
final_output = intermediate_outputs[-1]
|
| 134 |
-
messages = final_output.get('messages', [
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
except Exception as e:
|
| 138 |
st.error(f"Error: {e}")
|
| 139 |
|
|
|
|
| 7 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 8 |
from langgraph.graph import StateGraph, END
|
| 9 |
from typing import TypedDict, Annotated, Sequence
|
| 10 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 11 |
import operator
|
| 12 |
import networkx as nx
|
| 13 |
import matplotlib.pyplot as plt
|
|
|
|
| 38 |
|
| 39 |
tools = [search, multiply]
|
| 40 |
tool_map = {tool.name: tool for tool in tools}
|
|
|
|
| 41 |
model_with_tools = model.bind_tools(tools)
|
| 42 |
|
| 43 |
# Define Agent State class
|
| 44 |
class AgentState(TypedDict):
|
| 45 |
+
messages: Annotated[Sequence[HumanMessage], operator.add]
|
| 46 |
|
| 47 |
# Define workflow nodes
|
| 48 |
def invoke_model(state):
|
|
|
|
| 69 |
if response == "No":
|
| 70 |
raise ValueError(f"Execution of '{selected_tool}' was canceled.")
|
| 71 |
|
| 72 |
+
# Invoke tool and return response as AIMessage
|
| 73 |
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
|
| 74 |
+
return {"messages": [AIMessage(content=str(response))]}
|
| 75 |
|
| 76 |
def router(state):
|
| 77 |
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
|
|
|
|
| 129 |
for s in compiled_app.stream({"messages": [HumanMessage(content=prompt)]}):
|
| 130 |
intermediate_outputs.append(s)
|
| 131 |
|
| 132 |
+
# Extract and display the final response
|
| 133 |
final_output = intermediate_outputs[-1]
|
| 134 |
+
messages = final_output.get('messages', [])
|
| 135 |
+
if messages:
|
| 136 |
+
response_content = messages[-1].content
|
| 137 |
+
st.write("Response:", response_content)
|
| 138 |
+
else:
|
| 139 |
+
st.write("Response: No content generated.")
|
| 140 |
except Exception as e:
|
| 141 |
st.error(f"Error: {e}")
|
| 142 |
|