fix: visualiztion refactor to matplotlib now working as intended
Browse files- app.py +6 -13
- filtered_htzxc454.csv +0 -0
- requirements.txt +1 -2
- utils.py +1 -1
- visualizations.py +146 -47
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import pandas as pd
|
| 10 |
-
import
|
| 11 |
|
| 12 |
from data_processor import (
|
| 13 |
DatasetBundle,
|
|
@@ -35,7 +35,6 @@ from visualizations import (
|
|
| 35 |
create_distribution_plot,
|
| 36 |
create_scatter_plot,
|
| 37 |
create_time_series_plot,
|
| 38 |
-
figure_to_png_bytes,
|
| 39 |
)
|
| 40 |
|
| 41 |
|
|
@@ -184,7 +183,7 @@ def _populate_column_options(
|
|
| 184 |
dropdown(datetime_cols), # date filter column
|
| 185 |
gr.update(choices=[], value=[], visible=False, interactive=False), # categorical values reset
|
| 186 |
dropdown(categorical), # categorical filter column
|
| 187 |
-
dropdown(
|
| 188 |
dropdown(numeric, defaults.get("numeric")), # time series value
|
| 189 |
dropdown(numeric), # distribution numeric
|
| 190 |
dropdown(categorical), # category column
|
|
@@ -328,7 +327,7 @@ def _generate_chart(
|
|
| 328 |
scatter_x: Optional[str],
|
| 329 |
scatter_y: Optional[str],
|
| 330 |
scatter_color: Optional[str],
|
| 331 |
-
) -> Tuple[Optional[
|
| 332 |
"""Create a visualization based on user selections."""
|
| 333 |
state = _ensure_state(state)
|
| 334 |
try:
|
|
@@ -376,7 +375,7 @@ def _download_filtered(state) -> str:
|
|
| 376 |
return temp.name
|
| 377 |
|
| 378 |
|
| 379 |
-
def _download_chart(fig: Optional[
|
| 380 |
"""Export the most recent chart to PNG."""
|
| 381 |
if fig is None:
|
| 382 |
raise ValueError("Generate a visualization before exporting.")
|
|
@@ -521,8 +520,7 @@ def create_dashboard():
|
|
| 521 |
|
| 522 |
generate_chart_button = gr.Button("Generate Visualization", variant="primary")
|
| 523 |
chart_output = gr.Plot(label="Visualization")
|
| 524 |
-
|
| 525 |
-
chart_file_output = gr.File(label="Chart PNG", interactive=False)
|
| 526 |
|
| 527 |
with gr.Tab("Insights"):
|
| 528 |
insights_status = gr.Markdown()
|
|
@@ -716,12 +714,7 @@ def create_dashboard():
|
|
| 716 |
outputs=[last_figure_state, chart_output, viz_status],
|
| 717 |
)
|
| 718 |
|
| 719 |
-
|
| 720 |
-
fn=_download_chart,
|
| 721 |
-
inputs=[last_figure_state],
|
| 722 |
-
outputs=[chart_file_output],
|
| 723 |
-
)
|
| 724 |
-
|
| 725 |
generate_insights_button.click(
|
| 726 |
fn=_generate_insights,
|
| 727 |
inputs=[
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import pandas as pd
|
| 10 |
+
import matplotlib.figure as mpl_fig
|
| 11 |
|
| 12 |
from data_processor import (
|
| 13 |
DatasetBundle,
|
|
|
|
| 35 |
create_distribution_plot,
|
| 36 |
create_scatter_plot,
|
| 37 |
create_time_series_plot,
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
|
|
|
|
| 183 |
dropdown(datetime_cols), # date filter column
|
| 184 |
gr.update(choices=[], value=[], visible=False, interactive=False), # categorical values reset
|
| 185 |
dropdown(categorical), # categorical filter column
|
| 186 |
+
dropdown(all_columns, defaults.get("datetime")), # time series date
|
| 187 |
dropdown(numeric, defaults.get("numeric")), # time series value
|
| 188 |
dropdown(numeric), # distribution numeric
|
| 189 |
dropdown(categorical), # category column
|
|
|
|
| 327 |
scatter_x: Optional[str],
|
| 328 |
scatter_y: Optional[str],
|
| 329 |
scatter_color: Optional[str],
|
| 330 |
+
) -> Tuple[Optional[mpl_fig.Figure], Optional[mpl_fig.Figure], str]:
|
| 331 |
"""Create a visualization based on user selections."""
|
| 332 |
state = _ensure_state(state)
|
| 333 |
try:
|
|
|
|
| 375 |
return temp.name
|
| 376 |
|
| 377 |
|
| 378 |
+
def _download_chart(fig: Optional[mpl_fig.Figure]) -> str:
|
| 379 |
"""Export the most recent chart to PNG."""
|
| 380 |
if fig is None:
|
| 381 |
raise ValueError("Generate a visualization before exporting.")
|
|
|
|
| 520 |
|
| 521 |
generate_chart_button = gr.Button("Generate Visualization", variant="primary")
|
| 522 |
chart_output = gr.Plot(label="Visualization")
|
| 523 |
+
|
|
|
|
| 524 |
|
| 525 |
with gr.Tab("Insights"):
|
| 526 |
insights_status = gr.Markdown()
|
|
|
|
| 714 |
outputs=[last_figure_state, chart_output, viz_status],
|
| 715 |
)
|
| 716 |
|
| 717 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
generate_insights_button.click(
|
| 719 |
fn=_generate_insights,
|
| 720 |
inputs=[
|
filtered_htzxc454.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
gradio==4.42.0
|
| 2 |
pandas>=2.0,<3.0
|
| 3 |
-
|
| 4 |
-
kaleido>=0.2.1
|
| 5 |
numpy>=1.24
|
| 6 |
openpyxl>=3.1
|
| 7 |
huggingface_hub<0.25.0
|
|
|
|
| 1 |
gradio==4.42.0
|
| 2 |
pandas>=2.0,<3.0
|
| 3 |
+
matplotlib>=3.8.0
|
|
|
|
| 4 |
numpy>=1.24
|
| 5 |
openpyxl>=3.1
|
| 6 |
huggingface_hub<0.25.0
|
utils.py
CHANGED
|
@@ -58,7 +58,7 @@ def coerce_datetime_columns(df: pd.DataFrame, threshold: float = 0.6) -> Tuple[p
|
|
| 58 |
non_null_ratio = series.notna().mean()
|
| 59 |
if non_null_ratio == 0 or non_null_ratio < threshold:
|
| 60 |
continue
|
| 61 |
-
converted = pd.to_datetime(series, errors="coerce", utc=False
|
| 62 |
success_ratio = converted.notna().mean()
|
| 63 |
if success_ratio >= threshold:
|
| 64 |
df[col] = converted
|
|
|
|
| 58 |
non_null_ratio = series.notna().mean()
|
| 59 |
if non_null_ratio == 0 or non_null_ratio < threshold:
|
| 60 |
continue
|
| 61 |
+
converted = pd.to_datetime(series, errors="coerce", utc=False)
|
| 62 |
success_ratio = converted.notna().mean()
|
| 63 |
if success_ratio >= threshold:
|
| 64 |
df[col] = converted
|
visualizations.py
CHANGED
|
@@ -6,9 +6,14 @@ from abc import ABC, abstractmethod
|
|
| 6 |
from io import BytesIO
|
| 7 |
from typing import Any, Dict, Iterable, Optional
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
import pandas as pd
|
| 10 |
-
import
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
AGGREGATIONS = {
|
| 14 |
"sum": "sum",
|
|
@@ -22,8 +27,8 @@ class VisualizationStrategy(ABC):
|
|
| 22 |
"""Abstract base class for visualization strategies."""
|
| 23 |
|
| 24 |
@abstractmethod
|
| 25 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 26 |
-
"""Generate a
|
| 27 |
pass
|
| 28 |
|
| 29 |
def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None:
|
|
@@ -32,11 +37,17 @@ class VisualizationStrategy(ABC):
|
|
| 32 |
if missing:
|
| 33 |
raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}")
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
class TimeSeriesStrategy(VisualizationStrategy):
|
| 37 |
"""Strategy for generating time-series plots."""
|
| 38 |
|
| 39 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 40 |
date_column = kwargs.get("date_column")
|
| 41 |
value_column = kwargs.get("value_column")
|
| 42 |
aggregation = kwargs.get("aggregation", "sum")
|
|
@@ -53,21 +64,29 @@ class TimeSeriesStrategy(VisualizationStrategy):
|
|
| 53 |
subset = df.loc[date_series.notna(), [date_column, value_column]].copy()
|
| 54 |
subset[date_column] = pd.to_datetime(subset[date_column])
|
| 55 |
grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index()
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
fig =
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
)
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return fig
|
| 65 |
|
| 66 |
|
| 67 |
class DistributionStrategy(VisualizationStrategy):
|
| 68 |
"""Strategy for generating distribution plots (histogram/box)."""
|
| 69 |
|
| 70 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 71 |
column = kwargs.get("column")
|
| 72 |
plot_type = kwargs.get("plot_type", "histogram")
|
| 73 |
|
|
@@ -75,26 +94,34 @@ class DistributionStrategy(VisualizationStrategy):
|
|
| 75 |
raise ValueError("Numeric column is required for Distribution plot.")
|
| 76 |
|
| 77 |
self.validate_columns(df, [column])
|
|
|
|
|
|
|
| 78 |
numeric_series = pd.to_numeric(df[column], errors="coerce").dropna()
|
| 79 |
if numeric_series.empty:
|
| 80 |
raise ValueError("Selected column does not contain numeric data.")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
if plot_type == "box":
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
else:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
return fig
|
| 92 |
|
| 93 |
|
| 94 |
class CategoryStrategy(VisualizationStrategy):
|
| 95 |
"""Strategy for generating categorical charts (bar/pie)."""
|
| 96 |
|
| 97 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 98 |
category_column = kwargs.get("category_column")
|
| 99 |
value_column = kwargs.get("value_column")
|
| 100 |
aggregation = kwargs.get("aggregation", "sum")
|
|
@@ -114,11 +141,29 @@ class CategoryStrategy(VisualizationStrategy):
|
|
| 114 |
.sort_values(by=value_column, ascending=False)
|
| 115 |
)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
if chart_type == "pie":
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
else:
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
return fig
|
| 124 |
|
|
@@ -126,7 +171,7 @@ class CategoryStrategy(VisualizationStrategy):
|
|
| 126 |
class ScatterStrategy(VisualizationStrategy):
|
| 127 |
"""Strategy for generating scatter plots."""
|
| 128 |
|
| 129 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 130 |
x_column = kwargs.get("x_column")
|
| 131 |
y_column = kwargs.get("y_column")
|
| 132 |
color_column = kwargs.get("color_column")
|
|
@@ -139,46 +184,100 @@ class ScatterStrategy(VisualizationStrategy):
|
|
| 139 |
columns.append(color_column)
|
| 140 |
self.validate_columns(df, columns)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return fig
|
| 145 |
|
| 146 |
|
| 147 |
class CorrelationHeatmapStrategy(VisualizationStrategy):
|
| 148 |
"""Strategy for generating correlation heatmaps."""
|
| 149 |
|
| 150 |
-
def generate(self, df: pd.DataFrame, **kwargs: Any) ->
|
| 151 |
-
numeric_df = df.select_dtypes(include=["number"])
|
| 152 |
if numeric_df.shape[1] < 2:
|
| 153 |
raise ValueError("At least two numeric columns are required for a correlation heatmap.")
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
corr = numeric_df.corr()
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
return fig
|
| 164 |
|
| 165 |
|
| 166 |
-
def figure_to_png_bytes(fig:
|
| 167 |
"""Export the figure to an in-memory PNG buffer."""
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
return BytesIO(image_bytes)
|
| 173 |
|
| 174 |
|
| 175 |
-
def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") ->
|
| 176 |
"""Generate a time-series plot using the TimeSeriesStrategy."""
|
| 177 |
strategy = TimeSeriesStrategy()
|
| 178 |
return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation)
|
| 179 |
|
| 180 |
|
| 181 |
-
def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") ->
|
| 182 |
"""Generate a distribution plot using the DistributionStrategy."""
|
| 183 |
strategy = DistributionStrategy()
|
| 184 |
return strategy.generate(df, column=column, plot_type=plot_type)
|
|
@@ -186,7 +285,7 @@ def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "hi
|
|
| 186 |
|
| 187 |
def create_category_plot(
|
| 188 |
df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar"
|
| 189 |
-
) ->
|
| 190 |
"""Generate a category plot using the CategoryStrategy."""
|
| 191 |
strategy = CategoryStrategy()
|
| 192 |
return strategy.generate(
|
|
@@ -196,13 +295,13 @@ def create_category_plot(
|
|
| 196 |
|
| 197 |
def create_scatter_plot(
|
| 198 |
df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None
|
| 199 |
-
) ->
|
| 200 |
"""Generate a scatter plot using the ScatterStrategy."""
|
| 201 |
strategy = ScatterStrategy()
|
| 202 |
return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column)
|
| 203 |
|
| 204 |
|
| 205 |
-
def create_correlation_heatmap(df: pd.DataFrame) ->
|
| 206 |
"""Generate a correlation heatmap using the CorrelationHeatmapStrategy."""
|
| 207 |
strategy = CorrelationHeatmapStrategy()
|
| 208 |
return strategy.generate(df)
|
|
|
|
| 6 |
from io import BytesIO
|
| 7 |
from typing import Any, Dict, Iterable, Optional
|
| 8 |
|
| 9 |
+
import matplotlib
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from matplotlib.figure import Figure
|
| 12 |
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
# Use a non-interactive backend to avoid issues in some environments
|
| 16 |
+
matplotlib.use('Agg')
|
| 17 |
|
| 18 |
AGGREGATIONS = {
|
| 19 |
"sum": "sum",
|
|
|
|
| 27 |
"""Abstract base class for visualization strategies."""
|
| 28 |
|
| 29 |
@abstractmethod
|
| 30 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 31 |
+
"""Generate a Matplotlib figure from the provided dataframe and arguments."""
|
| 32 |
pass
|
| 33 |
|
| 34 |
def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None:
|
|
|
|
| 37 |
if missing:
|
| 38 |
raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}")
|
| 39 |
|
| 40 |
+
def _create_figure(self) -> Figure:
|
| 41 |
+
"""Helper to create a standard figure with tight layout."""
|
| 42 |
+
fig = Figure(figsize=(10, 6))
|
| 43 |
+
fig.set_layout_engine("tight")
|
| 44 |
+
return fig
|
| 45 |
+
|
| 46 |
|
| 47 |
class TimeSeriesStrategy(VisualizationStrategy):
|
| 48 |
"""Strategy for generating time-series plots."""
|
| 49 |
|
| 50 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 51 |
date_column = kwargs.get("date_column")
|
| 52 |
value_column = kwargs.get("value_column")
|
| 53 |
aggregation = kwargs.get("aggregation", "sum")
|
|
|
|
| 64 |
subset = df.loc[date_series.notna(), [date_column, value_column]].copy()
|
| 65 |
subset[date_column] = pd.to_datetime(subset[date_column])
|
| 66 |
grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index()
|
| 67 |
+
|
| 68 |
+
# Sort by date to ensure the line plot makes sense
|
| 69 |
+
grouped = grouped.sort_values(by=date_column)
|
| 70 |
|
| 71 |
+
fig = self._create_figure()
|
| 72 |
+
ax = fig.add_subplot(111)
|
| 73 |
+
|
| 74 |
+
ax.plot(grouped[date_column], grouped[value_column], marker='o', linestyle='-')
|
| 75 |
+
ax.set_title(f"{value_column} over time ({aggregation})")
|
| 76 |
+
ax.set_xlabel(date_column)
|
| 77 |
+
ax.set_ylabel(value_column)
|
| 78 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
| 79 |
+
|
| 80 |
+
# Rotate date labels for better readability
|
| 81 |
+
fig.autofmt_xdate()
|
| 82 |
+
|
| 83 |
return fig
|
| 84 |
|
| 85 |
|
| 86 |
class DistributionStrategy(VisualizationStrategy):
|
| 87 |
"""Strategy for generating distribution plots (histogram/box)."""
|
| 88 |
|
| 89 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 90 |
column = kwargs.get("column")
|
| 91 |
plot_type = kwargs.get("plot_type", "histogram")
|
| 92 |
|
|
|
|
| 94 |
raise ValueError("Numeric column is required for Distribution plot.")
|
| 95 |
|
| 96 |
self.validate_columns(df, [column])
|
| 97 |
+
|
| 98 |
+
# Convert column to numeric, dropping non-numeric values
|
| 99 |
numeric_series = pd.to_numeric(df[column], errors="coerce").dropna()
|
| 100 |
if numeric_series.empty:
|
| 101 |
raise ValueError("Selected column does not contain numeric data.")
|
| 102 |
|
| 103 |
+
fig = self._create_figure()
|
| 104 |
+
ax = fig.add_subplot(111)
|
| 105 |
+
|
| 106 |
if plot_type == "box":
|
| 107 |
+
ax.boxplot(numeric_series, vert=True, patch_artist=True)
|
| 108 |
+
ax.set_title(f"Distribution of {column}")
|
| 109 |
+
ax.set_ylabel(column)
|
| 110 |
+
ax.set_xticks([]) # Remove x-axis ticks for single boxplot
|
| 111 |
else:
|
| 112 |
+
ax.hist(numeric_series, bins=30, edgecolor='black', alpha=0.7)
|
| 113 |
+
ax.set_title(f"Distribution of {column}")
|
| 114 |
+
ax.set_xlabel(column)
|
| 115 |
+
ax.set_ylabel("Frequency")
|
| 116 |
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
| 117 |
+
|
| 118 |
return fig
|
| 119 |
|
| 120 |
|
| 121 |
class CategoryStrategy(VisualizationStrategy):
|
| 122 |
"""Strategy for generating categorical charts (bar/pie)."""
|
| 123 |
|
| 124 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 125 |
category_column = kwargs.get("category_column")
|
| 126 |
value_column = kwargs.get("value_column")
|
| 127 |
aggregation = kwargs.get("aggregation", "sum")
|
|
|
|
| 141 |
.sort_values(by=value_column, ascending=False)
|
| 142 |
)
|
| 143 |
|
| 144 |
+
fig = self._create_figure()
|
| 145 |
+
ax = fig.add_subplot(111)
|
| 146 |
+
|
| 147 |
if chart_type == "pie":
|
| 148 |
+
# Pie chart
|
| 149 |
+
wedges, texts, autotexts = ax.pie(
|
| 150 |
+
grouped[value_column],
|
| 151 |
+
labels=grouped[category_column],
|
| 152 |
+
autopct='%1.1f%%',
|
| 153 |
+
startangle=90
|
| 154 |
+
)
|
| 155 |
+
ax.set_title(f"{value_column} by {category_column}")
|
| 156 |
else:
|
| 157 |
+
# Bar chart
|
| 158 |
+
bars = ax.bar(grouped[category_column], grouped[value_column], alpha=0.7, edgecolor='black')
|
| 159 |
+
ax.set_title(f"{value_column} by {category_column}")
|
| 160 |
+
ax.set_xlabel(category_column)
|
| 161 |
+
ax.set_ylabel(f"{aggregation} of {value_column}")
|
| 162 |
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
| 163 |
+
|
| 164 |
+
# Rotate x labels if there are many categories
|
| 165 |
+
if len(grouped) > 5:
|
| 166 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
|
| 167 |
|
| 168 |
return fig
|
| 169 |
|
|
|
|
| 171 |
class ScatterStrategy(VisualizationStrategy):
|
| 172 |
"""Strategy for generating scatter plots."""
|
| 173 |
|
| 174 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 175 |
x_column = kwargs.get("x_column")
|
| 176 |
y_column = kwargs.get("y_column")
|
| 177 |
color_column = kwargs.get("color_column")
|
|
|
|
| 184 |
columns.append(color_column)
|
| 185 |
self.validate_columns(df, columns)
|
| 186 |
|
| 187 |
+
# Convert X and Y columns to numeric where possible
|
| 188 |
+
x = pd.to_numeric(df[x_column], errors="coerce")
|
| 189 |
+
y = pd.to_numeric(df[y_column], errors="coerce")
|
| 190 |
+
|
| 191 |
+
valid_mask = ~(x.isna() | y.isna())
|
| 192 |
+
if valid_mask.sum() == 0:
|
| 193 |
+
raise ValueError("Scatter plot requires numeric data in both X and Y columns.")
|
| 194 |
+
|
| 195 |
+
plot_df = df.loc[valid_mask].copy()
|
| 196 |
+
plot_df[x_column] = x[valid_mask]
|
| 197 |
+
plot_df[y_column] = y[valid_mask]
|
| 198 |
+
|
| 199 |
+
fig = self._create_figure()
|
| 200 |
+
ax = fig.add_subplot(111)
|
| 201 |
+
|
| 202 |
+
if color_column:
|
| 203 |
+
# If color column is present, we need to map categories to colors
|
| 204 |
+
# or use a colormap if numeric
|
| 205 |
+
c_data = plot_df[color_column]
|
| 206 |
+
if pd.api.types.is_numeric_dtype(c_data):
|
| 207 |
+
sc = ax.scatter(plot_df[x_column], plot_df[y_column], c=c_data, cmap='viridis', alpha=0.7)
|
| 208 |
+
fig.colorbar(sc, ax=ax, label=color_column)
|
| 209 |
+
else:
|
| 210 |
+
# Categorical coloring
|
| 211 |
+
categories = c_data.unique()
|
| 212 |
+
colors = plt.cm.tab10(np.linspace(0, 1, len(categories)))
|
| 213 |
+
for cat, color in zip(categories, colors):
|
| 214 |
+
mask = c_data == cat
|
| 215 |
+
ax.scatter(plot_df.loc[mask, x_column], plot_df.loc[mask, y_column], label=str(cat), color=color, alpha=0.7)
|
| 216 |
+
ax.legend(title=color_column)
|
| 217 |
+
else:
|
| 218 |
+
ax.scatter(plot_df[x_column], plot_df[y_column], alpha=0.7)
|
| 219 |
+
|
| 220 |
+
ax.set_title(f"{y_column} vs {x_column}")
|
| 221 |
+
ax.set_xlabel(x_column)
|
| 222 |
+
ax.set_ylabel(y_column)
|
| 223 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
| 224 |
+
|
| 225 |
return fig
|
| 226 |
|
| 227 |
|
| 228 |
class CorrelationHeatmapStrategy(VisualizationStrategy):
|
| 229 |
"""Strategy for generating correlation heatmaps."""
|
| 230 |
|
| 231 |
+
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure:
|
| 232 |
+
numeric_df = df.select_dtypes(include=["number"]).copy()
|
| 233 |
if numeric_df.shape[1] < 2:
|
| 234 |
raise ValueError("At least two numeric columns are required for a correlation heatmap.")
|
| 235 |
|
| 236 |
+
# Drop rows that are completely NaN in numeric columns
|
| 237 |
+
numeric_df = numeric_df.dropna(how="all")
|
| 238 |
+
if numeric_df.empty:
|
| 239 |
+
raise ValueError("No valid numeric data available for correlation heatmap.")
|
| 240 |
+
|
| 241 |
corr = numeric_df.corr()
|
| 242 |
+
|
| 243 |
+
fig = self._create_figure()
|
| 244 |
+
ax = fig.add_subplot(111)
|
| 245 |
+
|
| 246 |
+
cax = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1)
|
| 247 |
+
fig.colorbar(cax, ax=ax)
|
| 248 |
+
|
| 249 |
+
# Set ticks
|
| 250 |
+
ax.set_xticks(range(len(corr.columns)))
|
| 251 |
+
ax.set_yticks(range(len(corr.columns)))
|
| 252 |
+
ax.set_xticklabels(corr.columns, rotation=45, ha="right")
|
| 253 |
+
ax.set_yticklabels(corr.columns)
|
| 254 |
+
|
| 255 |
+
# Annotate values
|
| 256 |
+
for i in range(len(corr.columns)):
|
| 257 |
+
for j in range(len(corr.columns)):
|
| 258 |
+
text = ax.text(j, i, f"{corr.iloc[i, j]:.2f}",
|
| 259 |
+
ha="center", va="center", color="black")
|
| 260 |
+
|
| 261 |
+
ax.set_title("Correlation Heatmap")
|
| 262 |
+
|
| 263 |
return fig
|
| 264 |
|
| 265 |
|
| 266 |
+
def figure_to_png_bytes(fig: Figure) -> BytesIO:
|
| 267 |
"""Export the figure to an in-memory PNG buffer."""
|
| 268 |
+
buf = BytesIO()
|
| 269 |
+
fig.savefig(buf, format="png")
|
| 270 |
+
buf.seek(0)
|
| 271 |
+
return buf
|
|
|
|
| 272 |
|
| 273 |
|
| 274 |
+
def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") -> Figure:
|
| 275 |
"""Generate a time-series plot using the TimeSeriesStrategy."""
|
| 276 |
strategy = TimeSeriesStrategy()
|
| 277 |
return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation)
|
| 278 |
|
| 279 |
|
| 280 |
+
def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") -> Figure:
|
| 281 |
"""Generate a distribution plot using the DistributionStrategy."""
|
| 282 |
strategy = DistributionStrategy()
|
| 283 |
return strategy.generate(df, column=column, plot_type=plot_type)
|
|
|
|
| 285 |
|
| 286 |
def create_category_plot(
|
| 287 |
df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar"
|
| 288 |
+
) -> Figure:
|
| 289 |
"""Generate a category plot using the CategoryStrategy."""
|
| 290 |
strategy = CategoryStrategy()
|
| 291 |
return strategy.generate(
|
|
|
|
| 295 |
|
| 296 |
def create_scatter_plot(
|
| 297 |
df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None
|
| 298 |
+
) -> Figure:
|
| 299 |
"""Generate a scatter plot using the ScatterStrategy."""
|
| 300 |
strategy = ScatterStrategy()
|
| 301 |
return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column)
|
| 302 |
|
| 303 |
|
| 304 |
+
def create_correlation_heatmap(df: pd.DataFrame) -> Figure:
|
| 305 |
"""Generate a correlation heatmap using the CorrelationHeatmapStrategy."""
|
| 306 |
strategy = CorrelationHeatmapStrategy()
|
| 307 |
return strategy.generate(df)
|