|
|
"""Visualization utilities leveraging the Strategy Pattern for the BI dashboard.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from io import BytesIO |
|
|
from typing import Any, Dict, Iterable, Optional |
|
|
|
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.figure import Figure |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
matplotlib.use('Agg') |
|
|
|
|
|
AGGREGATIONS = { |
|
|
"sum": "sum", |
|
|
"mean": "mean", |
|
|
"median": "median", |
|
|
"count": "count", |
|
|
} |
|
|
|
|
|
|
|
|
class VisualizationStrategy(ABC): |
|
|
"""Abstract base class for visualization strategies.""" |
|
|
|
|
|
@abstractmethod |
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
"""Generate a Matplotlib figure from the provided dataframe and arguments.""" |
|
|
pass |
|
|
|
|
|
def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None: |
|
|
"""Ensure every column exists inside the DataFrame.""" |
|
|
missing = [col for col in columns if col not in df.columns] |
|
|
if missing: |
|
|
raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}") |
|
|
|
|
|
def _create_figure(self) -> Figure: |
|
|
"""Helper to create a standard figure with tight layout.""" |
|
|
fig = Figure(figsize=(10, 6)) |
|
|
fig.set_layout_engine("tight") |
|
|
return fig |
|
|
|
|
|
|
|
|
class TimeSeriesStrategy(VisualizationStrategy): |
|
|
"""Strategy for generating time-series plots.""" |
|
|
|
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
date_column = kwargs.get("date_column") |
|
|
value_column = kwargs.get("value_column") |
|
|
aggregation = kwargs.get("aggregation", "sum") |
|
|
|
|
|
if not date_column or not value_column: |
|
|
raise ValueError("Date and value columns are required for Time Series.") |
|
|
|
|
|
self.validate_columns(df, [date_column, value_column]) |
|
|
|
|
|
if aggregation not in AGGREGATIONS: |
|
|
raise ValueError("Unsupported aggregation method.") |
|
|
|
|
|
date_series = pd.to_datetime(df[date_column], errors="coerce") |
|
|
subset = df.loc[date_series.notna(), [date_column, value_column]].copy() |
|
|
subset[date_column] = pd.to_datetime(subset[date_column]) |
|
|
grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index() |
|
|
|
|
|
|
|
|
grouped = grouped.sort_values(by=date_column) |
|
|
|
|
|
fig = self._create_figure() |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
ax.plot(grouped[date_column], grouped[value_column], marker='o', linestyle='-') |
|
|
ax.set_title(f"{value_column} over time ({aggregation})") |
|
|
ax.set_xlabel(date_column) |
|
|
ax.set_ylabel(value_column) |
|
|
ax.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
|
|
|
fig.autofmt_xdate() |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
class DistributionStrategy(VisualizationStrategy): |
|
|
"""Strategy for generating distribution plots (histogram/box).""" |
|
|
|
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
column = kwargs.get("column") |
|
|
plot_type = kwargs.get("plot_type", "histogram") |
|
|
|
|
|
if not column: |
|
|
raise ValueError("Numeric column is required for Distribution plot.") |
|
|
|
|
|
self.validate_columns(df, [column]) |
|
|
|
|
|
|
|
|
numeric_series = pd.to_numeric(df[column], errors="coerce").dropna() |
|
|
if numeric_series.empty: |
|
|
raise ValueError("Selected column does not contain numeric data.") |
|
|
|
|
|
fig = self._create_figure() |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
if plot_type == "box": |
|
|
ax.boxplot(numeric_series, vert=True, patch_artist=True) |
|
|
ax.set_title(f"Distribution of {column}") |
|
|
ax.set_ylabel(column) |
|
|
ax.set_xticks([]) |
|
|
else: |
|
|
ax.hist(numeric_series, bins=30, edgecolor='black', alpha=0.7) |
|
|
ax.set_title(f"Distribution of {column}") |
|
|
ax.set_xlabel(column) |
|
|
ax.set_ylabel("Frequency") |
|
|
ax.grid(axis='y', linestyle='--', alpha=0.7) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
class CategoryStrategy(VisualizationStrategy): |
|
|
"""Strategy for generating categorical charts (bar/pie).""" |
|
|
|
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
category_column = kwargs.get("category_column") |
|
|
value_column = kwargs.get("value_column") |
|
|
aggregation = kwargs.get("aggregation", "sum") |
|
|
chart_type = kwargs.get("chart_type", "bar").lower() |
|
|
|
|
|
if not category_column or not value_column: |
|
|
raise ValueError("Category and value columns are required for Category plot.") |
|
|
|
|
|
self.validate_columns(df, [category_column, value_column]) |
|
|
if aggregation not in AGGREGATIONS: |
|
|
raise ValueError("Unsupported aggregation method.") |
|
|
|
|
|
grouped = ( |
|
|
df.groupby(category_column)[value_column] |
|
|
.agg(aggregation) |
|
|
.reset_index() |
|
|
.sort_values(by=value_column, ascending=False) |
|
|
) |
|
|
|
|
|
fig = self._create_figure() |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
if chart_type == "pie": |
|
|
|
|
|
wedges, texts, autotexts = ax.pie( |
|
|
grouped[value_column], |
|
|
labels=grouped[category_column], |
|
|
autopct='%1.1f%%', |
|
|
startangle=90 |
|
|
) |
|
|
ax.set_title(f"{value_column} by {category_column}") |
|
|
else: |
|
|
|
|
|
bars = ax.bar(grouped[category_column], grouped[value_column], alpha=0.7, edgecolor='black') |
|
|
ax.set_title(f"{value_column} by {category_column}") |
|
|
ax.set_xlabel(category_column) |
|
|
ax.set_ylabel(f"{aggregation} of {value_column}") |
|
|
ax.grid(axis='y', linestyle='--', alpha=0.7) |
|
|
|
|
|
|
|
|
if len(grouped) > 5: |
|
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
class ScatterStrategy(VisualizationStrategy): |
|
|
"""Strategy for generating scatter plots.""" |
|
|
|
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
x_column = kwargs.get("x_column") |
|
|
y_column = kwargs.get("y_column") |
|
|
color_column = kwargs.get("color_column") |
|
|
|
|
|
if not x_column or not y_column: |
|
|
raise ValueError("X and Y columns are required for Scatter plot.") |
|
|
|
|
|
columns = [x_column, y_column] |
|
|
if color_column: |
|
|
columns.append(color_column) |
|
|
self.validate_columns(df, columns) |
|
|
|
|
|
|
|
|
x = pd.to_numeric(df[x_column], errors="coerce") |
|
|
y = pd.to_numeric(df[y_column], errors="coerce") |
|
|
|
|
|
valid_mask = ~(x.isna() | y.isna()) |
|
|
if valid_mask.sum() == 0: |
|
|
raise ValueError("Scatter plot requires numeric data in both X and Y columns.") |
|
|
|
|
|
plot_df = df.loc[valid_mask].copy() |
|
|
plot_df[x_column] = x[valid_mask] |
|
|
plot_df[y_column] = y[valid_mask] |
|
|
|
|
|
fig = self._create_figure() |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
if color_column: |
|
|
|
|
|
|
|
|
c_data = plot_df[color_column] |
|
|
if pd.api.types.is_numeric_dtype(c_data): |
|
|
sc = ax.scatter(plot_df[x_column], plot_df[y_column], c=c_data, cmap='viridis', alpha=0.7) |
|
|
fig.colorbar(sc, ax=ax, label=color_column) |
|
|
else: |
|
|
|
|
|
categories = c_data.unique() |
|
|
colors = plt.cm.tab10(np.linspace(0, 1, len(categories))) |
|
|
for cat, color in zip(categories, colors): |
|
|
mask = c_data == cat |
|
|
ax.scatter(plot_df.loc[mask, x_column], plot_df.loc[mask, y_column], label=str(cat), color=color, alpha=0.7) |
|
|
ax.legend(title=color_column) |
|
|
else: |
|
|
ax.scatter(plot_df[x_column], plot_df[y_column], alpha=0.7) |
|
|
|
|
|
ax.set_title(f"{y_column} vs {x_column}") |
|
|
ax.set_xlabel(x_column) |
|
|
ax.set_ylabel(y_column) |
|
|
ax.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
class CorrelationHeatmapStrategy(VisualizationStrategy): |
|
|
"""Strategy for generating correlation heatmaps.""" |
|
|
|
|
|
def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
|
|
numeric_df = df.select_dtypes(include=["number"]).copy() |
|
|
if numeric_df.shape[1] < 2: |
|
|
raise ValueError("At least two numeric columns are required for a correlation heatmap.") |
|
|
|
|
|
|
|
|
numeric_df = numeric_df.dropna(how="all") |
|
|
if numeric_df.empty: |
|
|
raise ValueError("No valid numeric data available for correlation heatmap.") |
|
|
|
|
|
corr = numeric_df.corr() |
|
|
|
|
|
fig = self._create_figure() |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
cax = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1) |
|
|
fig.colorbar(cax, ax=ax) |
|
|
|
|
|
|
|
|
ax.set_xticks(range(len(corr.columns))) |
|
|
ax.set_yticks(range(len(corr.columns))) |
|
|
ax.set_xticklabels(corr.columns, rotation=45, ha="right") |
|
|
ax.set_yticklabels(corr.columns) |
|
|
|
|
|
|
|
|
for i in range(len(corr.columns)): |
|
|
for j in range(len(corr.columns)): |
|
|
text = ax.text(j, i, f"{corr.iloc[i, j]:.2f}", |
|
|
ha="center", va="center", color="black") |
|
|
|
|
|
ax.set_title("Correlation Heatmap") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
def figure_to_png_bytes(fig: Figure) -> BytesIO: |
|
|
"""Export the figure to an in-memory PNG buffer.""" |
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
return buf |
|
|
|
|
|
|
|
|
def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") -> Figure: |
|
|
"""Generate a time-series plot using the TimeSeriesStrategy.""" |
|
|
strategy = TimeSeriesStrategy() |
|
|
return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation) |
|
|
|
|
|
|
|
|
def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") -> Figure: |
|
|
"""Generate a distribution plot using the DistributionStrategy.""" |
|
|
strategy = DistributionStrategy() |
|
|
return strategy.generate(df, column=column, plot_type=plot_type) |
|
|
|
|
|
|
|
|
def create_category_plot( |
|
|
df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar" |
|
|
) -> Figure: |
|
|
"""Generate a category plot using the CategoryStrategy.""" |
|
|
strategy = CategoryStrategy() |
|
|
return strategy.generate( |
|
|
df, category_column=category_column, value_column=value_column, aggregation=aggregation, chart_type=chart_type |
|
|
) |
|
|
|
|
|
|
|
|
def create_scatter_plot( |
|
|
df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None |
|
|
) -> Figure: |
|
|
"""Generate a scatter plot using the ScatterStrategy.""" |
|
|
strategy = ScatterStrategy() |
|
|
return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column) |
|
|
|
|
|
|
|
|
def create_correlation_heatmap(df: pd.DataFrame) -> Figure: |
|
|
"""Generate a correlation heatmap using the CorrelationHeatmapStrategy.""" |
|
|
strategy = CorrelationHeatmapStrategy() |
|
|
return strategy.generate(df) |
|
|
|