DrishtiSharma commited on
Commit
c062c14
·
verified ·
1 Parent(s): 4f3be30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import json
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.tools import tool
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langgraph.graph import StateGraph, END
8
+ from typing import TypedDict, Annotated, Sequence
9
+ from langchain_core.messages import BaseMessage
10
+ import operator
11
+ import pygraphviz as pgv
12
+ from PIL import Image
13
+ import tempfile
14
+
15
+ # Set API keys and validate credentials
16
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
18
+
19
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
20
+ st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
21
+ st.stop()
22
+
23
+ # Initialize OpenAI LLM
24
+ model = ChatOpenAI(temperature=0)
25
+
26
+ # Define Tools
27
+ @tool
28
+ def multiply(first_number: int, second_number: int) -> int:
29
+ """Multiplies two integers together."""
30
+ return first_number * second_number
31
+
32
+ @tool
33
+ def search(query: str):
34
+ """Performs web search on the user query."""
35
+ tavily = TavilySearchResults(max_results=1)
36
+ result = tavily.invoke(query)
37
+ return result
38
+
39
+ tools = [search, multiply]
40
+ tool_map = {tool.name: tool for tool in tools}
41
+
42
+ model_with_tools = model.bind_tools(tools)
43
+
44
+ # Define Agent State class
45
+ class AgentState(TypedDict):
46
+ messages: Annotated[Sequence[BaseMessage], operator.add]
47
+
48
+ # Define workflow nodes
49
+ def invoke_model(state):
50
+ messages = state['messages']
51
+ question = messages[-1]
52
+ return {"messages": [model_with_tools.invoke(question)]}
53
+
54
+ def invoke_tool(state):
55
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
56
+ tool_details = None
57
+
58
+ for tool_call in tool_calls:
59
+ tool_details = tool_call
60
+
61
+ if tool_details is None:
62
+ raise Exception("No tool input found.")
63
+
64
+ selected_tool = tool_details.get("function").get("name")
65
+ st.sidebar.write(f"Selected tool: {selected_tool}")
66
+
67
+ if selected_tool == "search":
68
+ if 'human_loop' in st.session_state and st.session_state['human_loop']:
69
+ response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
70
+ if response == "No":
71
+ raise ValueError("User canceled the search tool execution.")
72
+
73
+ response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
74
+ return {"messages": [response]}
75
+
76
+ def router(state):
77
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
78
+ if len(tool_calls):
79
+ return "tool"
80
+ else:
81
+ return "end"
82
+
83
+ # Graph setup
84
+ graph = StateGraph(AgentState)
85
+ graph.add_node("agent", invoke_model)
86
+ graph.add_node("tool", invoke_tool)
87
+ graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
88
+ graph.add_edge("tool", END)
89
+ graph.set_entry_point("agent")
90
+ compiled_app = graph.compile()
91
+
92
+ # Function to render graph
93
+ def render_graph(graph):
94
+ dot_string = graph.get_graph().to_string()
95
+ G = pgv.AGraph(string=dot_string)
96
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
97
+ G.draw(temp_file.name, prog="dot", format="png")
98
+ return Image.open(temp_file.name)
99
+
100
+ # Streamlit UI
101
+ st.title("LLM Tool Workflow Demo")
102
+ st.write("This app demonstrates LLM-based tool usage with and without human intervention.")
103
+
104
+ # Sidebar for options
105
+ st.sidebar.header("Configuration")
106
+ st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)
107
+
108
+ # Input prompt
109
+ prompt = st.text_input("Enter your question:", "What is 24 * 365?")
110
+ if st.button("Run Workflow"):
111
+ st.subheader("Execution Results")
112
+ try:
113
+ intermediate_outputs = []
114
+ for s in compiled_app.stream({"messages": [prompt]}):
115
+ intermediate_outputs.append(s)
116
+ st.write("Response:", list(s.values())[0])
117
+ st.write("---")
118
+
119
+ st.sidebar.write("### Intermediate Outputs")
120
+ for i, output in enumerate(intermediate_outputs):
121
+ st.sidebar.write(f"Step {i+1}: {output}")
122
+ except Exception as e:
123
+ st.error(f"Error occurred: {e}")
124
+
125
+ # Display Graph
126
+ st.subheader("Workflow Graph")
127
+ graph_image = render_graph(graph)
128
+ st.image(graph_image, caption="Workflow Graph", use_column_width=True)