SHAMIL SHAHBAZ AWAN commited on
Commit
0779acb
·
verified ·
1 Parent(s): f45cec7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -21
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
- import os
6
  from io import StringIO
7
  from transformers import pipeline
8
 
@@ -16,7 +15,8 @@ def load_file(uploaded_file):
16
  if uploaded_file.type == "text/csv":
17
  data = pd.read_csv(uploaded_file)
18
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
19
- data = pd.read_excel(uploaded_file)
 
20
  else:
21
  st.error("Unsupported file type.")
22
  return None
@@ -25,17 +25,9 @@ def load_file(uploaded_file):
25
  st.error(f"Error loading file: {e}")
26
  return None
27
 
28
- # Function to infer column names based on synonyms
29
- def infer_column(data, synonyms):
30
- """Infer a column name based on synonyms."""
31
- for column in data.columns:
32
- if column.lower() in synonyms:
33
- return column
34
- return None
35
-
36
  # Function to classify the user query
37
  def classify_query(query, candidate_labels):
38
- """Classify the user query into graph types."""
39
  results = nlp(query, candidate_labels)
40
  if results:
41
  return results['labels'][0]
@@ -47,17 +39,23 @@ def generate_graph(data, query):
47
  try:
48
  fig, ax = plt.subplots(figsize=(10, 6))
49
 
50
- # Infer column names dynamically
 
 
 
 
51
  numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
52
  categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
53
  datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
54
 
55
  # Define possible graph types
56
- candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram"]
57
  query_type = classify_query(query, candidate_labels)
58
 
 
 
59
  if query_type == "bar chart" and categorical_columns and numerical_columns:
60
- # Bar chart for categorical vs numerical
61
  x_col = st.selectbox("Select the categorical column:", categorical_columns)
62
  y_col = st.selectbox("Select the numerical column:", numerical_columns)
63
  aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
@@ -67,7 +65,7 @@ def generate_graph(data, query):
67
  st.pyplot(fig)
68
 
69
  elif query_type == "line chart" and datetime_columns and numerical_columns:
70
- # Line chart for numerical trend over time
71
  x_col = st.selectbox("Select the datetime column:", datetime_columns)
72
  y_col = st.selectbox("Select the numerical column:", numerical_columns)
73
  data[x_col] = pd.to_datetime(data[x_col])
@@ -77,7 +75,7 @@ def generate_graph(data, query):
77
  st.pyplot(fig)
78
 
79
  elif query_type == "scatter plot" and len(numerical_columns) >= 2:
80
- # Scatter plot for numerical relationships
81
  x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
82
  y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
83
  sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
@@ -85,17 +83,44 @@ def generate_graph(data, query):
85
  st.pyplot(fig)
86
 
87
  elif query_type == "histogram" and numerical_columns:
88
- # Histogram for a numerical column
89
  hist_col = st.selectbox("Select the numerical column:", numerical_columns)
90
  sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
91
  ax.set_title(f"Histogram of {hist_col}")
92
  st.pyplot(fig)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  else:
95
- st.error("Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, or histogram.")
 
 
 
 
96
  except Exception as e:
97
  st.error(f"Error generating graph: {e}")
98
 
 
 
 
 
 
 
 
 
99
  # Streamlit App Interface
100
  def main():
101
  st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
@@ -123,13 +148,17 @@ def main():
123
  data = load_file(uploaded_file)
124
 
125
  if data is not None:
 
 
 
 
126
  st.write("Dataset preview:", data.head())
127
 
128
- # User input for graph generation
129
- query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')")
130
 
131
  if query:
132
- # Generate the graph based on the query
133
  generate_graph(data, query)
134
 
135
  if __name__ == "__main__":
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
 
5
  from io import StringIO
6
  from transformers import pipeline
7
 
 
15
  if uploaded_file.type == "text/csv":
16
  data = pd.read_csv(uploaded_file)
17
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
18
+ # Load all sheets if it's an Excel file
19
+ data = pd.read_excel(uploaded_file, sheet_name=None) # Load all sheets into a dictionary
20
  else:
21
  st.error("Unsupported file type.")
22
  return None
 
25
  st.error(f"Error loading file: {e}")
26
  return None
