BI-dashboard / visualizations.py
Lohith Venkat Chamakura
Initial commit
48909ac
"""
Visualization module for the Business Intelligence Dashboard.
This module creates various types of charts and visualizations
using the Strategy Pattern for different chart types.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Any
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import detect_column_types
from constants import (
HISTOGRAM_BINS,
MAX_CATEGORY_DISPLAY,
MIN_NUMERICAL_COLUMNS_FOR_CORRELATION
)
class VisualizationStrategy(ABC):
"""Abstract base class for visualization strategies."""
@abstractmethod
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
**kwargs
) -> go.Figure:
"""
Create a visualization.
Args:
df: Input DataFrame
x_column: X-axis column
y_column: Y-axis column
aggregation: Aggregation method (sum, mean, count, median)
**kwargs: Additional parameters
Returns:
Plotly figure object
"""
pass
class TimeSeriesStrategy(VisualizationStrategy):
"""Strategy for creating time series plots."""
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
**kwargs
) -> go.Figure:
"""Create time series plot."""
if x_column is None or y_column is None:
raise ValueError("Both x_column and y_column required for time series")
# Convert date column
df = df.copy()
df[x_column] = pd.to_datetime(df[x_column], errors='coerce')
df = df.dropna(subset=[x_column, y_column])
# Aggregate if needed
if aggregation != 'none':
df = df.groupby(x_column)[y_column].agg(aggregation).reset_index()
fig = px.line(
df,
x=x_column,
y=y_column,
title=f'Time Series: {y_column} over {x_column}',
labels={x_column: x_column, y_column: y_column}
)
fig.update_layout(
xaxis_title=x_column,
yaxis_title=y_column,
hovermode='x unified',
template='plotly_white'
)
return fig
class DistributionStrategy(VisualizationStrategy):
"""Strategy for creating distribution plots."""
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
sub_chart_type: str = 'histogram',
**kwargs
) -> go.Figure:
"""Create distribution plot (histogram or box plot)."""
if x_column is None:
raise ValueError("x_column required for distribution plot")
# Get sub_chart_type from kwargs if provided, otherwise use parameter
# Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility
sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type))
df = df.copy()
df = df.dropna(subset=[x_column])
if sub_chart_type == 'histogram':
fig = px.histogram(
df,
x=x_column,
title=f'Distribution of {x_column}',
labels={x_column: x_column, 'count': 'Frequency'},
nbins=HISTOGRAM_BINS
)
else: # box plot
fig = px.box(
df,
y=x_column,
title=f'Box Plot of {x_column}',
labels={x_column: x_column}
)
fig.update_layout(
template='plotly_white',
showlegend=False
)
return fig
class CategoryAnalysisStrategy(VisualizationStrategy):
"""Strategy for creating category analysis charts."""
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
sub_chart_type: str = 'bar',
**kwargs
) -> go.Figure:
"""Create category analysis (bar chart or pie chart)."""
if x_column is None:
raise ValueError("x_column required for category analysis")
# Get sub_chart_type from kwargs if provided, otherwise use parameter
# Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility
sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type))
df = df.copy()
df = df.dropna(subset=[x_column])
if y_column:
# Aggregate by category
if aggregation != 'none':
df_agg = df.groupby(x_column)[y_column].agg(aggregation).reset_index()
df_agg.columns = [x_column, y_column]
else:
df_agg = df[[x_column, y_column]]
# Sort by value
df_agg = df_agg.sort_values(y_column, ascending=False).head(MAX_CATEGORY_DISPLAY)
if sub_chart_type == 'bar':
fig = px.bar(
df_agg,
x=x_column,
y=y_column,
title=f'{y_column} by {x_column}',
labels={x_column: x_column, y_column: y_column}
)
else: # pie
fig = px.pie(
df_agg,
names=x_column,
values=y_column,
title=f'{y_column} Distribution by {x_column}'
)
else:
# Count by category
value_counts = df[x_column].value_counts().head(MAX_CATEGORY_DISPLAY)
if sub_chart_type == 'bar':
fig = px.bar(
x=value_counts.index,
y=value_counts.values,
title=f'Count by {x_column}',
labels={'x': x_column, 'y': 'Count'}
)
else: # pie
fig = px.pie(
values=value_counts.values,
names=value_counts.index,
title=f'Distribution of {x_column}'
)
fig.update_layout(template='plotly_white')
return fig
class ScatterStrategy(VisualizationStrategy):
"""Strategy for creating scatter plots."""
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
color_column: Optional[str] = None,
**kwargs
) -> go.Figure:
"""Create scatter plot."""
if x_column is None or y_column is None:
raise ValueError("Both x_column and y_column required for scatter plot")
df = df.copy()
df = df.dropna(subset=[x_column, y_column])
fig = px.scatter(
df,
x=x_column,
y=y_column,
color=color_column,
title=f'Scatter Plot: {y_column} vs {x_column}',
labels={x_column: x_column, y_column: y_column},
hover_data=df.columns.tolist()
)
fig.update_layout(template='plotly_white')
return fig
class CorrelationHeatmapStrategy(VisualizationStrategy):
"""Strategy for creating correlation heatmaps."""
def create_chart(
self,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
**kwargs
) -> go.Figure:
"""Create correlation heatmap."""
numerical, _, _ = detect_column_types(df)
if len(numerical) < MIN_NUMERICAL_COLUMNS_FOR_CORRELATION:
raise ValueError(
f"Need at least {MIN_NUMERICAL_COLUMNS_FOR_CORRELATION} "
"numerical columns for correlation"
)
corr_matrix = df[numerical].corr()
fig = px.imshow(
corr_matrix,
title='Correlation Heatmap',
labels=dict(x="Column", y="Column", color="Correlation"),
color_continuous_scale='RdBu',
aspect="auto"
)
fig.update_layout(template='plotly_white')
return fig
class VisualizationFactory:
"""Factory class for creating visualizations using Strategy Pattern."""
def __init__(self):
"""Initialize with visualization strategies."""
self._strategies = {
'time_series': TimeSeriesStrategy(),
'distribution': DistributionStrategy(),
'category': CategoryAnalysisStrategy(),
'scatter': ScatterStrategy(),
'correlation': CorrelationHeatmapStrategy()
}
def create_visualization(
self,
chart_type: str,
df: pd.DataFrame,
x_column: Optional[str] = None,
y_column: Optional[str] = None,
aggregation: str = 'sum',
**kwargs
) -> go.Figure:
"""
Create visualization using appropriate strategy.
Args:
chart_type: Type of chart to create
df: Input DataFrame
x_column: X-axis column
y_column: Y-axis column
aggregation: Aggregation method
**kwargs: Additional parameters
Returns:
Plotly figure object
"""
if chart_type not in self._strategies:
raise ValueError(f"Unknown chart type: {chart_type}")
strategy = self._strategies[chart_type]
return strategy.create_chart(
df,
x_column=x_column,
y_column=y_column,
aggregation=aggregation,
**kwargs
)