DrishtiSharma commited on
Commit
d0f4eb9
·
verified ·
1 Parent(s): c6a7187

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +112 -0
interim.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+ import matplotlib.pyplot as plt
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.graph import MessagesState
10
+ from langgraph.graph import START, StateGraph
11
+ from langgraph.prebuilt import tools_condition
12
+ from langgraph.prebuilt import ToolNode
13
+ from langchain_core.messages import HumanMessage, SystemMessage
14
+
15
+ # ------------------- Environment Variable Setup -------------------
16
+ # Fetch API keys from environment variables
17
+ openai_api_key = os.getenv("OPENAI_API_KEY")
18
+ tavily_api_key = os.getenv("TAVILY_API_KEY")
19
+
20
+ # Verify if API keys are set
21
+ if not openai_api_key:
22
+ raise ValueError("Missing required environment variable: OPENAI_API_KEY")
23
+ if not tavily_api_key:
24
+ raise ValueError("Missing required environment variable: TAVILY_API_KEY")
25
+
26
+ # ------------------- Tool Definitions -------------------
27
+ # Tavily Search Tool
28
+ tavily_tool = TavilySearchResults(max_results=5)
29
+
30
+ def multiply(a: int, b: int) -> int:
31
+ """Multiply two numbers."""
32
+ return a * b
33
+
34
+ def add(a: int, b: int) -> int:
35
+ """Add two numbers."""
36
+ return a + b
37
+
38
+ def divide(a: int, b: int) -> float:
39
+ """Divide two numbers."""
40
+ if b == 0:
41
+ raise ValueError("Division by zero is not allowed.")
42
+ return a / b
43
+
44
+ # Combine tools
45
+ tools = [add, multiply, divide, tavily_tool]
46
+
47
+ # ------------------- LLM and System Message Setup -------------------
48
+ llm = ChatOpenAI(model="gpt-4o-mini")
49
+ llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
50
+ sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic and search on a set of inputs.")
51
+
52
+ # ------------------- LangGraph Workflow -------------------
53
+ def assistant(state: MessagesState):
54
+ """Assistant node to invoke LLM with tools."""
55
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
56
+
57
+ # Define the graph
58
+ app_graph = StateGraph(MessagesState)
59
+ app_graph.add_node("assistant", assistant)
60
+ app_graph.add_node("tools", ToolNode(tools))
61
+ app_graph.add_edge(START, "assistant")
62
+ app_graph.add_conditional_edges("assistant", tools_condition)
63
+ app_graph.add_edge("tools", "assistant")
64
+ react_graph = app_graph.compile()
65
+
66
+ # ------------------- Streamlit Interface -------------------
67
+ st.title("ReAct Agent")
68
+
69
+ # Display the workflow graph
70
+ st.header("LangGraph Workflow Visualization")
71
+
72
+ # Convert LangGraph workflow to NetworkX graph
73
+ G = nx.DiGraph()
74
+ G.add_edge("START", "assistant")
75
+ G.add_edge("assistant", "tools", label="if tool condition")
76
+ G.add_edge("tools", "assistant")
77
+
78
+ # Draw the graph
79
+ plt.figure(figsize=(10, 6))
80
+ pos = nx.spring_layout(G, seed=42)
81
+ nx.draw(G, pos, with_labels=True, node_size=3000, node_color="lightblue", font_size=10, font_weight="bold")
82
+ nx.draw_networkx_edge_labels(G, pos, edge_labels={
83
+ ("assistant", "tools"): "tools_condition",
84
+ ("tools", "assistant"): "loop back",
85
+ }, font_color="red")
86
+ st.pyplot(plt)
87
+
88
+ # Prompt user for inputs
89
+ user_question = st.text_area("Enter your question:",
90
+ placeholder="Example: 'Add 3 and 4. Multiply the result by 2. Divide it by 5.'")
91
+
92
+ if st.button("Submit"):
93
+ if not user_question.strip():
94
+ st.error("Please enter a valid question.")
95
+ st.stop()
96
+
97
+ st.info("Processing your question...")
98
+ messages = [HumanMessage(content=user_question)]
99
+ response = react_graph.invoke({"messages": messages})
100
+
101
+ # Display results
102
+ st.subheader("Responses")
103
+ for m in response['messages']:
104
+ st.write(m.content)
105
+
106
+ st.success("Processing complete!")
107
+
108
+ # Example Placeholder Suggestions
109
+ st.sidebar.subheader("Example Questions")
110
+ st.sidebar.write("- Add 3 and 4. Multiply the result by 2. Divide it by 5.")
111
+ st.sidebar.write("- Tell me how many centuries Virat Kohli scored.")
112
+ st.sidebar.write("- Search for the tallest building in the world.")