SHAMIL SHAHBAZ AWAN commited on
Commit
a36f392
·
verified ·
1 Parent(s): 0779acb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -2,8 +2,8 @@ import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
- from io import StringIO
6
  from transformers import pipeline
 
7
 
8
  # Load a lightweight NLP model for query understanding
9
  nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
@@ -15,8 +15,7 @@ def load_file(uploaded_file):
15
  if uploaded_file.type == "text/csv":
16
  data = pd.read_csv(uploaded_file)
17
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
18
- # Load all sheets if it's an Excel file
19
- data = pd.read_excel(uploaded_file, sheet_name=None) # Load all sheets into a dictionary
20
  else:
21
  st.error("Unsupported file type.")
22
  return None
@@ -33,23 +32,19 @@ def classify_query(query, candidate_labels):
33
  return results['labels'][0]
34
  return None
35
 
36
- # Function to generate graph based on user query
37
  def generate_graph(data, query):
38
  """Generate a graph based on user query."""
39
  try:
40
  fig, ax = plt.subplots(figsize=(10, 6))
41
 
42
- # Extract columns from data (if it's a dictionary of sheets, flatten it)
43
- if isinstance(data, dict):
44
- data = pd.concat(data.values(), ignore_index=True) # Combine all sheets into a single dataframe
45
-
46
  # Infer column types
47
  numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
48
  categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
49
  datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
50
 
51
  # Define possible graph types
52
- candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question"]
53
  query_type = classify_query(query, candidate_labels)
54
 
55
  # Provide text-based query response
@@ -104,6 +99,19 @@ def generate_graph(data, query):
104
  else:
105
  response += " Could not find relevant 'department' or 'sales' columns in the dataset."
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
  response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
109
 
@@ -148,18 +156,14 @@ def main():
148
  data = load_file(uploaded_file)
149
 
150
  if data is not None:
151
- if isinstance(data, dict): # For Excel with multiple sheets
152
- st.write("Sheets in Excel file:", list(data.keys()))
153
- sheet_name = st.selectbox("Select a sheet", list(data.keys()))
154
- data = data[sheet_name] # Use the selected sheet
155
  st.write("Dataset preview:", data.head())
156
 
157
- # User input for graph generation or general questions
158
- query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales', or 'Which department has the most sales?')")
159
 
160
- if query:
161
  # Generate the graph based on the query or handle general questions
162
- generate_graph(data, query)
163
 
164
  if __name__ == "__main__":
165
  main()
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
 
5
  from transformers import pipeline
6
+ import numpy as np
7
 
8
  # Load a lightweight NLP model for query understanding
9
  nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
15
  if uploaded_file.type == "text/csv":
16
  data = pd.read_csv(uploaded_file)
17
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
18
+ data = pd.read_excel(uploaded_file)
 
19
  else:
20
  st.error("Unsupported file type.")
21
  return None
 
32
  return results['labels'][0]
33
  return None
34
 
35
+ # Function to generate a graph based on user query
36
  def generate_graph(data, query):
37
  """Generate a graph based on user query."""
38
  try:
39
  fig, ax = plt.subplots(figsize=(10, 6))
40
 
 
 
 
 
41
  # Infer column types
42
  numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
43
  categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
44
  datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
45
 
46
  # Define possible graph types
47
+ candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"]
48
  query_type = classify_query(query, candidate_labels)
49
 
50
  # Provide text-based query response
 
99
  else:
100
  response += " Could not find relevant 'department' or 'sales' columns in the dataset."
101
 
102
+ elif query_type == "general question":
103
+ # Handle general questions
104
+ response = "Analyzing the data for your general question."
105
+ # Apply simple logic to answer the query based on dataset
106
+ if "sales" in query.lower():
107
+ response += " Checking for the highest sales..."
108
+ sales_column = infer_column(data, ["sales", "revenue"])
109
+ if sales_column:
110
+ top_country = data.loc[data[sales_column].idxmax(), 'country'] # Assuming 'country' column exists
111
+ response += f" The country with the highest sales is {top_country}."
112
+ else:
113
+ response += " Could not find a 'sales' column."
114
+
115
  else:
116
  response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
117
 
 
156
  data = load_file(uploaded_file)
157
 
158
  if data is not None:
 
 
 
 
159
  st.write("Dataset preview:", data.head())
160
 
161
+ # User input for query
162
+ user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'Which country has the highest sales?')")
163
 
164
+ if user_query:
165
  # Generate the graph based on the query or handle general questions
166
+ generate_graph(data, user_query)
167
 
168
  if __name__ == "__main__":
169
  main()