Spaces:
Sleeping
Sleeping
| """ | |
| Visualization module for the Business Intelligence Dashboard. | |
| This module creates various types of charts and visualizations | |
| using the Strategy Pattern for different chart types. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Optional, Tuple, Any | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from utils import detect_column_types | |
| from constants import ( | |
| HISTOGRAM_BINS, | |
| MAX_CATEGORY_DISPLAY, | |
| MIN_NUMERICAL_COLUMNS_FOR_CORRELATION | |
| ) | |
| class VisualizationStrategy(ABC): | |
| """Abstract base class for visualization strategies.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| **kwargs | |
| ) -> go.Figure: | |
| """ | |
| Create a visualization. | |
| Args: | |
| df: Input DataFrame | |
| x_column: X-axis column | |
| y_column: Y-axis column | |
| aggregation: Aggregation method (sum, mean, count, median) | |
| **kwargs: Additional parameters | |
| Returns: | |
| Plotly figure object | |
| """ | |
| pass | |
| class TimeSeriesStrategy(VisualizationStrategy): | |
| """Strategy for creating time series plots.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| **kwargs | |
| ) -> go.Figure: | |
| """Create time series plot.""" | |
| if x_column is None or y_column is None: | |
| raise ValueError("Both x_column and y_column required for time series") | |
| # Convert date column | |
| df = df.copy() | |
| df[x_column] = pd.to_datetime(df[x_column], errors='coerce') | |
| df = df.dropna(subset=[x_column, y_column]) | |
| # Aggregate if needed | |
| if aggregation != 'none': | |
| df = df.groupby(x_column)[y_column].agg(aggregation).reset_index() | |
| fig = px.line( | |
| df, | |
| x=x_column, | |
| y=y_column, | |
| title=f'Time Series: {y_column} over {x_column}', | |
| labels={x_column: x_column, y_column: y_column} | |
| ) | |
| fig.update_layout( | |
| xaxis_title=x_column, | |
| yaxis_title=y_column, | |
| hovermode='x unified', | |
| template='plotly_white' | |
| ) | |
| return fig | |
| class DistributionStrategy(VisualizationStrategy): | |
| """Strategy for creating distribution plots.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| sub_chart_type: str = 'histogram', | |
| **kwargs | |
| ) -> go.Figure: | |
| """Create distribution plot (histogram or box plot).""" | |
| if x_column is None: | |
| raise ValueError("x_column required for distribution plot") | |
| # Get sub_chart_type from kwargs if provided, otherwise use parameter | |
| # Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility | |
| sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type)) | |
| df = df.copy() | |
| df = df.dropna(subset=[x_column]) | |
| if sub_chart_type == 'histogram': | |
| fig = px.histogram( | |
| df, | |
| x=x_column, | |
| title=f'Distribution of {x_column}', | |
| labels={x_column: x_column, 'count': 'Frequency'}, | |
| nbins=HISTOGRAM_BINS | |
| ) | |
| else: # box plot | |
| fig = px.box( | |
| df, | |
| y=x_column, | |
| title=f'Box Plot of {x_column}', | |
| labels={x_column: x_column} | |
| ) | |
| fig.update_layout( | |
| template='plotly_white', | |
| showlegend=False | |
| ) | |
| return fig | |
| class CategoryAnalysisStrategy(VisualizationStrategy): | |
| """Strategy for creating category analysis charts.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| sub_chart_type: str = 'bar', | |
| **kwargs | |
| ) -> go.Figure: | |
| """Create category analysis (bar chart or pie chart).""" | |
| if x_column is None: | |
| raise ValueError("x_column required for category analysis") | |
| # Get sub_chart_type from kwargs if provided, otherwise use parameter | |
| # Check both 'sub_chart_type' (new) and 'chart_type' (legacy) for compatibility | |
| sub_chart_type = kwargs.pop('sub_chart_type', kwargs.pop('chart_type', sub_chart_type)) | |
| df = df.copy() | |
| df = df.dropna(subset=[x_column]) | |
| if y_column: | |
| # Aggregate by category | |
| if aggregation != 'none': | |
| df_agg = df.groupby(x_column)[y_column].agg(aggregation).reset_index() | |
| df_agg.columns = [x_column, y_column] | |
| else: | |
| df_agg = df[[x_column, y_column]] | |
| # Sort by value | |
| df_agg = df_agg.sort_values(y_column, ascending=False).head(MAX_CATEGORY_DISPLAY) | |
| if sub_chart_type == 'bar': | |
| fig = px.bar( | |
| df_agg, | |
| x=x_column, | |
| y=y_column, | |
| title=f'{y_column} by {x_column}', | |
| labels={x_column: x_column, y_column: y_column} | |
| ) | |
| else: # pie | |
| fig = px.pie( | |
| df_agg, | |
| names=x_column, | |
| values=y_column, | |
| title=f'{y_column} Distribution by {x_column}' | |
| ) | |
| else: | |
| # Count by category | |
| value_counts = df[x_column].value_counts().head(MAX_CATEGORY_DISPLAY) | |
| if sub_chart_type == 'bar': | |
| fig = px.bar( | |
| x=value_counts.index, | |
| y=value_counts.values, | |
| title=f'Count by {x_column}', | |
| labels={'x': x_column, 'y': 'Count'} | |
| ) | |
| else: # pie | |
| fig = px.pie( | |
| values=value_counts.values, | |
| names=value_counts.index, | |
| title=f'Distribution of {x_column}' | |
| ) | |
| fig.update_layout(template='plotly_white') | |
| return fig | |
| class ScatterStrategy(VisualizationStrategy): | |
| """Strategy for creating scatter plots.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| color_column: Optional[str] = None, | |
| **kwargs | |
| ) -> go.Figure: | |
| """Create scatter plot.""" | |
| if x_column is None or y_column is None: | |
| raise ValueError("Both x_column and y_column required for scatter plot") | |
| df = df.copy() | |
| df = df.dropna(subset=[x_column, y_column]) | |
| fig = px.scatter( | |
| df, | |
| x=x_column, | |
| y=y_column, | |
| color=color_column, | |
| title=f'Scatter Plot: {y_column} vs {x_column}', | |
| labels={x_column: x_column, y_column: y_column}, | |
| hover_data=df.columns.tolist() | |
| ) | |
| fig.update_layout(template='plotly_white') | |
| return fig | |
| class CorrelationHeatmapStrategy(VisualizationStrategy): | |
| """Strategy for creating correlation heatmaps.""" | |
| def create_chart( | |
| self, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| **kwargs | |
| ) -> go.Figure: | |
| """Create correlation heatmap.""" | |
| numerical, _, _ = detect_column_types(df) | |
| if len(numerical) < MIN_NUMERICAL_COLUMNS_FOR_CORRELATION: | |
| raise ValueError( | |
| f"Need at least {MIN_NUMERICAL_COLUMNS_FOR_CORRELATION} " | |
| "numerical columns for correlation" | |
| ) | |
| corr_matrix = df[numerical].corr() | |
| fig = px.imshow( | |
| corr_matrix, | |
| title='Correlation Heatmap', | |
| labels=dict(x="Column", y="Column", color="Correlation"), | |
| color_continuous_scale='RdBu', | |
| aspect="auto" | |
| ) | |
| fig.update_layout(template='plotly_white') | |
| return fig | |
| class VisualizationFactory: | |
| """Factory class for creating visualizations using Strategy Pattern.""" | |
| def __init__(self): | |
| """Initialize with visualization strategies.""" | |
| self._strategies = { | |
| 'time_series': TimeSeriesStrategy(), | |
| 'distribution': DistributionStrategy(), | |
| 'category': CategoryAnalysisStrategy(), | |
| 'scatter': ScatterStrategy(), | |
| 'correlation': CorrelationHeatmapStrategy() | |
| } | |
| def create_visualization( | |
| self, | |
| chart_type: str, | |
| df: pd.DataFrame, | |
| x_column: Optional[str] = None, | |
| y_column: Optional[str] = None, | |
| aggregation: str = 'sum', | |
| **kwargs | |
| ) -> go.Figure: | |
| """ | |
| Create visualization using appropriate strategy. | |
| Args: | |
| chart_type: Type of chart to create | |
| df: Input DataFrame | |
| x_column: X-axis column | |
| y_column: Y-axis column | |
| aggregation: Aggregation method | |
| **kwargs: Additional parameters | |
| Returns: | |
| Plotly figure object | |
| """ | |
| if chart_type not in self._strategies: | |
| raise ValueError(f"Unknown chart type: {chart_type}") | |
| strategy = self._strategies[chart_type] | |
| return strategy.create_chart( | |
| df, | |
| x_column=x_column, | |
| y_column=y_column, | |
| aggregation=aggregation, | |
| **kwargs | |
| ) | |