DrishtiSharma commited on
Commit
c05a5d6
·
verified ·
1 Parent(s): c33ab38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import chromadb
3
  import streamlit as st
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain.agents import AgentExecutor, create_openai_tools_agent
6
  from langchain_core.messages import BaseMessage, HumanMessage
@@ -22,12 +24,10 @@ import operator
22
  from langchain_core.tools import tool
23
  from glob import glob
24
 
25
-
26
- # Clear ChromaDB cache to fix tenant issue
27
  chromadb.api.client.SharedSystemClient.clear_system_cache()
28
 
29
  # Load environment variables
30
-
31
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
33
 
@@ -35,7 +35,7 @@ if not OPENAI_API_KEY or not TAVILY_API_KEY:
35
  st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
36
  st.stop()
37
 
38
- # Initialize API keys and LLM
39
  llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
40
 
41
  # Utility Functions
@@ -53,19 +53,15 @@ def agent_node(state, agent, name):
53
  result = agent.invoke(state)
54
  output_content = result["output"]
55
 
56
- # Check if the output contains Python code that generates a graph
57
  if "matplotlib" in output_content or "plt." in output_content:
58
  exec_locals = {}
59
  try:
60
- exec(output_content, {}, exec_locals) # Safely execute the code
61
- fig = plt.gcf() # Get the current matplotlib figure
62
-
63
- # Save the figure to a buffer
64
  buf = io.BytesIO()
65
  fig.savefig(buf, format="png")
66
  buf.seek(0)
67
-
68
- # Add image to session state for display
69
  st.session_state.graph_image = buf
70
  except Exception as e:
71
  output_content += f"\nError: {str(e)}"
@@ -88,28 +84,24 @@ def RAG(state):
88
  result = retrieval_chain.invoke(question)
89
  return result
90
 
91
- # Load Tools
92
  tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
93
  python_repl_tool = PythonREPLTool()
94
 
95
  # Streamlit UI
96
- st.title("Multi-Agent w Supervisor")
97
 
98
- # Example questions for immediate testing
99
  example_questions = [
100
- #"Code hello world and print it",
101
  "What is James McIlroy aiming for in sports?",
102
  "Fetch India's GDP over the past 5 years and draw a line graph.",
103
  "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
104
  ]
105
 
106
- # File Selection Section
107
  source_files = glob("sources/*.txt")
108
  selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
109
-
110
  uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
111
 
112
- # Combine Files
113
  all_docs = []
114
  if selected_files:
115
  for file_path in selected_files:
@@ -122,18 +114,17 @@ if uploaded_files:
122
  all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
123
 
124
  if not all_docs:
125
- st.warning("Please select files from the source directory or upload TXT files.")
126
  st.stop()
127
 
128
- # Process Documents
129
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
130
  split_docs = text_splitter.split_documents(all_docs)
131
-
132
  embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
133
  db = Chroma.from_documents(split_docs, embeddings)
134
  retriever = db.as_retriever(search_kwargs={"k": 4})
135
 
136
- # Create Agents
137
  research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
138
  code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
139
  RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
@@ -143,10 +134,7 @@ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
143
  rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
144
 
145
  members = ["RAG", "Researcher", "Coder"]
146
- system_prompt = (
147
- "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
148
- "Use RAG tool for Japan or Sports questions."
149
- )
150
  options = ["FINISH"] + members
