Divya Bharambe commited on
Commit
b61d46f
·
verified ·
1 Parent(s): b676402

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import re
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import plotly.express as px
8
+ import plotly.io as pio
9
+ import asyncio
10
+ import nest_asyncio
11
+ import json
12
+ from plotly.io import from_json
13
+
14
+ # Fix Streamlit event loop issue
15
+ nest_asyncio.apply()
16
+
17
+ # Updated LangChain imports
18
+ from langchain_community.document_loaders import PyPDFLoader
19
+ from langchain_text_splitters import CharacterTextSplitter
20
+ from langchain_huggingface import HuggingFaceEmbeddings # Updated import
21
+ from langchain_community.vectorstores import FAISS
22
+ from langchain.chains import RetrievalQA
23
+ from langchain_experimental.agents import create_pandas_dataframe_agent
24
+ from langchain_groq import ChatGroq
25
+ from langchain_core.tools import tool
26
+ from langchain.prompts import ChatPromptTemplate
27
+ from langchain.chains import LLMChain
28
+
29
+ load_dotenv()
30
+
31
+ # Set Plotly default template
32
+ pio.templates.default = "plotly_white"
33
+
34
+ st.set_page_config(page_title="Chatlytics: Business Data Insights", layout="wide")
35
+ st.title("📊 Chatlytics: Business Data Insights Chatbot")
36
+
37
+ # Initialize session state
38
+ if "qa_chain" not in st.session_state:
39
+ st.session_state.qa_chain = None
40
+ if "df" not in st.session_state:
41
+ st.session_state.df = None
42
+ if "data_agent" not in st.session_state:
43
+ st.session_state.data_agent = None
44
+ if "active_mode" not in st.session_state: # Track active document type
45
+ st.session_state.active_mode = None
46
+
47
+ def get_chart_config_llm_chain(llm):
48
+ prompt = ChatPromptTemplate.from_template("""
49
+ You are a data visualization assistant. Based on the user's prompt and the dataset's columns, return a JSON with:
50
+ - chart_type: one of ["bar", "pie", "line", "scatter"]
51
+ - x_axis: (optional)
52
+ - y_axis: (optional)
53
+ - group_by: (optional)
54
+
55
+ Respond in JSON only. No explanation.
56
+
57
+ User prompt: {query}
58
+ Available columns: {columns}
59
+ """)
60
+ return prompt | llm
61
+
62
+ def process_pdf(pdf_path):
63
+ """Process PDF files for document QA"""
64
+ loader = PyPDFLoader(pdf_path)
65
+ pages = loader.load_and_split()
66
+
67
+ text_splitter = CharacterTextSplitter(
68
+ chunk_size=1000,
69
+ chunk_overlap=200
70
+ )
71
+ texts = text_splitter.split_documents(pages)
72
+
73
+ # Updated embeddings initialization
74
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
75
+
76
+ vectorstore = FAISS.from_documents(texts, embeddings)
77
+
78
+ llm = ChatGroq(
79
+ temperature=0,
80
+ model_name="llama3-70b-8192",
81
+ groq_api_key=os.getenv("GROQ_API_KEY")
82
+ )
83
+
84
+ return RetrievalQA.from_chain_type(
85
+ llm=llm,
86
+ chain_type="stuff",
87
+ retriever=vectorstore.as_retriever(),
88
+ return_source_documents=True
89
+ )
90
+
91
+ def process_data_file(file):
92
+ """Process CSV/Excel files into DataFrame"""
93
+ try:
94
+ if file.name.endswith('.csv'):
95
+ df = pd.read_csv(file)
96
+ elif file.name.endswith(('.xls', '.xlsx')):
97
+ df = pd.read_excel(file)
98
+ else:
99
+ return None
100
+
101
+ # Clean data using vectorized operations
102
+ df = df.map(lambda x: x.encode('utf-8', 'ignore').decode('utf-8')
103
+ if isinstance(x, str) else x)
104
+ return df
105
+ except Exception as e:
106
+ st.error(f"Error processing file: {str(e)}")
107
+ return None
108
+
109
+ @tool
110
+ def generate_visualization(query: str) -> str:
111
+ """
112
+ Dynamically generate Plotly visualizations using LLM-based interpretation of user prompts.
113
+ """
114
+ try:
115
+ df = st.session_state.df.copy()
116
+ if df is None or df.empty:
117
+ return "CHART|||NO_DATA|||ANALYSIS|||No data available."
118
+
119
+ llm = ChatGroq(
120
+ temperature=0,
121
+ model_name="llama3-70b-8192",
122
+ groq_api_key=os.getenv("GROQ_API_KEY")
123
+ )
124
+
125
+ chain = get_chart_config_llm_chain(llm)
126
+ result = chain.invoke({
127
+ "query": query,
128
+ "columns": ", ".join(df.columns)
129
+ })
130
+
131
+ from langchain.schema import AIMessage
132
+
133
+ # Ensure it's a string
134
+ if isinstance(result, AIMessage):
135
+ result_text = result.content
136
+ elif isinstance(result, str):
137
+ result_text = result
138
+ else:
139
+ result_text = str(result)
140
+
141
+ config = json.loads(result_text)
142
+
143
+ chart_type = config.get("chart_type", "bar").lower()
144
+ x = config.get("x_axis")
145
+ y = config.get("y_axis")
146
+ group_by = config.get("group_by")
147
+
148
+ # st.write("📊 **DEBUG**: Chart Config from LLM =>", config) # Debug output
149
+
150
+ if group_by and group_by in df.columns:
151
+ agg_df = df[group_by].value_counts().reset_index()
152
+ agg_df.columns = [group_by, "Count"]
153
+ elif x and x in df.columns:
154
+ agg_df = df[x].value_counts().reset_index()
155
+ agg_df.columns = [x, "Count"]
156
+ else:
157
+ return "CHART|||NO_DATA|||ANALYSIS|||Insufficient or invalid columns to generate chart."
158
+
159
+ if chart_type == "pie":
160
+ fig = px.pie(agg_df, names=agg_df.columns[0], values="Count", title=f"{agg_df.columns[0]} Distribution")
161
+ elif chart_type == "line":
162
+ fig = px.line(agg_df, x=agg_df.columns[0], y="Count", title=f"{agg_df.columns[0]} Trend")
163
+ elif chart_type == "scatter":
164
+ fig = px.scatter(agg_df, x=agg_df.columns[0], y="Count", title=f"{agg_df.columns[0]} Scatter")
165
+ else:
166
+ fig = px.bar(agg_df, x=agg_df.columns[0], y="Count", title=f"{agg_df.columns[0]} Bar Chart")
167
+
168
+ return f"CHART|||{fig.to_json()}|||ANALYSIS|||Successfully generated a {chart_type} chart for '{agg_df.columns[0]}'."
169
+
170
+ except Exception as e:
171
+ # st.write("⚠️ **DEBUG**: Exception in generate_visualization =>", str(e))
172
+ return f"CHART|||ERROR|||ANALYSIS|||Error generating chart: {str(e)}"
173
+
174
+ def create_dataframe_agent(df):
175
+ """Create data analysis agent with visualization capability"""
176
+ llm = ChatGroq(
177
+ temperature=0,
178
+ model_name="llama3-70b-8192",
179
+ groq_api_key=os.getenv("GROQ_API_KEY")
180
+ )
181
+
182
+ # Revised prefix with a few-shot example
183
+ prefix = """
184
+ You are a data analysis expert. Follow these rules:
185
+ 1. ALWAYS use generate_visualization for charts
186
+ 2. Never use matplotlib or python_repl_ast
187
+ 3. Provide final answer in the format: CHART|||<chart JSON>|||ANALYSIS|||<analysis text>
188
+ 4. Handle dates carefully
189
+
190
+ Below is an example of how you should respond:
191
+
192
+ EXAMPLE
193
+ -------
194
+ User: "Can you create a pie chart of Sales by Region?"
195
+ Assistant:
196
+ Thought: "I should use the generate_visualization tool to build the chart"
197
+ Action: generate_visualization
198
+ Action Input: "Pie chart of Sales by Region"
199
+
200
+ Observation:
201
+ CHART|||{"data": [...], "layout": {...}}|||ANALYSIS|||Based on the pie chart, Region A leads in sales.
202
+
203
+ # Final Answer from the assistant:
204
+ CHART|||{"data": [...], "layout": {...}}|||ANALYSIS|||Based on the pie chart, Region A leads in sales...
205
+ -------
206
+ END OF EXAMPLE
207
+ """
208
+
209
+ return create_pandas_dataframe_agent(
210
+ llm=llm,
211
+ df=df,
212
+ verbose=True,
213
+ agent_type="openai-tools",
214
+ max_iterations=5,
215
+ extra_tools=[generate_visualization],
216
+ allow_dangerous_code=True,
217
+ prefix=prefix
218
+ )
219
+
220
+
221
+ # Sidebar for file uploads
222
+ with st.sidebar:
223
+ st.header("Upload Files")
224
+
225
+ pdf_file = st.file_uploader("PDF Document", type="pdf")
226
+ data_file = st.file_uploader("Data File (CSV/Excel)", type=["csv", "xls", "xlsx"])
227
+
228
+ # If both are uploaded, show a warning and stop execution
229
+ if pdf_file and data_file:
230
+ st.warning("Please upload only one file at a time! Remove one of them.")
231
+ st.stop()
232
+
233
+ # If there's only a PDF and no CSV
234
+ if pdf_file:
235
+ try:
236
+ with open("temp.pdf", "wb") as f:
237
+ f.write(pdf_file.getbuffer())
238
+ st.session_state.qa_chain = process_pdf("temp.pdf")
239
+ st.session_state.active_mode = "pdf"
240
+ # Clear data file context
241
+ st.session_state.df = None
242
+ st.session_state.data_agent = None
243
+ st.session_state.current_data_file = None
244
+ st.success("PDF document processed!")
245
+ except Exception as e:
246
+ st.error(f"PDF Error: {str(e)}")
247
+
248
+ # If there's only a CSV (and no PDF)
249
+ if data_file and data_file.name != st.session_state.get('current_data_file'):
250
+ with st.spinner("Analyzing data file..."):
251
+ df = process_data_file(data_file)
252
+ if df is not None:
253
+ st.session_state.df = df
254
+ st.session_state.data_agent = create_dataframe_agent(df)
255
+ st.session_state.current_data_file = data_file.name
256
+ st.session_state.active_mode = "data"
257
+ # Clear PDF context
258
+ st.session_state.qa_chain = None
259
+ st.success("Data file processed!")
260
+
261
+ # Chat interface
262
+ if prompt := st.chat_input("Ask about your data or document"):
263
+ # st.write("🔎 **DEBUG**: User propt =>", prompt) # Debug statement
264
+
265
+ # Check which mode is active
266
+ if st.session_state.active_mode == "data" and st.session_state.data_agent and st.session_state.df is not None:
267
+ try:
268
+ response = st.session_state.data_agent.invoke({"input": prompt})
269
+
270
+ # DEBUG: Show the raw response
271
+ # st.write("🔎 **DEBUG**: Agent response =>", response)
272
+
273
+ if isinstance(response, dict) and "output" in response:
274
+ output_text = response["output"]
275
+ elif isinstance(response, str):
276
+ output_text = response
277
+ else:
278
+ output_text = str(response)
279
+
280
+ # DEBUG: Show the final output text
281
+ # st.write("🔎 **DEBUG**: Output text =>", output_text)
282
+
283
+ with st.chat_message("assistant"):
284
+ # Check if it contains CHART|||
285
+ if "CHART|||" in output_text:
286
+ parts = output_text.split("|||")
287
+ if len(parts) >= 4:
288
+ chart_json = parts[1] # "NO_DATA" or actual JSON
289
+ analysis_text = parts[3]
290
+
291
+ if chart_json == "NO_DATA":
292
+ # No valid chart, but still show the "analysis_text"
293
+ st.markdown("**Analysis (No Chart):**")
294
+ st.write(analysis_text)
295
+ else:
296
+ # Attempt to load a real chart
297
+ try:
298
+ fig = from_json(chart_json)
299
+ st.plotly_chart(fig, use_container_width=True)
300
+ except Exception as e:
301
+ st.error("⚠️ Could not render chart.")
302
+ st.code(chart_json, language="json")
303
+
304
+ # Then show the LLM’s analysis
305
+ st.markdown("**Analysis:**")
306
+ st.write(analysis_text)
307
+ else:
308
+ st.warning("CHART message has unexpected format.")
309
+ else:
310
+ # If "CHART|||" not in output_text at all, show the entire text
311
+ st.write(output_text)
312
+
313
+
314
+ # Always show data sample
315
+ st.write("**Data Sample:**")
316
+ st.dataframe(st.session_state.df.sample(3))
317
+
318
+ except Exception as e:
319
+ # st.write("⚠️ **DEBUG**: Exception in data block =>", str(e))
320
+ st.error(f"Data Analysis Error: {str(e)}")
321
+
322
+ elif st.session_state.active_mode == "pdf" and st.session_state.qa_chain:
323
+ try:
324
+ result = st.session_state.qa_chain({"query": prompt})
325
+ with st.chat_message("assistant"):
326
+ st.write(result["result"])
327
+ with st.expander("Source Context"):
328
+ st.write(result["source_documents"][0].page_content)
329
+ except Exception as e:
330
+ # st.write("⚠️ **DEBUG**: Exception in pdf block =>", str(e))
331
+ st.error(f"Document Query Error: {str(e)}")
332
+
333
+ else:
334
+ st.warning("Please upload a file first!")
335
+
336
+ if not os.getenv("GROQ_API_KEY"):
337
+ st.error("Missing GROQ_API_KEY in .env file!")