Teoman21 commited on
Commit
f81a8b5
·
1 Parent(s): ec32a5b

fix: visualiztion refactor to matplotlib now working as intended

Browse files
Files changed (5) hide show
  1. app.py +6 -13
  2. filtered_htzxc454.csv +0 -0
  3. requirements.txt +1 -2
  4. utils.py +1 -1
  5. visualizations.py +146 -47
app.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
10
- import plotly.graph_objects as go
11
 
12
  from data_processor import (
13
  DatasetBundle,
@@ -35,7 +35,6 @@ from visualizations import (
35
  create_distribution_plot,
36
  create_scatter_plot,
37
  create_time_series_plot,
38
- figure_to_png_bytes,
39
  )
40
 
41
 
@@ -184,7 +183,7 @@ def _populate_column_options(
184
  dropdown(datetime_cols), # date filter column
185
  gr.update(choices=[], value=[], visible=False, interactive=False), # categorical values reset
186
  dropdown(categorical), # categorical filter column
187
- dropdown(datetime_cols, defaults.get("datetime")), # time series date
188
  dropdown(numeric, defaults.get("numeric")), # time series value
189
  dropdown(numeric), # distribution numeric
190
  dropdown(categorical), # category column
@@ -328,7 +327,7 @@ def _generate_chart(
328
  scatter_x: Optional[str],
329
  scatter_y: Optional[str],
330
  scatter_color: Optional[str],
331
- ) -> Tuple[Optional[go.Figure], Optional[go.Figure], str]:
332
  """Create a visualization based on user selections."""
333
  state = _ensure_state(state)
334
  try:
@@ -376,7 +375,7 @@ def _download_filtered(state) -> str:
376
  return temp.name
377
 
378
 
379
- def _download_chart(fig: Optional[go.Figure]) -> str:
380
  """Export the most recent chart to PNG."""
381
  if fig is None:
382
  raise ValueError("Generate a visualization before exporting.")
@@ -521,8 +520,7 @@ def create_dashboard():
521
 
522
  generate_chart_button = gr.Button("Generate Visualization", variant="primary")
523
  chart_output = gr.Plot(label="Visualization")
524
- download_chart_button = gr.Button("Download Chart as PNG", variant="secondary")
525
- chart_file_output = gr.File(label="Chart PNG", interactive=False)
526
 
527
  with gr.Tab("Insights"):
528
  insights_status = gr.Markdown()
@@ -716,12 +714,7 @@ def create_dashboard():
716
  outputs=[last_figure_state, chart_output, viz_status],
717
  )
718
 
719
- download_chart_button.click(
720
- fn=_download_chart,
721
- inputs=[last_figure_state],
722
- outputs=[chart_file_output],
723
- )
724
-
725
  generate_insights_button.click(
726
  fn=_generate_insights,
727
  inputs=[
 
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ import matplotlib.figure as mpl_fig
11
 
12
  from data_processor import (
13
  DatasetBundle,
 
35
  create_distribution_plot,
36
  create_scatter_plot,
37
  create_time_series_plot,
 
38
  )
39
 
40
 
 
183
  dropdown(datetime_cols), # date filter column
184
  gr.update(choices=[], value=[], visible=False, interactive=False), # categorical values reset
185
  dropdown(categorical), # categorical filter column
186
+ dropdown(all_columns, defaults.get("datetime")), # time series date
187
  dropdown(numeric, defaults.get("numeric")), # time series value
188
  dropdown(numeric), # distribution numeric
189
  dropdown(categorical), # category column
 
327
  scatter_x: Optional[str],
328
  scatter_y: Optional[str],
329
  scatter_color: Optional[str],
330
+ ) -> Tuple[Optional[mpl_fig.Figure], Optional[mpl_fig.Figure], str]:
331
  """Create a visualization based on user selections."""
332
  state = _ensure_state(state)
333
  try:
 
375
  return temp.name
376
 
377
 
378
+ def _download_chart(fig: Optional[mpl_fig.Figure]) -> str:
379
  """Export the most recent chart to PNG."""
380
  if fig is None:
381
  raise ValueError("Generate a visualization before exporting.")
 
520
 
521
  generate_chart_button = gr.Button("Generate Visualization", variant="primary")
522
  chart_output = gr.Plot(label="Visualization")
523
+
 
524
 
525
  with gr.Tab("Insights"):
526
  insights_status = gr.Markdown()
 
714
  outputs=[last_figure_state, chart_output, viz_status],
715
  )
716
 
717
+
 
 
 
 
 
718
  generate_insights_button.click(
719
  fn=_generate_insights,
720
  inputs=[
filtered_htzxc454.csv DELETED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  gradio==4.42.0
2
  pandas>=2.0,<3.0
3
- plotly>=5.18
4
- kaleido>=0.2.1
5
  numpy>=1.24
6
  openpyxl>=3.1
7
  huggingface_hub<0.25.0
 
1
  gradio==4.42.0
2
  pandas>=2.0,<3.0
3
+ matplotlib>=3.8.0
 
4
  numpy>=1.24
5
  openpyxl>=3.1
6
  huggingface_hub<0.25.0
utils.py CHANGED
@@ -58,7 +58,7 @@ def coerce_datetime_columns(df: pd.DataFrame, threshold: float = 0.6) -> Tuple[p
58
  non_null_ratio = series.notna().mean()
59
  if non_null_ratio == 0 or non_null_ratio < threshold:
60
  continue
61
- converted = pd.to_datetime(series, errors="coerce", utc=False, infer_datetime_format=True)
62
  success_ratio = converted.notna().mean()
63
  if success_ratio >= threshold:
64
  df[col] = converted
 
58
  non_null_ratio = series.notna().mean()
59
  if non_null_ratio == 0 or non_null_ratio < threshold:
60
  continue
61
+ converted = pd.to_datetime(series, errors="coerce", utc=False)
62
  success_ratio = converted.notna().mean()
63
  if success_ratio >= threshold:
64
  df[col] = converted
visualizations.py CHANGED
@@ -6,9 +6,14 @@ from abc import ABC, abstractmethod
6
  from io import BytesIO
7
  from typing import Any, Dict, Iterable, Optional
8
 
 
 
 
9
  import pandas as pd
10
- import plotly.express as px
11
- import plotly.graph_objects as go
 
 
12
 
13
  AGGREGATIONS = {
14
  "sum": "sum",
@@ -22,8 +27,8 @@ class VisualizationStrategy(ABC):
22
  """Abstract base class for visualization strategies."""
23
 
24
  @abstractmethod
25
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
26
- """Generate a Plotly figure from the provided dataframe and arguments."""
27
  pass
28
 
29
  def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None:
@@ -32,11 +37,17 @@ class VisualizationStrategy(ABC):
32
  if missing:
33
  raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}")
34
 
 
 
 
 
 
 
35
 
36
  class TimeSeriesStrategy(VisualizationStrategy):
37
  """Strategy for generating time-series plots."""
38
 
39
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
40
  date_column = kwargs.get("date_column")
41
  value_column = kwargs.get("value_column")
42
  aggregation = kwargs.get("aggregation", "sum")
@@ -53,21 +64,29 @@ class TimeSeriesStrategy(VisualizationStrategy):
53
  subset = df.loc[date_series.notna(), [date_column, value_column]].copy()
54
  subset[date_column] = pd.to_datetime(subset[date_column])
55
  grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index()
 
 
 
56
 
57
- fig = px.line(
58
- grouped,
59
- x=date_column,
60
- y=value_column,
61
- title=f"{value_column} over time ({aggregation})",
62
- )
63
- fig.update_layout(xaxis_title=date_column, yaxis_title=value_column)
 
 
 
 
 
64
  return fig
65
 
66
 
67
  class DistributionStrategy(VisualizationStrategy):
68
  """Strategy for generating distribution plots (histogram/box)."""
69
 
70
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
71
  column = kwargs.get("column")
72
  plot_type = kwargs.get("plot_type", "histogram")
73
 
@@ -75,26 +94,34 @@ class DistributionStrategy(VisualizationStrategy):
75
  raise ValueError("Numeric column is required for Distribution plot.")
76
 
77
  self.validate_columns(df, [column])
 
 
78
  numeric_series = pd.to_numeric(df[column], errors="coerce").dropna()
79
  if numeric_series.empty:
80
  raise ValueError("Selected column does not contain numeric data.")
81
 
 
 
 
82
  if plot_type == "box":
83
- fig = px.box(numeric_series, y=column, points="suspectedoutliers", title=f"Distribution of {column}")
 
 
 
84
  else:
85
- fig = px.histogram(
86
- numeric_series,
87
- nbins=30,
88
- title=f"Distribution of {column}",
89
- )
90
- fig.update_layout(xaxis_title=column, yaxis_title="Frequency")
91
  return fig
92
 
93
 
94
  class CategoryStrategy(VisualizationStrategy):
95
  """Strategy for generating categorical charts (bar/pie)."""
96
 
97
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
98
  category_column = kwargs.get("category_column")
99
  value_column = kwargs.get("value_column")
100
  aggregation = kwargs.get("aggregation", "sum")
@@ -114,11 +141,29 @@ class CategoryStrategy(VisualizationStrategy):
114
  .sort_values(by=value_column, ascending=False)
115
  )
116
 
 
 
 
117
  if chart_type == "pie":
118
- fig = px.pie(grouped, names=category_column, values=value_column, title=f"{value_column} by {category_column}")
 
 
 
 
 
 
 
119
  else:
120
- fig = px.bar(grouped, x=category_column, y=value_column, title=f"{value_column} by {category_column}")
121
- fig.update_layout(xaxis_title=category_column, yaxis_title=f"{aggregation} of {value_column}")
 
 
 
 
 
 
 
 
122
 
123
  return fig
124
 
@@ -126,7 +171,7 @@ class CategoryStrategy(VisualizationStrategy):
126
  class ScatterStrategy(VisualizationStrategy):
127
  """Strategy for generating scatter plots."""
128
 
129
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
130
  x_column = kwargs.get("x_column")
131
  y_column = kwargs.get("y_column")
132
  color_column = kwargs.get("color_column")
@@ -139,46 +184,100 @@ class ScatterStrategy(VisualizationStrategy):
139
  columns.append(color_column)
140
  self.validate_columns(df, columns)
141
 
142
- fig = px.scatter(df, x=x_column, y=y_column, color=color_column, title=f"{y_column} vs {x_column}")
143
- fig.update_layout(xaxis_title=x_column, yaxis_title=y_column)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return fig
145
 
146
 
147
  class CorrelationHeatmapStrategy(VisualizationStrategy):
148
  """Strategy for generating correlation heatmaps."""
149
 
150
- def generate(self, df: pd.DataFrame, **kwargs: Any) -> go.Figure:
151
- numeric_df = df.select_dtypes(include=["number"])
152
  if numeric_df.shape[1] < 2:
153
  raise ValueError("At least two numeric columns are required for a correlation heatmap.")
154
 
 
 
 
 
 
155
  corr = numeric_df.corr()
156
- fig = px.imshow(
157
- corr,
158
- text_auto=True,
159
- title="Correlation Heatmap",
160
- color_continuous_scale="RdBu",
161
- aspect="auto",
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  return fig
164
 
165
 
166
- def figure_to_png_bytes(fig: go.Figure) -> BytesIO:
167
  """Export the figure to an in-memory PNG buffer."""
168
- try:
169
- image_bytes = fig.to_image(format="png")
170
- except ValueError as exc: # pragma: no cover - fallback for environments without kaleido
171
- raise ValueError("PNG export requires the 'kaleido' package. Please install it to enable downloads.") from exc
172
- return BytesIO(image_bytes)
173
 
174
 
175
- def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") -> go.Figure:
176
  """Generate a time-series plot using the TimeSeriesStrategy."""
177
  strategy = TimeSeriesStrategy()
178
  return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation)
179
 
180
 
181
- def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") -> go.Figure:
182
  """Generate a distribution plot using the DistributionStrategy."""
183
  strategy = DistributionStrategy()
184
  return strategy.generate(df, column=column, plot_type=plot_type)
@@ -186,7 +285,7 @@ def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "hi
186
 
187
  def create_category_plot(
188
  df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar"
189
- ) -> go.Figure:
190
  """Generate a category plot using the CategoryStrategy."""
191
  strategy = CategoryStrategy()
192
  return strategy.generate(
@@ -196,13 +295,13 @@ def create_category_plot(
196
 
197
  def create_scatter_plot(
198
  df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None
199
- ) -> go.Figure:
200
  """Generate a scatter plot using the ScatterStrategy."""
201
  strategy = ScatterStrategy()
202
  return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column)
203
 
204
 
205
- def create_correlation_heatmap(df: pd.DataFrame) -> go.Figure:
206
  """Generate a correlation heatmap using the CorrelationHeatmapStrategy."""
207
  strategy = CorrelationHeatmapStrategy()
208
  return strategy.generate(df)
 
6
  from io import BytesIO
7
  from typing import Any, Dict, Iterable, Optional
8
 
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.figure import Figure
12
  import pandas as pd
13
+ import numpy as np
14
+
15
+ # Use a non-interactive backend to avoid issues in some environments
16
+ matplotlib.use('Agg')
17
 
18
  AGGREGATIONS = {
19
  "sum": "sum",
 
27
  """Abstract base class for visualization strategies."""
28
 
29
  @abstractmethod
30
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
31
+ """Generate a Matplotlib figure from the provided dataframe and arguments."""
32
  pass
33
 
34
  def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None:
 
37
  if missing:
38
  raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}")
39
 
40
+ def _create_figure(self) -> Figure:
41
+ """Helper to create a standard figure with tight layout."""
42
+ fig = Figure(figsize=(10, 6))
43
+ fig.set_layout_engine("tight")
44
+ return fig
45
+
46
 
47
  class TimeSeriesStrategy(VisualizationStrategy):
48
  """Strategy for generating time-series plots."""
49
 
50
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
51
  date_column = kwargs.get("date_column")
52
  value_column = kwargs.get("value_column")
53
  aggregation = kwargs.get("aggregation", "sum")
 
64
  subset = df.loc[date_series.notna(), [date_column, value_column]].copy()
65
  subset[date_column] = pd.to_datetime(subset[date_column])
66
  grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index()
67
+
68
+ # Sort by date to ensure the line plot makes sense
69
+ grouped = grouped.sort_values(by=date_column)
70
 
71
+ fig = self._create_figure()
72
+ ax = fig.add_subplot(111)
73
+
74
+ ax.plot(grouped[date_column], grouped[value_column], marker='o', linestyle='-')
75
+ ax.set_title(f"{value_column} over time ({aggregation})")
76
+ ax.set_xlabel(date_column)
77
+ ax.set_ylabel(value_column)
78
+ ax.grid(True, linestyle='--', alpha=0.7)
79
+
80
+ # Rotate date labels for better readability
81
+ fig.autofmt_xdate()
82
+
83
  return fig
84
 
85
 
86
  class DistributionStrategy(VisualizationStrategy):
87
  """Strategy for generating distribution plots (histogram/box)."""
88
 
89
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
90
  column = kwargs.get("column")
91
  plot_type = kwargs.get("plot_type", "histogram")
92
 
 
94
  raise ValueError("Numeric column is required for Distribution plot.")
95
 
96
  self.validate_columns(df, [column])
97
+
98
+ # Convert column to numeric, dropping non-numeric values
99
  numeric_series = pd.to_numeric(df[column], errors="coerce").dropna()
100
  if numeric_series.empty:
101
  raise ValueError("Selected column does not contain numeric data.")
102
 
103
+ fig = self._create_figure()
104
+ ax = fig.add_subplot(111)
105
+
106
  if plot_type == "box":
107
+ ax.boxplot(numeric_series, vert=True, patch_artist=True)
108
+ ax.set_title(f"Distribution of {column}")
109
+ ax.set_ylabel(column)
110
+ ax.set_xticks([]) # Remove x-axis ticks for single boxplot
111
  else:
112
+ ax.hist(numeric_series, bins=30, edgecolor='black', alpha=0.7)
113
+ ax.set_title(f"Distribution of {column}")
114
+ ax.set_xlabel(column)
115
+ ax.set_ylabel("Frequency")
116
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
117
+
118
  return fig
119
 
120
 
121
  class CategoryStrategy(VisualizationStrategy):
122
  """Strategy for generating categorical charts (bar/pie)."""
123
 
124
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
125
  category_column = kwargs.get("category_column")
126
  value_column = kwargs.get("value_column")
127
  aggregation = kwargs.get("aggregation", "sum")
 
141
  .sort_values(by=value_column, ascending=False)
142
  )
143
 
144
+ fig = self._create_figure()
145
+ ax = fig.add_subplot(111)
146
+
147
  if chart_type == "pie":
148
+ # Pie chart
149
+ wedges, texts, autotexts = ax.pie(
150
+ grouped[value_column],
151
+ labels=grouped[category_column],
152
+ autopct='%1.1f%%',
153
+ startangle=90
154
+ )
155
+ ax.set_title(f"{value_column} by {category_column}")
156
  else:
157
+ # Bar chart
158
+ bars = ax.bar(grouped[category_column], grouped[value_column], alpha=0.7, edgecolor='black')
159
+ ax.set_title(f"{value_column} by {category_column}")
160
+ ax.set_xlabel(category_column)
161
+ ax.set_ylabel(f"{aggregation} of {value_column}")
162
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
163
+
164
+ # Rotate x labels if there are many categories
165
+ if len(grouped) > 5:
166
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
167
 
168
  return fig
169
 
 
171
  class ScatterStrategy(VisualizationStrategy):
172
  """Strategy for generating scatter plots."""
173
 
174
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
175
  x_column = kwargs.get("x_column")
176
  y_column = kwargs.get("y_column")
177
  color_column = kwargs.get("color_column")
 
184
  columns.append(color_column)
185
  self.validate_columns(df, columns)
186
 
187
+ # Convert X and Y columns to numeric where possible
188
+ x = pd.to_numeric(df[x_column], errors="coerce")
189
+ y = pd.to_numeric(df[y_column], errors="coerce")
190
+
191
+ valid_mask = ~(x.isna() | y.isna())
192
+ if valid_mask.sum() == 0:
193
+ raise ValueError("Scatter plot requires numeric data in both X and Y columns.")
194
+
195
+ plot_df = df.loc[valid_mask].copy()
196
+ plot_df[x_column] = x[valid_mask]
197
+ plot_df[y_column] = y[valid_mask]
198
+
199
+ fig = self._create_figure()
200
+ ax = fig.add_subplot(111)
201
+
202
+ if color_column:
203
+ # If color column is present, we need to map categories to colors
204
+ # or use a colormap if numeric
205
+ c_data = plot_df[color_column]
206
+ if pd.api.types.is_numeric_dtype(c_data):
207
+ sc = ax.scatter(plot_df[x_column], plot_df[y_column], c=c_data, cmap='viridis', alpha=0.7)
208
+ fig.colorbar(sc, ax=ax, label=color_column)
209
+ else:
210
+ # Categorical coloring
211
+ categories = c_data.unique()
212
+ colors = plt.cm.tab10(np.linspace(0, 1, len(categories)))
213
+ for cat, color in zip(categories, colors):
214
+ mask = c_data == cat
215
+ ax.scatter(plot_df.loc[mask, x_column], plot_df.loc[mask, y_column], label=str(cat), color=color, alpha=0.7)
216
+ ax.legend(title=color_column)
217
+ else:
218
+ ax.scatter(plot_df[x_column], plot_df[y_column], alpha=0.7)
219
+
220
+ ax.set_title(f"{y_column} vs {x_column}")
221
+ ax.set_xlabel(x_column)
222
+ ax.set_ylabel(y_column)
223
+ ax.grid(True, linestyle='--', alpha=0.7)
224
+
225
  return fig
226
 
227
 
228
  class CorrelationHeatmapStrategy(VisualizationStrategy):
229
  """Strategy for generating correlation heatmaps."""
230
 
231
+ def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
232
+ numeric_df = df.select_dtypes(include=["number"]).copy()
233
  if numeric_df.shape[1] < 2:
234
  raise ValueError("At least two numeric columns are required for a correlation heatmap.")
235
 
236
+ # Drop rows that are completely NaN in numeric columns
237
+ numeric_df = numeric_df.dropna(how="all")
238
+ if numeric_df.empty:
239
+ raise ValueError("No valid numeric data available for correlation heatmap.")
240
+
241
  corr = numeric_df.corr()
242
+
243
+ fig = self._create_figure()
244
+ ax = fig.add_subplot(111)
245
+
246
+ cax = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1)
247
+ fig.colorbar(cax, ax=ax)
248
+
249
+ # Set ticks
250
+ ax.set_xticks(range(len(corr.columns)))
251
+ ax.set_yticks(range(len(corr.columns)))
252
+ ax.set_xticklabels(corr.columns, rotation=45, ha="right")
253
+ ax.set_yticklabels(corr.columns)
254
+
255
+ # Annotate values
256
+ for i in range(len(corr.columns)):
257
+ for j in range(len(corr.columns)):
258
+ text = ax.text(j, i, f"{corr.iloc[i, j]:.2f}",
259
+ ha="center", va="center", color="black")
260
+
261
+ ax.set_title("Correlation Heatmap")
262
+
263
  return fig
264
 
265
 
266
+ def figure_to_png_bytes(fig: Figure) -> BytesIO:
267
  """Export the figure to an in-memory PNG buffer."""
268
+ buf = BytesIO()
269
+ fig.savefig(buf, format="png")
270
+ buf.seek(0)
271
+ return buf
 
272
 
273
 
274
+ def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") -> Figure:
275
  """Generate a time-series plot using the TimeSeriesStrategy."""
276
  strategy = TimeSeriesStrategy()
277
  return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation)
278
 
279
 
280
+ def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") -> Figure:
281
  """Generate a distribution plot using the DistributionStrategy."""
282
  strategy = DistributionStrategy()
283
  return strategy.generate(df, column=column, plot_type=plot_type)
 
285
 
286
  def create_category_plot(
287
  df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar"
288
+ ) -> Figure:
289
  """Generate a category plot using the CategoryStrategy."""
290
  strategy = CategoryStrategy()
291
  return strategy.generate(
 
295
 
296
  def create_scatter_plot(
297
  df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None
298
+ ) -> Figure:
299
  """Generate a scatter plot using the ScatterStrategy."""
300
  strategy = ScatterStrategy()
301
  return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column)
302
 
303
 
304
+ def create_correlation_heatmap(df: pd.DataFrame) -> Figure:
305
  """Generate a correlation heatmap using the CorrelationHeatmapStrategy."""
306
  strategy = CorrelationHeatmapStrategy()
307
  return strategy.generate(df)