151
  function_def = {
152
  "name": "route", "description": "Select the next role.",
@@ -157,10 +145,8 @@ prompt = ChatPromptTemplate.from_messages([
157
  MessagesPlaceholder(variable_name="messages"),
158
  ("system", "Given the conversation above, who should act next? Select one of: {options}"),
159
  ]).partial(options=str(options), members=", ".join(members))
160
-
161
  supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
162
 
163
- # Workflow
164
  class AgentState(TypedDict):
165
  messages: Annotated[Sequence[BaseMessage], operator.add]
166
  next: str
@@ -188,6 +174,7 @@ user_input = st.text_area("Enter your task or question:", placeholder=example_qu
188
  def run_workflow(task):
189
  st.session_state.outputs.clear()
190
  st.session_state.outputs.append(f"User Input: {task}")
 
191
  for state in graph.stream({"messages": [HumanMessage(content=task)]}):
192
  if "__end__" not in state:
193
  st.session_state.outputs.append(str(state))
@@ -199,10 +186,10 @@ if st.button("Run Workflow"):
199
  else:
200
  st.warning("Please enter a task or question.")
201
 
202
- st.subheader("Example Questions:")
203
- for example in example_questions:
204
- st.text(f"- {example}")
205
-
206
  st.subheader("Workflow Output:")
207
  for output in st.session_state.outputs:
208
  st.text(output)
 
 
 
 
 
1
  import os
2
+ import io
3
  import chromadb
4
  import streamlit as st
5
+ import matplotlib.pyplot as plt # For matplotlib graph handling
6
  from langchain_openai import ChatOpenAI
7
  from langchain.agents import AgentExecutor, create_openai_tools_agent
8
  from langchain_core.messages import BaseMessage, HumanMessage
 
24
  from langchain_core.tools import tool
25
  from glob import glob
26
 
27
+ # Clear ChromaDB cache
 
28
  chromadb.api.client.SharedSystemClient.clear_system_cache()
29
 
30
  # Load environment variables
 
31
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
33
 
 
35
  st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
36
  st.stop()
37
 
38
+ # Initialize LLM
39
  llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
40
 
41
  # Utility Functions
 
53
  result = agent.invoke(state)
54
  output_content = result["output"]
55
 
56
+ # Check if Python code generates a graph
57
  if "matplotlib" in output_content or "plt." in output_content:
58
  exec_locals = {}
59
  try:
60
+ exec(output_content, {}, exec_locals)
61
+ fig = plt.gcf()
 
 
62
  buf = io.BytesIO()
63
  fig.savefig(buf, format="png")
64
  buf.seek(0)
 
 
65
  st.session_state.graph_image = buf
66
  except Exception as e:
67
  output_content += f"\nError: {str(e)}"
 
84
  result = retrieval_chain.invoke(question)
85
  return result
86
 
87
+ # Tools Setup
88
  tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
89
  python_repl_tool = PythonREPLTool()
90
 
91
  # Streamlit UI
92
+ st.title("Multi-Agent Workflow with Supervisor")
93
 
 
94
  example_questions = [
 
95
  "What is James McIlroy aiming for in sports?",
96
  "Fetch India's GDP over the past 5 years and draw a line graph.",
97
  "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
98
  ]
99
 
 
100
  source_files = glob("sources/*.txt")
101
  selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
 
102
  uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
103
 
104
+ # Document Handling
105
  all_docs = []
106
  if selected_files:
107
  for file_path in selected_files:
 
114
  all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
115
 
116
  if not all_docs:
117
+ st.warning("Please select files or upload TXT files.")
118
  st.stop()
119
 
120
+ # Document Splitting and Embedding
121
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
122
  split_docs = text_splitter.split_documents(all_docs)
 
123
  embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
124
  db = Chroma.from_documents(split_docs, embeddings)
125
  retriever = db.as_retriever(search_kwargs={"k": 4})
126
 
127
+ # Agents
128
  research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
129
  code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
130
  RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
 
134
  rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
135
 
136
  members = ["RAG", "Researcher", "Coder"]
137
+ system_prompt = "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH."
 
 
 
138
  options = ["FINISH"] + members
139
  function_def = {
140
  "name": "route", "description": "Select the next role.",
 
145
  MessagesPlaceholder(variable_name="messages"),
146
  ("system", "Given the conversation above, who should act next? Select one of: {options}"),
147
  ]).partial(options=str(options), members=", ".join(members))
 
148
  supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
149
 
 
150
  class AgentState(TypedDict):
151
  messages: Annotated[Sequence[BaseMessage], operator.add]
152
  next: str
 
174
  def run_workflow(task):
175
  st.session_state.outputs.clear()
176
  st.session_state.outputs.append(f"User Input: {task}")
177
+ st.session_state.graph_image = None
178
  for state in graph.stream({"messages": [HumanMessage(content=task)]}):
179
  if "__end__" not in state:
180
  st.session_state.outputs.append(str(state))
 
186
  else:
187
  st.warning("Please enter a task or question.")
188
 
 
 
 
 
189
  st.subheader("Workflow Output:")
190
  for output in st.session_state.outputs:
191
  st.text(output)
192
+
193
+ if "graph_image" in st.session_state and st.session_state.graph_image:
194
+ st.subheader("Generated Graph:")
195
+ st.image(st.session_state.graph_image, caption="Generated Line Graph", use_column_width=True)