27
 
 
 
 
 
 
 
 
 
28
  # Function to classify the user query
29
  def classify_query(query, candidate_labels):
30
+ """Classify the user query into graph types or general analysis queries."""
31
  results = nlp(query, candidate_labels)
32
  if results:
33
  return results['labels'][0]
 
39
  try:
40
  fig, ax = plt.subplots(figsize=(10, 6))
41
 
42
+ # Extract columns from data (if it's a dictionary of sheets, flatten it)
43
+ if isinstance(data, dict):
44
+ data = pd.concat(data.values(), ignore_index=True) # Combine all sheets into a single dataframe
45
+
46
+ # Infer column types
47
  numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
48
  categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
49
  datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
50
 
51
  # Define possible graph types
52
+ candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question"]
53
  query_type = classify_query(query, candidate_labels)
54
 
55
+ # Provide text-based query response
56
+ response = ""
57
  if query_type == "bar chart" and categorical_columns and numerical_columns:
58
+ response = f"Generating a bar chart for {query}"
59
  x_col = st.selectbox("Select the categorical column:", categorical_columns)
60
  y_col = st.selectbox("Select the numerical column:", numerical_columns)
61
  aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
 
65
  st.pyplot(fig)
66
 
67
  elif query_type == "line chart" and datetime_columns and numerical_columns:
68
+ response = f"Generating a line chart for {query}"
69
  x_col = st.selectbox("Select the datetime column:", datetime_columns)
70
  y_col = st.selectbox("Select the numerical column:", numerical_columns)
71
  data[x_col] = pd.to_datetime(data[x_col])
 
75
  st.pyplot(fig)
76
 
77
  elif query_type == "scatter plot" and len(numerical_columns) >= 2:
78
+ response = f"Generating a scatter plot for {query}"
79
  x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
80
  y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
81
  sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
 
83
  st.pyplot(fig)
84
 
85
  elif query_type == "histogram" and numerical_columns:
86
+ response = f"Generating a histogram for {query}"
87
  hist_col = st.selectbox("Select the numerical column:", numerical_columns)
88
  sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
89
  ax.set_title(f"Histogram of {hist_col}")
90
  st.pyplot(fig)
91
 
92
+ elif query_type == "sales question":
93
+ # General sales-related question (e.g., "Which department has the most sales?")
94
+ response = "Analyzing the sales data for your query."
95
+ # Assuming the file has columns like "Department" and "Sales"
96
+ department_column = infer_column(data, ["department", "dept"])
97
+ sales_column = infer_column(data, ["sales", "revenue"])
98
+
99
+ if department_column and sales_column:
100
+ # Answer the query: Which department has the most sales?
101
+ top_department = data.groupby(department_column)[sales_column].sum().idxmax()
102
+ top_sales = data.groupby(department_column)[sales_column].sum().max()
103
+ response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}."
104
+ else:
105
+ response += " Could not find relevant 'department' or 'sales' columns in the dataset."
106
+
107
  else:
108
+ response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
109
+
110
+ # Show text-based response
111
+ st.write(response)
112
+
113
  except Exception as e:
114
  st.error(f"Error generating graph: {e}")
115
 
116
+ # Helper function to infer column names based on synonyms
117
+ def infer_column(data, synonyms):
118
+ """Infer a column name based on synonyms."""
119
+ for column in data.columns:
120
+ if column.lower() in synonyms:
121
+ return column
122
+ return None
123
+
124
  # Streamlit App Interface
125
  def main():
126
  st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
 
148
  data = load_file(uploaded_file)
149
 
150
  if data is not None:
151
+ if isinstance(data, dict): # For Excel with multiple sheets
152
+ st.write("Sheets in Excel file:", list(data.keys()))
153
+ sheet_name = st.selectbox("Select a sheet", list(data.keys()))
154
+ data = data[sheet_name] # Use the selected sheet
155
  st.write("Dataset preview:", data.head())
156
 
157
+ # User input for graph generation or general questions
158
+ query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales', or 'Which department has the most sales?')")
159
 
160
  if query:
161
+ # Generate the graph based on the query or handle general questions
162
  generate_graph(data, query)
163
 
164
  if __name__ == "__main__":