"""
Unified styling module for both Streamlit UI and matplotlib plots.
Contains all styling definitions to ensure consistency across the application.
Note: When used outside of Streamlit environment (e.g., in Jupyter notebooks),
you may see warnings about missing ScriptRunContext or Session state. These
warnings are harmless and can be safely ignored - the core plotting functions
(get_plot_style, set_plot_style, PLOT_COLORS) work correctly regardless.
"""
import warnings
import logging
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm
# Suppress Streamlit warnings when running outside streamlit environment
warnings.filterwarnings('ignore', category=UserWarning, module='streamlit')
warnings.filterwarnings('ignore', message='.*ScriptRunContext.*')
warnings.filterwarnings('ignore', message='.*Session state.*')
warnings.filterwarnings('ignore', message='.*missing ScriptRunContext.*')
warnings.filterwarnings('ignore', message='.*does not function when running.*')
warnings.filterwarnings('ignore', module='streamlit.runtime.*')
warnings.filterwarnings('ignore', module='streamlit.runtime.scriptrunner_utils.*')
warnings.filterwarnings('ignore', module='streamlit.runtime.state.*')
# Suppress Streamlit loggers that generate warnings outside streamlit environment
logging.getLogger('streamlit.runtime.scriptrunner_utils.script_run_context').setLevel(logging.ERROR)
logging.getLogger('streamlit.runtime.state.session_state_proxy').setLevel(logging.ERROR)
logging.getLogger('streamlit').setLevel(logging.ERROR)
try:
# Set logging level before importing to suppress initial warnings
for logger_name in ['streamlit', 'streamlit.runtime', 'streamlit.runtime.scriptrunner_utils',
'streamlit.runtime.state', 'streamlit.runtime.scriptrunner_utils.script_run_context',
'streamlit.runtime.state.session_state_proxy']:
logging.getLogger(logger_name).setLevel(logging.ERROR)
import streamlit as st
_STREAMLIT_AVAILABLE = True
except ImportError:
_STREAMLIT_AVAILABLE = False
# Create a mock streamlit module for non-streamlit environments
class MockStreamlit:
class session_state:
dark_theme = False
st = MockStreamlit()
def _suppress_streamlit_warnings(func):
"""Decorator to suppress streamlit warnings in functions."""
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
return func(*args, **kwargs)
return wrapper
# ==========================
# Shared Color Themes
# ==========================
# Light theme colors - consistent across UI and plots
LIGHT_COLORS = {
'background': '#F7FAFC',
'figure_background': '#FFFFFF',
'sidebar_bg_start': '#F0F4F8',
'sidebar_bg_end': '#E4EBF3',
'border_light': '#E2E8F0',
'border_medium': '#CBD5E1',
'text_primary': '#1F2933',
'text_secondary': '#334E68',
'text_tertiary': '#52606D',
'text_light': '#829AB1',
'card_background': '#FFFFFF',
'code_background': '#EEF2FF',
'code_text': '#1E3A8A',
'button_bg_start': '#2563EB',
'button_bg_end': '#1D4ED8',
'button_hover_start': '#1D4ED8',
'button_hover_end': '#1E40AF',
'alert_error_bg': '#FDE8E8',
'alert_error_border': '#F76B6B',
'alert_error_text': '#B91C1C',
'alert_info_bg': '#E0F2FE',
'alert_info_border': '#3B82F6',
'alert_info_text': '#1E3A8A',
'warning_bg': '#FEF3C7',
'warning_border': '#F59E0B',
'success_bg': '#DCFCE7',
'success_border': '#16A34A',
'generate_button_bg': '#047857',
'generate_button_hover': '#0F9D58',
'panel_background': '#FFFFFF',
'panel_border': '#E2E8F0',
'panel_shadow': '0 8px 24px rgba(15, 23, 42, 0.08)',
# Plot-specific colors
'axes_background': '#FFFFFF',
'grid_color': '#E2E8F0',
'spine_color': '#CBD5E1',
}
# Paper theme colors - pure white backgrounds for publication
PAPER_COLORS = {
'background': '#FFFFFF',
'figure_background': '#FFFFFF',
'sidebar_bg_start': '#F5F7FA',
'sidebar_bg_end': '#E4EBF3',
'border_light': '#E2E8F0',
'border_medium': '#CBD5E1',
'text_primary': '#1F2933',
'text_secondary': '#334E68',
'text_tertiary': '#52606D',
'text_light': '#829AB1',
'card_background': '#FFFFFF',
'code_background': '#EEF2FF',
'code_text': '#1E3A8A',
'button_bg_start': '#2563EB',
'button_bg_end': '#1D4ED8',
'button_hover_start': '#1D4ED8',
'button_hover_end': '#1E40AF',
'alert_error_bg': '#FDE8E8',
'alert_error_border': '#F76B6B',
'alert_error_text': '#B91C1C',
'alert_info_bg': '#E0F2FE',
'alert_info_border': '#3B82F6',
'alert_info_text': '#1E3A8A',
'warning_bg': '#FEF3C7',
'warning_border': '#F59E0B',
'success_bg': '#DCFCE7',
'success_border': '#16A34A',
'generate_button_bg': '#047857',
'generate_button_hover': '#0F9D58',
'panel_background': '#FFFFFF',
'panel_border': '#E2E8F0',
'panel_shadow': '0 8px 24px rgba(15, 23, 42, 0.08)',
# Plot-specific colors - pure white for papers
'axes_background': '#FFFFFF',
'grid_color': '#E2E8F0',
'spine_color': '#CBD5E1',
}
# Dark theme colors - consistent across UI and plots
DARK_COLORS = {
'background': '#0F172A',
'figure_background': '#1E293B',
'sidebar_bg_start': '#111C2E',
'sidebar_bg_end': '#1B2537',
'border_light': '#27364C',
'border_medium': '#334155',
'text_primary': '#F8FAFC',
'text_secondary': '#CBD5F5',
'text_tertiary': '#94A3B8',
'text_light': '#64748B',
'card_background': '#1F2937',
'code_background': '#1E3A5F',
'code_text': '#C7D2FE',
'button_bg_start': '#3B82F6',
'button_bg_end': '#2563EB',
'button_hover_start': '#2563EB',
'button_hover_end': '#1D4ED8',
'alert_error_bg': '#451A1A',
'alert_error_border': '#F87171',
'alert_error_text': '#FCA5A5',
'alert_info_bg': '#1E293B',
'alert_info_border': '#60A5FA',
'alert_info_text': '#BFDBFE',
'warning_bg': '#3D2D12',
'warning_border': '#FBBF24',
'success_bg': '#163225',
'success_border': '#34D399',
'generate_button_bg': '#10B981',
'generate_button_hover': '#34D399',
'panel_background': '#1F2937',
'panel_border': '#334155',
'panel_shadow': '0 18px 36px rgba(2, 6, 23, 0.55)',
# Plot-specific colors
'axes_background': '#0F172A',
'grid_color': '#27364C',
'spine_color': '#334155',
}
@_suppress_streamlit_warnings
def get_current_colors():
"""Return the active color scheme, defaulting to the light palette."""
try:
dark_mode = getattr(st.session_state, 'dark_theme', False)
except Exception:
dark_mode = False
return DARK_COLORS if dark_mode else LIGHT_COLORS
# ==========================
# Plot Styling
# ==========================
# Font configuration
DEFAULT_FONT_FAMILY = 'sans-serif'
try:
fm.fontManager.addfont('/usr/share/fonts/truetype/msttcorefonts/Arial.ttf')
PLOT_STYLE_FONT_FAMILY = 'Arial'
print("Successfully loaded Arial font.")
except FileNotFoundError:
print("Arial.ttf not found. Using default system font.")
PLOT_STYLE_FONT_FAMILY = DEFAULT_FONT_FAMILY
except Exception as e:
print(f"An error occurred while trying to load Arial font: {e}. Using default system font.")
PLOT_STYLE_FONT_FAMILY = DEFAULT_FONT_FAMILY
# Color constants for plots
PLOT_COLORS = {
'input_similarity': sns.color_palette('rocket', as_cmap=True),
'output_difference': sns.cubehelix_palette(start=.2, rot=-.3, dark=0, light=0.85,
reverse=True, as_cmap=True),
'conflict': sns.cubehelix_palette(start=2, rot=0, dark=0, light=0.85,
reverse=True, as_cmap=True),
'output_biomechanical': sns.cubehelix_palette(start=2.8, rot=0.4, dark=0, light=0.85,
reverse=True, as_cmap=True)
}
# Additional palettes
purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
reverse=True, as_cmap=True)
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
reverse=True, as_cmap=True)
def get_plot_style(style='default'):
"""Get plot style with specified color theme.
Args:
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
"""
if style == 'paper':
theme_colors = PAPER_COLORS
elif style == 'dark':
theme_colors = DARK_COLORS
else: # default
theme_colors = get_current_colors()
return {
'font_family': PLOT_STYLE_FONT_FAMILY,
'font_size': 18,
'title_size': 20,
'label_size': 18,
'tick_size': 15,
'tick_length': 5,
'tick_width': 0.5,
'tick_pad': 5,
'label_pad_x': -15,
'label_pad_y': -35,
'figure_dpi': 300,
'aspect_ratio': 'equal',
'subplot_wspace': 0.05,
'subplot_hspace': 0.1,
# Theme-specific styling
'figure_facecolor': theme_colors['figure_background'],
'axes_facecolor': theme_colors['axes_background'],
'text_color': theme_colors['text_primary'],
'grid_color': theme_colors['grid_color'],
'spine_color': theme_colors['spine_color'],
}
def set_plot_style(style='default'):
"""Set consistent plot styling across all figures.
Args:
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
"""
plot_style = get_plot_style(style=style)
plt.rcParams['font.family'] = plot_style['font_family']
plt.rcParams['font.size'] = plot_style['font_size']
plt.rcParams['axes.labelsize'] = plot_style['label_size']
plt.rcParams['axes.titlesize'] = plot_style['title_size']
plt.rcParams['xtick.labelsize'] = plot_style['tick_size']
plt.rcParams['ytick.labelsize'] = plot_style['tick_size']
plt.rcParams['xtick.major.pad'] = plot_style['tick_pad']
plt.rcParams['ytick.major.pad'] = plot_style['tick_pad']
plt.rcParams['figure.dpi'] = plot_style['figure_dpi']
plt.rcParams['figure.subplot.wspace'] = plot_style['subplot_wspace']
plt.rcParams['figure.subplot.hspace'] = plot_style['subplot_hspace']
# Apply theme styling
plt.rcParams['figure.facecolor'] = plot_style['figure_facecolor']
plt.rcParams['axes.facecolor'] = plot_style['axes_facecolor']
plt.rcParams['text.color'] = plot_style['text_color']
plt.rcParams['axes.labelcolor'] = plot_style['text_color']
plt.rcParams['xtick.color'] = plot_style['text_color']
plt.rcParams['ytick.color'] = plot_style['text_color']
plt.rcParams['axes.edgecolor'] = plot_style['spine_color']
plt.rcParams['grid.color'] = plot_style['grid_color']
plt.rcParams['grid.alpha'] = 0.7
def apply_theme_to_figure(fig, ax=None):
"""Apply current theme to an existing figure and axes"""
theme_colors = get_current_colors()
if fig:
fig.patch.set_facecolor(theme_colors['figure_background'])
if ax is not None:
# Handle single axes or iterables of axes robustly
if hasattr(ax, 'flatten'):
axes_list = ax.flatten()
elif isinstance(ax, (list, tuple)):
axes_list = ax
elif hasattr(ax, '__iter__'):
axes_list = list(ax)
else:
axes_list = [ax]
for axis in axes_list:
if axis is not None:
axis.set_facecolor(theme_colors['axes_background'])
# Update text colors
axis.title.set_color(theme_colors['text_primary'])
axis.xaxis.label.set_color(theme_colors['text_primary'])
axis.yaxis.label.set_color(theme_colors['text_primary'])
# Update tick colors
axis.tick_params(colors=theme_colors['text_primary'])
# Update spine colors
for spine in axis.spines.values():
spine.set_color(theme_colors['spine_color'])
# Update grid
axis.grid(True, color=theme_colors['grid_color'], alpha=0.7)
return fig, ax
def set_paper_plot_style():
"""Convenience function to set pure white backgrounds for paper publication."""
set_plot_style(style='paper')
# Legacy function name for backward compatibility
def apply_cream_theme_to_figure(fig, ax=None):
"""Apply current theme to an existing figure and axes (legacy function name)"""
return apply_theme_to_figure(fig, ax)
# ==========================
# Streamlit UI Styling
# ==========================
def get_base_css():
"""Returns the base CSS styling used across all pages."""
return f"""
"""
def get_home_page_css():
"""Returns additional CSS specific to the home page."""
return f"""
"""
def get_documentation_page_css():
"""Returns additional CSS specific to the documentation page."""
return f"""
"""
def get_tool_page_css():
"""Returns additional CSS specific to the analysis tool page."""
return f"""
"""
def apply_base_styling():
"""Apply the base styling to the current Streamlit page."""
if not _STREAMLIT_AVAILABLE:
return
st.markdown(get_base_css(), unsafe_allow_html=True)
def apply_home_page_styling():
"""Apply styling specific to the home page."""
if not _STREAMLIT_AVAILABLE:
return
st.markdown(get_base_css(), unsafe_allow_html=True)
st.markdown(get_home_page_css(), unsafe_allow_html=True)
def apply_documentation_page_styling():
"""Apply styling specific to the documentation page."""
if not _STREAMLIT_AVAILABLE:
return
st.markdown(get_base_css(), unsafe_allow_html=True)
st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
def apply_tool_page_styling():
"""Apply styling specific to the analysis tool page."""
if not _STREAMLIT_AVAILABLE:
return
st.markdown(get_base_css(), unsafe_allow_html=True)
st.markdown(get_tool_page_css(), unsafe_allow_html=True)