SHAMIL SHAHBAZ AWAN commited on
Commit
9dd7961
·
verified ·
1 Parent(s): b5e2763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -36
app.py CHANGED
@@ -21,53 +21,65 @@ def load_file(uploaded_file):
21
  st.error(f"Error loading file: {e}")
22
  return None
23
 
 
 
 
 
 
 
 
 
24
  # Function to generate graph based on user query
25
  def generate_graph(data, query):
26
  """Generate a graph based on user query."""
27
  try:
28
  fig, ax = plt.subplots(figsize=(10, 6))
29
 
30
- if "bar" in query.lower() and "gross sales" in query.lower():
 
 
 
 
 
31
  # Bar chart for countries and gross sales
32
- if 'country' in data.columns and 'gross_sales' in data.columns:
33
- country_data = data[['country', 'gross_sales']].groupby('country').sum().reset_index()
34
- sns.barplot(x='country', y='gross_sales', data=country_data, ax=ax, color='skyblue')
35
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
36
- st.pyplot(fig)
37
- else:
38
- st.error("The dataset must contain 'country' and 'gross_sales' columns.")
39
- elif "line" in query.lower() and "sales trend" in query.lower():
40
  # Line chart for sales trend over time
41
- if 'date' in data.columns and 'sales' in data.columns:
42
- data['date'] = pd.to_datetime(data['date'])
43
- sales_trend = data.groupby('date')['sales'].sum().reset_index()
44
- sns.lineplot(x='date', y='sales', data=sales_trend, ax=ax)
45
- ax.set_title("Sales Trend Over Time")
46
- st.pyplot(fig)
47
- else:
48
- st.error("The dataset must contain 'date' and 'sales' columns.")
49
- elif "scatter" in query.lower() and "relationship" in query.lower():
50
  # Scatter plot for relationships
51
- columns = query.lower().split("between")[-1].strip().split("and")
52
- x_col = columns[0].strip()
53
- y_col = columns[1].strip()
54
-
55
- if x_col in data.columns and y_col in data.columns:
56
- sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
57
- ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
58
- st.pyplot(fig)
59
- else:
60
- st.error(f"The dataset must contain '{x_col}' and '{y_col}' columns.")
 
 
61
  elif "histogram" in query.lower():
62
  # Histogram for a specified column
63
- column = query.lower().split("for")[-1].strip()
64
-
65
- if column in data.columns:
66
- sns.histplot(data[column], bins=20, kde=True, ax=ax, color='green')
67
- ax.set_title(f"Histogram of {column}")
68
- st.pyplot(fig)
69
- else:
70
- st.error(f"The dataset must contain the column '{column}'.")
 
 
71
  else:
72
  st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.")
73
  except Exception as e:
 
21
  st.error(f"Error loading file: {e}")
22
  return None
23
 
24
+ # Function to infer column names based on synonyms
25
+ def infer_column(data, synonyms):
26
+ """Infer a column name based on synonyms."""
27
+ for column in data.columns:
28
+ if column.lower() in synonyms:
29
+ return column
30
+ return None
31
+
32
  # Function to generate graph based on user query
33
  def generate_graph(data, query):
34
  """Generate a graph based on user query."""
35
  try:
36
  fig, ax = plt.subplots(figsize=(10, 6))
37
 
38
+ # Infer column names
39
+ country_col = infer_column(data, {"country", "countries"})
40
+ sales_col = infer_column(data, {"gross_sales", "sales", "revenue"})
41
+ date_col = infer_column(data, {"date", "time"})
42
+
43
+ if "bar" in query.lower() and country_col and sales_col:
44
  # Bar chart for countries and gross sales
45
+ country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index()
46
+ sns.barplot(x=country_col, y=sales_col, data=country_data, ax=ax, color='skyblue')
47
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
48
+ ax.set_title(f"Bar Chart: {country_col} vs {sales_col}")
49
+ st.pyplot(fig)
50
+ elif "line" in query.lower() and date_col and sales_col:
 
 
51
  # Line chart for sales trend over time
52
+ data[date_col] = pd.to_datetime(data[date_col])
53
+ sales_trend = data.groupby(date_col)[sales_col].sum().reset_index()
54
+ sns.lineplot(x=date_col, y=sales_col, data=sales_trend, ax=ax)
55
+ ax.set_title(f"Line Chart: {sales_col} Over Time")
56
+ st.pyplot(fig)
57
+ elif "scatter" in query.lower():
 
 
 
58
  # Scatter plot for relationships
59
+ if "between" in query.lower():
60
+ columns = query.lower().split("between")[-1].strip().split("and")
61
+ if len(columns) == 2:
62
+ x_col = infer_column(data, {columns[0].strip()})
63
+ y_col = infer_column(data, {columns[1].strip()})
64
+
65
+ if x_col and y_col:
66
+ sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
67
+ ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
68
+ st.pyplot(fig)
69
+ return
70
+ st.error("Please specify valid columns for the scatter plot.")
71
  elif "histogram" in query.lower():
72
  # Histogram for a specified column
73
+ if "for" in query.lower():
74
+ column = query.lower().split("for")[-1].strip()
75
+ hist_col = infer_column(data, {column})
76
+
77
+ if hist_col:
78
+ sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
79
+ ax.set_title(f"Histogram of {hist_col}")
80
+ st.pyplot(fig)
81
+ return
82
+ st.error("Please specify a valid column for the histogram.")
83
  else:
84
  st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.")
85
  except Exception as e: