""" Visualization Utility Functions This module provides utility functions for creating common visualizations used in pharmaceutical analytics dashboards. """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go from typing import List, Dict, Any, Optional, Tuple, Union def create_trend_chart( df: pd.DataFrame, date_column: str, value_columns: List[str], title: str = "Trend Analysis", colors: Optional[List[str]] = None, markers: bool = True, annotations: Optional[List[Dict[str, Any]]] = None, height: int = 400 ) -> go.Figure: """ Create a time series trend chart with Plotly Parameters: ----------- df : DataFrame Pandas DataFrame containing the data date_column : str Name of the column containing dates value_columns : List[str] List of column names to plot as lines title : str Chart title colors : List[str], optional List of colors for each line markers : bool Whether to show markers on lines annotations : List[Dict], optional List of annotation dictionaries height : int Height of the chart in pixels Returns: -------- go.Figure Plotly figure object """ # Create figure fig = go.Figure() # Default colors if not provided if not colors: colors = ['blue', 'green', 'red', 'orange', 'purple'] # Convert date column to datetime if not already if not pd.api.types.is_datetime64_any_dtype(df[date_column]): df = df.copy() df[date_column] = pd.to_datetime(df[date_column]) # Add each value column as a line for i, column in enumerate(value_columns): color = colors[i % len(colors)] mode = 'lines+markers' if markers else 'lines' fig.add_trace(go.Scatter( x=df[date_column], y=df[column], mode=mode, name=column, line=dict(color=color, width=2) )) # Add annotations if provided if annotations: for annotation in annotations: if 'x' in annotation and 'text' in annotation: # Convert annotation date to datetime if it's a string if isinstance(annotation['x'], str): annotation['x'] = pd.to_datetime(annotation['x']) fig.add_vline( x=annotation['x'], line_dash="dash", line_color=annotation.get('color', 'red'), annotation_text=annotation['text'], annotation_position=annotation.get('position', 'top right') ) # Update layout fig.update_layout( title=title, xaxis_title=date_column, yaxis_title="Value", height=height, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ), margin=dict(l=20, r=20, t=40, b=20) ) return fig def create_comparison_chart( df: pd.DataFrame, category_column: str, value_columns: List[str], title: str = "Comparison Analysis", chart_type: str = "bar", stacked: bool = False, colors: Optional[List[str]] = None, height: int = 400, horizontal: bool = False ) -> go.Figure: """ Create a comparison chart (bar, line, area) with Plotly Parameters: ----------- df : DataFrame Pandas DataFrame containing the data category_column : str Name of the column containing categories value_columns : List[str] List of column names to plot title : str Chart title chart_type : str Type of chart ('bar', 'line', 'area') stacked : bool Whether to stack the bars/areas colors : List[str], optional List of colors for each series height : int Height of the chart in pixels horizontal : bool If True, create horizontal bar chart Returns: -------- go.Figure Plotly figure object """ # Default colors if not provided if not colors: colors = ['blue', 'green', 'red', 'orange', 'purple'] fig = go.Figure() # Determine barmode based on stacked parameter barmode = 'stack' if stacked else 'group' # Add each value column as a series for i, column in enumerate(value_columns): color = colors[i % len(colors)] if chart_type == 'bar': if horizontal: fig.add_trace(go.Bar( y=df[category_column], x=df[column], name=column, marker_color=color, orientation='h' )) else: fig.add_trace(go.Bar( x=df[category_column], y=df[column], name=column, marker_color=color )) elif chart_type == 'line': fig.add_trace(go.Scatter( x=df[category_column], y=df[column], mode='lines+markers', name=column, line=dict(color=color) )) elif chart_type == 'area': fig.add_trace(go.Scatter( x=df[category_column], y=df[column], mode='lines', name=column, fill='tonexty' if stacked else 'none', line=dict(color=color) )) # Update layout x_title = None if horizontal else category_column y_title = category_column if horizontal else None fig.update_layout( title=title, xaxis_title=x_title, yaxis_title=y_title, barmode=barmode, height=height, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) return fig def create_heatmap( df: pd.DataFrame, x_column: str, y_column: str, value_column: str, title: str = "Heatmap Analysis", colorscale: str = "Blues", height: int = 500, width: int = 700, text_format: Optional[str] = None ) -> go.Figure: """ Create a heatmap with Plotly Parameters: ----------- df : DataFrame Pandas DataFrame containing the data x_column : str Name of the column for x-axis categories y_column : str Name of the column for y-axis categories value_column : str Name of the column containing values to plot title : str Chart title colorscale : str Colorscale for the heatmap height : int Height of the chart in pixels width : int Width of the chart in pixels text_format : str, optional Format string for text values (e.g., ".1f" for float with 1 decimal) Returns: -------- go.Figure Plotly figure object """ # Pivot the data for the heatmap pivot_df = df.pivot_table( index=y_column, columns=x_column, values=value_column, aggfunc='mean' ) # Format text values if specified text_values = None if text_format: text_values = pivot_df.applymap(lambda x: f"{x:{text_format}}") # Create heatmap fig = px.imshow( pivot_df, labels=dict(x=x_column, y=y_column, color=value_column), x=pivot_df.columns, y=pivot_df.index, color_continuous_scale=colorscale, text_auto=text_format is None, # Auto text if format not specified aspect="auto" ) # Add custom text if format specified if text_values is not None: fig.update_traces(text=text_values.values, texttemplate="%{text}") # Update layout fig.update_layout( title=title, height=height, width=width, xaxis=dict(side="bottom"), margin=dict(l=20, r=20, t=40, b=20) ) return fig def create_pie_chart( df: pd.DataFrame, names_column: str, values_column: str, title: str = "Distribution Analysis", colors: Optional[List[str]] = None, hole: float = 0.0, height: int = 400 ) -> go.Figure: """ Create a pie or donut chart with Plotly Parameters: ----------- df : DataFrame Pandas DataFrame containing the data names_column : str Name of the column containing category names values_column : str Name of the column containing values title : str Chart title colors : List[str], optional List of colors for pie slices hole : float Size of hole for donut chart (0.0 for pie chart) height : int Height of the chart in pixels Returns: -------- go.Figure Plotly figure object """ # Create pie chart fig = px.pie( df, names=names_column, values=values_column, title=title, color_discrete_sequence=colors, hole=hole, height=height ) # Update layout fig.update_layout( margin=dict(l=20, r=20, t=40, b=20), legend=dict( orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5 ) ) # Update traces fig.update_traces( textposition='inside', textinfo='percent+label' ) return fig def create_scatter_plot( df: pd.DataFrame, x_column: str, y_column: str, size_column: Optional[str] = None, color_column: Optional[str] = None, title: str = "Correlation Analysis", height: int = 500, trendline: bool = False, hover_data: Optional[List[str]] = None ) -> go.Figure: """ Create a scatter plot with Plotly Parameters: ----------- df : DataFrame Pandas DataFrame containing the data x_column : str Name of the column for x-axis values y_column : str Name of the column for y-axis values size_column : str, optional Name of the column for point sizes color_column : str, optional Name of the column for point colors title : str Chart title height : int Height of the chart in pixels trendline : bool Whether to add a trendline hover_data : List[str], optional List of column names to include in hover data Returns: -------- go.Figure Plotly figure object """ # Create scatter plot fig = px.scatter( df, x=x_column, y=y_column, size=size_column, color=color_column, title=title, height=height, hover_data=hover_data, trendline='ols' if trendline else None ) # Update layout fig.update_layout( xaxis_title=x_column, yaxis_title=y_column, margin=dict(l=20, r=20, t=40, b=20) ) return fig # Example usage if __name__ == "__main__": # Create sample data dates = pd.date_range(start='2023-01-01', periods=12, freq='M') data = { 'date': dates, 'sales': [100, 110, 120, 115, 130, 140, 135, 150, 145, 160, 155, 170], 'target': [105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155, 160], 'region': ['Northeast'] * 12 } df = pd.DataFrame(data) # Create trend chart fig = create_trend_chart( df, date_column='date', value_columns=['sales', 'target'], title='Sales vs Target', annotations=[{'x': '2023-06-01', 'text': 'Campaign Launch'}] ) # Display the chart (in a notebook or Streamlit app) print("Trend chart created successfully!")