File size: 4,462 Bytes
c062c14
 
 
 
 
 
 
 
b5a0276
c062c14
6988bcd
 
c062c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5a0276
c062c14
 
 
 
b5a0276
c062c14
 
 
 
 
b5a0276
c062c14
 
 
 
 
 
 
 
 
 
 
 
 
 
b5a0276
 
 
 
 
 
c062c14
b5a0276
c062c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6988bcd
b5a0276
6988bcd
 
 
 
 
 
 
 
 
 
 
 
c062c14
 
a8bed22
b5a0276
c062c14
b5a0276
c062c14
b5a0276
c062c14
2ef9839
c062c14
 
 
b5a0276
92d35f8
b5a0276
 
 
 
 
 
 
 
 
 
 
 
c062c14
b5a0276
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import streamlit as st
import json
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import networkx as nx
import matplotlib.pyplot as plt

# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

if not OPENAI_API_KEY or not TAVILY_API_KEY:
    st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
    st.stop()

# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)

# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
    """Multiplies two integers together."""
    return first_number * second_number

@tool
def search(query: str):
    """Performs web search on the user query."""
    tavily = TavilySearchResults(max_results=1)
    result = tavily.invoke(query)
    return result

tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}

model_with_tools = model.bind_tools(tools)

# Define Agent State class
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

# Define workflow nodes
def invoke_model(state):
    messages = state['messages']
    question = messages[-1]
    return {"messages": [model_with_tools.invoke(question)]}

def invoke_tool(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    tool_details = None

    for tool_call in tool_calls:
        tool_details = tool_call

    if tool_details is None:
        raise Exception("No tool input found.")

    selected_tool = tool_details.get("function").get("name")
    st.sidebar.write(f"Selected tool: {selected_tool}")

    if selected_tool == "search":
        if 'human_loop' in st.session_state and st.session_state['human_loop']:
            response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
            if response == "No":
                raise ValueError("User canceled the search tool execution.")
    
    response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
    return {"messages": [response]}

def router(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "tool"
    else:
        return "end"

# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()

# Function to render graph with NetworkX
def render_graph_nx(graph):
    G = nx.DiGraph()
    G.add_edge("agent", "tool", label="invoke tool")
    G.add_edge("agent", "end", label="end condition")
    G.add_edge("tool", "end", label="finish")

    pos = nx.spring_layout(G, seed=42)
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
    edge_labels = nx.get_edge_attributes(G, "label")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
    plt.title("Workflow Graph")
    st.pyplot(plt)

# Streamlit UI
st.title("Tool Usage with and w/o Human Intervention")
st.write("This app demonstrates LLM-based tool usage with and without human intervention.")

# Sidebar for options
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)

# Input prompt
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
    st.subheader("Execution Results")
    try:
        intermediate_outputs = []
        for s in compiled_app.stream({"messages": [prompt]}):
            intermediate_outputs.append(s)
            st.write("Response:", list(s.values())[0])
            st.write("---")

        st.sidebar.write("### Intermediate Outputs")
        for i, output in enumerate(intermediate_outputs):
            st.sidebar.write(f"Step {i+1}: {output}")
    except Exception as e:
        st.error(f"Error occurred: {e}")

    # Display Graph
    st.subheader("Workflow Graph")
    render_graph_nx(graph)