SHAMIL SHAHBAZ AWAN commited on
Commit
508be90
·
verified ·
1 Parent(s): 66bd671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -48
app.py CHANGED
@@ -7,7 +7,7 @@ from io import StringIO
7
  from transformers import pipeline
8
 
9
  # Load a lightweight NLP model for query understanding
10
- nlp = pipeline("text-classification", model="distilbert-base-uncased", tokenizer="distilbert-base-uncased")
11
 
12
  # Function to load the uploaded file (CSV or Excel)
13
  def load_file(uploaded_file):
@@ -34,11 +34,11 @@ def infer_column(data, synonyms):
34
  return None
35
 
36
  # Function to classify the user query
37
- def classify_query(query):
38
  """Classify the user query into graph types."""
39
- results = nlp(query)
40
  if results:
41
- return results[0]['label']
42
  return None
43
 
44
  # Function to generate graph based on user query
@@ -47,55 +47,52 @@ def generate_graph(data, query):
47
  try:
48
  fig, ax = plt.subplots(figsize=(10, 6))
49
 
50
- # Infer column names
51
- country_col = infer_column(data, {"country", "countries"})
52
- sales_col = infer_column(data, {"gross_sales", "sales", "revenue"})
53
- date_col = infer_column(data, {"date", "time"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- query_type = classify_query(query)
 
 
 
 
 
 
 
 
56
 
57
- if "bar" in query.lower() and country_col and sales_col:
58
- # Bar chart for countries and gross sales
59
- country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index()
60
- sns.barplot(x=country_col, y=sales_col, data=country_data, ax=ax, color='skyblue')
61
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
62
- ax.set_title(f"Bar Chart: {country_col} vs {sales_col}")
63
  st.pyplot(fig)
64
- elif "line" in query.lower() and date_col and sales_col:
65
- # Line chart for sales trend over time
66
- data[date_col] = pd.to_datetime(data[date_col])
67
- sales_trend = data.groupby(date_col)[sales_col].sum().reset_index()
68
- sns.lineplot(x=date_col, y=sales_col, data=sales_trend, ax=ax)
69
- ax.set_title(f"Line Chart: {sales_col} Over Time")
70
  st.pyplot(fig)
71
- elif "scatter" in query.lower():
72
- # Scatter plot for relationships
73
- if "between" in query.lower():
74
- columns = query.lower().split("between")[-1].strip().split("and")
75
- if len(columns) == 2:
76
- x_col = infer_column(data, {columns[0].strip()})
77
- y_col = infer_column(data, {columns[1].strip()})
78
-
79
- if x_col and y_col:
80
- sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
81
- ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
82
- st.pyplot(fig)
83
- return
84
- st.error("Please specify valid columns for the scatter plot.")
85
- elif "histogram" in query.lower():
86
- # Histogram for a specified column
87
- if "for" in query.lower():
88
- column = query.lower().split("for")[-1].strip()
89
- hist_col = infer_column(data, {column})
90
-
91
- if hist_col:
92
- sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
93
- ax.set_title(f"Histogram of {hist_col}")
94
- st.pyplot(fig)
95
- return
96
- st.error("Please specify a valid column for the histogram.")
97
  else:
98
- st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.")
99
  except Exception as e:
100
  st.error(f"Error generating graph: {e}")
101
 
 
7
  from transformers import pipeline
8
 
9
  # Load a lightweight NLP model for query understanding
10
+ nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
 
12
  # Function to load the uploaded file (CSV or Excel)
13
  def load_file(uploaded_file):
 
34
  return None
35
 
36
  # Function to classify the user query
37
+ def classify_query(query, candidate_labels):
38
  """Classify the user query into graph types."""
39
+ results = nlp(query, candidate_labels)
40
  if results:
41
+ return results['labels'][0]
42
  return None
43
 
44
  # Function to generate graph based on user query
 
47
  try:
48
  fig, ax = plt.subplots(figsize=(10, 6))
49
 
50
+ # Infer column names dynamically
51
+ numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
52
+ categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
53
+ datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
54
+
55
+ # Define possible graph types
56
+ candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram"]
57
+ query_type = classify_query(query, candidate_labels)
58
+
59
+ if query_type == "bar chart" and categorical_columns and numerical_columns:
60
+ # Bar chart for categorical vs numerical
61
+ x_col = st.selectbox("Select the categorical column:", categorical_columns)
62
+ y_col = st.selectbox("Select the numerical column:", numerical_columns)
63
+ aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
64
+ sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue')
65
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
66
+ ax.set_title(f"Bar Chart: {x_col} vs {y_col}")
67
+ st.pyplot(fig)
68
 
69
+ elif query_type == "line chart" and datetime_columns and numerical_columns:
70
+ # Line chart for numerical trend over time
71
+ x_col = st.selectbox("Select the datetime column:", datetime_columns)
72
+ y_col = st.selectbox("Select the numerical column:", numerical_columns)
73
+ data[x_col] = pd.to_datetime(data[x_col])
74
+ trend_data = data.groupby(x_col)[y_col].sum().reset_index()
75
+ sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax)
76
+ ax.set_title(f"Line Chart: {y_col} Over Time")
77
+ st.pyplot(fig)
78
 
79
+ elif query_type == "scatter plot" and len(numerical_columns) >= 2:
80
+ # Scatter plot for numerical relationships
81
+ x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
82
+ y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
83
+ sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
84
+ ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
85
  st.pyplot(fig)
86
+
87
+ elif query_type == "histogram" and numerical_columns:
88
+ # Histogram for a numerical column
89
+ hist_col = st.selectbox("Select the numerical column:", numerical_columns)
90
+ sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
91
+ ax.set_title(f"Histogram of {hist_col}")
92
  st.pyplot(fig)
93
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  else:
95
+ st.error("Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, or histogram.")
96
  except Exception as e:
97
  st.error(f"Error generating graph: {e}")
98