DrishtiSharma's picture
Create app.py
c062c14 verified
raw
history blame
4.25 kB
import os
import streamlit as st
import json
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import pygraphviz as pgv
from PIL import Image
import tempfile
# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
st.stop()
# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)
# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
"""Multiplies two integers together."""
return first_number * second_number
@tool
def search(query: str):
"""Performs web search on the user query."""
tavily = TavilySearchResults(max_results=1)
result = tavily.invoke(query)
return result
tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)
# Define Agent State class
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
# Define workflow nodes
def invoke_model(state):
messages = state['messages']
question = messages[-1]
return {"messages": [model_with_tools.invoke(question)]}
def invoke_tool(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
tool_details = None
for tool_call in tool_calls:
tool_details = tool_call
if tool_details is None:
raise Exception("No tool input found.")
selected_tool = tool_details.get("function").get("name")
st.sidebar.write(f"Selected tool: {selected_tool}")
if selected_tool == "search":
if 'human_loop' in st.session_state and st.session_state['human_loop']:
response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
if response == "No":
raise ValueError("User canceled the search tool execution.")
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
return {"messages": [response]}
def router(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
if len(tool_calls):
return "tool"
else:
return "end"
# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()
# Function to render graph
def render_graph(graph):
dot_string = graph.get_graph().to_string()
G = pgv.AGraph(string=dot_string)
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
G.draw(temp_file.name, prog="dot", format="png")
return Image.open(temp_file.name)
# Streamlit UI
st.title("LLM Tool Workflow Demo")
st.write("This app demonstrates LLM-based tool usage with and without human intervention.")
# Sidebar for options
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)
# Input prompt
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
st.subheader("Execution Results")
try:
intermediate_outputs = []
for s in compiled_app.stream({"messages": [prompt]}):
intermediate_outputs.append(s)
st.write("Response:", list(s.values())[0])
st.write("---")
st.sidebar.write("### Intermediate Outputs")
for i, output in enumerate(intermediate_outputs):
st.sidebar.write(f"Step {i+1}: {output}")
except Exception as e:
st.error(f"Error occurred: {e}")
# Display Graph
st.subheader("Workflow Graph")
graph_image = render_graph(graph)
st.image(graph_image, caption="Workflow Graph", use_column_width=True)