File size: 4,958 Bytes
c062c14 92d35f8 c062c14 6988bcd c062c14 92d35f8 c062c14 6988bcd 92d35f8 6988bcd c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 c062c14 92d35f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import streamlit as st
import json
import time
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import networkx as nx
import matplotlib.pyplot as plt
# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
st.stop()
# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)
# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
"""Multiplies two integers together."""
return first_number * second_number
@tool
def search(query: str):
"""Performs web search on the user query."""
tavily = TavilySearchResults(max_results=1)
result = tavily.invoke(query)
return result
tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)
# Define Agent State class
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
# Define workflow nodes
def invoke_model(state):
messages = state['messages']
question = messages[-1]
return {"messages": [model_with_tools.invoke(question)]}
def invoke_tool(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
tool_details = None
for tool_call in tool_calls:
tool_details = tool_call
if tool_details is None:
raise Exception("No tool input found.")
selected_tool = tool_details.get("function").get("name")
st.sidebar.write(f"Selected tool: {selected_tool}")
# Add human-in-the-loop for all tools
if st.session_state['human_loop']:
response = st.sidebar.radio(f"Proceed with tool '{selected_tool}'?", ["Yes", "No"], index=0)
if response == "No":
raise ValueError(f"Execution of '{selected_tool}' was canceled.")
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
return {"messages": [response]}
def router(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
if len(tool_calls):
return "tool"
else:
return "end"
# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()
# Function to render graph with NetworkX
def render_graph_nx():
G = nx.DiGraph()
G.add_edge("agent", "tool", label="invoke tool")
G.add_edge("agent", "end", label="end condition")
G.add_edge("tool", "end", label="finish")
pos = nx.spring_layout(G, seed=42)
plt.figure(figsize=(8, 6))
nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
edge_labels = nx.get_edge_attributes(G, "label")
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
plt.title("Workflow Graph")
st.pyplot(plt)
# Streamlit UI
st.title("Blah Blah Demo")
st.write("Compare results **with and without human intervention** in the workflow.")
# Sidebar for configuration
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop", value=False)
# Input and comparison mode
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
st.subheader("Execution Results")
# Without human-in-the-loop
st.markdown("### Without Human-in-the-Loop")
st.session_state['human_loop'] = False
start_time = time.time()
try:
intermediate_outputs = []
for s in compiled_app.stream({"messages": [prompt]}):
intermediate_outputs.append(s)
st.write("Response:", intermediate_outputs[-1]['messages'][0])
except Exception as e:
st.error(f"Error: {e}")
st.write(f"Execution Time: {time.time() - start_time:.2f} seconds")
# With human-in-the-loop
st.markdown("### With Human-in-the-Loop")
st.session_state['human_loop'] = True
start_time = time.time()
try:
intermediate_outputs = []
for s in compiled_app.stream({"messages": [prompt]}):
intermediate_outputs.append(s)
st.write("Response:", intermediate_outputs[-1]['messages'][0])
except Exception as e:
st.error(f"Error: {e}")
st.write(f"Execution Time: {time.time() - start_time:.2f} seconds")
# Display Workflow Graph
st.subheader("Workflow Graph")
render_graph_nx()
|