File size: 4,462 Bytes
c062c14 b5a0276 c062c14 6988bcd c062c14 b5a0276 c062c14 b5a0276 c062c14 b5a0276 c062c14 b5a0276 c062c14 b5a0276 c062c14 6988bcd b5a0276 6988bcd c062c14 a8bed22 b5a0276 c062c14 b5a0276 c062c14 b5a0276 c062c14 2ef9839 c062c14 b5a0276 92d35f8 b5a0276 c062c14 b5a0276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import streamlit as st
import json
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import networkx as nx
import matplotlib.pyplot as plt
# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
st.stop()
# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)
# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
"""Multiplies two integers together."""
return first_number * second_number
@tool
def search(query: str):
"""Performs web search on the user query."""
tavily = TavilySearchResults(max_results=1)
result = tavily.invoke(query)
return result
tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)
# Define Agent State class
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
# Define workflow nodes
def invoke_model(state):
messages = state['messages']
question = messages[-1]
return {"messages": [model_with_tools.invoke(question)]}
def invoke_tool(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
tool_details = None
for tool_call in tool_calls:
tool_details = tool_call
if tool_details is None:
raise Exception("No tool input found.")
selected_tool = tool_details.get("function").get("name")
st.sidebar.write(f"Selected tool: {selected_tool}")
if selected_tool == "search":
if 'human_loop' in st.session_state and st.session_state['human_loop']:
response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
if response == "No":
raise ValueError("User canceled the search tool execution.")
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
return {"messages": [response]}
def router(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
if len(tool_calls):
return "tool"
else:
return "end"
# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()
# Function to render graph with NetworkX
def render_graph_nx(graph):
G = nx.DiGraph()
G.add_edge("agent", "tool", label="invoke tool")
G.add_edge("agent", "end", label="end condition")
G.add_edge("tool", "end", label="finish")
pos = nx.spring_layout(G, seed=42)
plt.figure(figsize=(8, 6))
nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
edge_labels = nx.get_edge_attributes(G, "label")
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
plt.title("Workflow Graph")
st.pyplot(plt)
# Streamlit UI
st.title("Tool Usage with and w/o Human Intervention")
st.write("This app demonstrates LLM-based tool usage with and without human intervention.")
# Sidebar for options
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)
# Input prompt
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
st.subheader("Execution Results")
try:
intermediate_outputs = []
for s in compiled_app.stream({"messages": [prompt]}):
intermediate_outputs.append(s)
st.write("Response:", list(s.values())[0])
st.write("---")
st.sidebar.write("### Intermediate Outputs")
for i, output in enumerate(intermediate_outputs):
st.sidebar.write(f"Step {i+1}: {output}")
except Exception as e:
st.error(f"Error occurred: {e}")
# Display Graph
st.subheader("Workflow Graph")
render_graph_nx(graph) |