File size: 4,958 Bytes
c062c14
 
 
92d35f8
c062c14
 
 
 
 
 
 
6988bcd
 
c062c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d35f8
 
 
 
 
 
c062c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6988bcd
92d35f8
6988bcd
 
 
 
 
 
 
 
 
 
 
 
c062c14
 
92d35f8
 
c062c14
92d35f8
c062c14
92d35f8
c062c14
92d35f8
c062c14
 
 
92d35f8
 
 
 
 
c062c14
 
 
 
92d35f8
 
 
 
c062c14
92d35f8
 
 
 
 
 
 
 
 
c062c14
92d35f8
 
c062c14
92d35f8
c062c14
92d35f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import streamlit as st
import json
import time
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import networkx as nx
import matplotlib.pyplot as plt

# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

if not OPENAI_API_KEY or not TAVILY_API_KEY:
    st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
    st.stop()

# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)

# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
    """Multiplies two integers together."""
    return first_number * second_number

@tool
def search(query: str):
    """Performs web search on the user query."""
    tavily = TavilySearchResults(max_results=1)
    result = tavily.invoke(query)
    return result

tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)

# Define Agent State class
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

# Define workflow nodes
def invoke_model(state):
    messages = state['messages']
    question = messages[-1]
    return {"messages": [model_with_tools.invoke(question)]}

def invoke_tool(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    tool_details = None

    for tool_call in tool_calls:
        tool_details = tool_call

    if tool_details is None:
        raise Exception("No tool input found.")

    selected_tool = tool_details.get("function").get("name")
    st.sidebar.write(f"Selected tool: {selected_tool}")

    # Add human-in-the-loop for all tools
    if st.session_state['human_loop']:
        response = st.sidebar.radio(f"Proceed with tool '{selected_tool}'?", ["Yes", "No"], index=0)
        if response == "No":
            raise ValueError(f"Execution of '{selected_tool}' was canceled.")

    response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
    return {"messages": [response]}

def router(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "tool"
    else:
        return "end"

# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()

# Function to render graph with NetworkX
def render_graph_nx():
    G = nx.DiGraph()
    G.add_edge("agent", "tool", label="invoke tool")
    G.add_edge("agent", "end", label="end condition")
    G.add_edge("tool", "end", label="finish")

    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
    edge_labels = nx.get_edge_attributes(G, "label")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
    plt.title("Workflow Graph")
    st.pyplot(plt)

# Streamlit UI
st.title("Blah Blah Demo")
st.write("Compare results **with and without human intervention** in the workflow.")

# Sidebar for configuration
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop", value=False)

# Input and comparison mode
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
    st.subheader("Execution Results")

    # Without human-in-the-loop
    st.markdown("### Without Human-in-the-Loop")
    st.session_state['human_loop'] = False
    start_time = time.time()
    try:
        intermediate_outputs = []
        for s in compiled_app.stream({"messages": [prompt]}):
            intermediate_outputs.append(s)
        st.write("Response:", intermediate_outputs[-1]['messages'][0])
    except Exception as e:
        st.error(f"Error: {e}")
    st.write(f"Execution Time: {time.time() - start_time:.2f} seconds")

    # With human-in-the-loop
    st.markdown("### With Human-in-the-Loop")
    st.session_state['human_loop'] = True
    start_time = time.time()
    try:
        intermediate_outputs = []
        for s in compiled_app.stream({"messages": [prompt]}):
            intermediate_outputs.append(s)
        st.write("Response:", intermediate_outputs[-1]['messages'][0])
    except Exception as e:
        st.error(f"Error: {e}")
    st.write(f"Execution Time: {time.time() - start_time:.2f} seconds")

    # Display Workflow Graph
    st.subheader("Workflow Graph")
    render_graph_nx()