""" Visualization module for the Business Intelligence Dashboard. This module creates various types of charts and visualizations using the Strategy Pattern for different chart types. """ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, Any import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots from utils import detect_column_types from constants import ( HISTOGRAM_BINS, MAX_CATEGORY_DISPLAY, MIN_NUMERICAL_COLUMNS_FOR_CORRELATION ) class VisualizationStrategy(ABC): """Abstract base class for visualization strategies.""" @abstractmethod def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', **kwargs ) -> go.Figure: """ Create a visualization. Args: df: Input DataFrame x_column: X-axis column y_column: Y-axis column aggregation: Aggregation method (sum, mean, count, median) **kwargs: Additional parameters Returns: Plotly figure object """ pass class TimeSeriesStrategy(VisualizationStrategy): """Strategy for creating time series plots.""" def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', **kwargs ) -> go.Figure: """Create time series plot.""" if x_column is None or y_column is None: raise ValueError("Both x_column and y_column required for time series") # Convert date column df = df.copy() df[x_column] = pd.to_datetime(df[x_column], errors='coerce') df = df.dropna(subset=[x_column, y_column]) # Aggregate if needed if aggregation != 'none': df = df.groupby(x_column)[y_column].agg(aggregation).reset_index() fig = px.line( df, x=x_column, y=y_column, title=f'Time Series: {y_column} over {x_column}', labels={x_column: x_column, y_column: y_column} ) fig.update_layout( xaxis_title=x_column, yaxis_title=y_column, hovermode='x unified', template='plotly_white' ) return fig class DistributionStrategy(VisualizationStrategy): """Strategy for creating distribution plots.""" def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', sub_chart_type: str = 'histogram', **kwargs ) -> go.Figure: """Create distribution plot (histogram or box plot).""" if x_column is None: raise ValueError("x_column required for distribution plot") # Get sub_chart_type from kwargs if provided, otherwise use parameter # Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type)) df = df.copy() df = df.dropna(subset=[x_column]) if sub_chart_type == 'histogram': fig = px.histogram( df, x=x_column, title=f'Distribution of {x_column}', labels={x_column: x_column, 'count': 'Frequency'}, nbins=HISTOGRAM_BINS ) else: # box plot fig = px.box( df, y=x_column, title=f'Box Plot of {x_column}', labels={x_column: x_column} ) fig.update_layout( template='plotly_white', showlegend=False ) return fig class CategoryAnalysisStrategy(VisualizationStrategy): """Strategy for creating category analysis charts.""" def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', sub_chart_type: str = 'bar', **kwargs ) -> go.Figure: """Create category analysis (bar chart or pie chart).""" if x_column is None: raise ValueError("x_column required for category analysis") # Get sub_chart_type from kwargs if provided, otherwise use parameter # Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type)) df = df.copy() df = df.dropna(subset=[x_column]) if y_column: # Aggregate by category if aggregation != 'none': df_agg = df.groupby(x_column)[y_column].agg(aggregation).reset_index() df_agg.columns = [x_column, y_column] else: df_agg = df[[x_column, y_column]] # Sort by value df_agg = df_agg.sort_values(y_column, ascending=False).head(MAX_CATEGORY_DISPLAY) if sub_chart_type == 'bar': fig = px.bar( df_agg, x=x_column, y=y_column, title=f'{y_column} by {x_column}', labels={x_column: x_column, y_column: y_column} ) else: # pie fig = px.pie( df_agg, names=x_column, values=y_column, title=f'{y_column} Distribution by {x_column}' ) else: # Count by category value_counts = df[x_column].value_counts().head(MAX_CATEGORY_DISPLAY) if sub_chart_type == 'bar': fig = px.bar( x=value_counts.index, y=value_counts.values, title=f'Count by {x_column}', labels={'x': x_column, 'y': 'Count'} ) else: # pie fig = px.pie( values=value_counts.values, names=value_counts.index, title=f'Distribution of {x_column}' ) fig.update_layout(template='plotly_white') return fig class ScatterStrategy(VisualizationStrategy): """Strategy for creating scatter plots.""" def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', color_column: Optional[str] = None, **kwargs ) -> go.Figure: """Create scatter plot.""" if x_column is None or y_column is None: raise ValueError("Both x_column and y_column required for scatter plot") df = df.copy() df = df.dropna(subset=[x_column, y_column]) fig = px.scatter( df, x=x_column, y=y_column, color=color_column, title=f'Scatter Plot: {y_column} vs {x_column}', labels={x_column: x_column, y_column: y_column}, hover_data=df.columns.tolist() ) fig.update_layout(template='plotly_white') return fig class CorrelationHeatmapStrategy(VisualizationStrategy): """Strategy for creating correlation heatmaps.""" def create_chart( self, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', **kwargs ) -> go.Figure: """Create correlation heatmap.""" numerical, _, _ = detect_column_types(df) if len(numerical) < MIN_NUMERICAL_COLUMNS_FOR_CORRELATION: raise ValueError( f"Need at least {MIN_NUMERICAL_COLUMNS_FOR_CORRELATION} " "numerical columns for correlation" ) corr_matrix = df[numerical].corr() fig = px.imshow( corr_matrix, title='Correlation Heatmap', labels=dict(x="Column", y="Column", color="Correlation"), color_continuous_scale='RdBu', aspect="auto" ) fig.update_layout(template='plotly_white') return fig class VisualizationFactory: """Factory class for creating visualizations using Strategy Pattern.""" def __init__(self): """Initialize with visualization strategies.""" self._strategies = { 'time_series': TimeSeriesStrategy(), 'distribution': DistributionStrategy(), 'category': CategoryAnalysisStrategy(), 'scatter': ScatterStrategy(), 'correlation': CorrelationHeatmapStrategy() } def create_visualization( self, chart_type: str, df: pd.DataFrame, x_column: Optional[str] = None, y_column: Optional[str] = None, aggregation: str = 'sum', **kwargs ) -> go.Figure: """ Create visualization using appropriate strategy. Args: chart_type: Type of chart to create df: Input DataFrame x_column: X-axis column y_column: Y-axis column aggregation: Aggregation method **kwargs: Additional parameters Returns: Plotly figure object """ if chart_type not in self._strategies: raise ValueError(f"Unknown chart type: {chart_type}") strategy = self._strategies[chart_type] return strategy.create_chart( df, x_column=x_column, y_column=y_column, aggregation=aggregation, **kwargs )