File size: 6,013 Bytes
a1eb2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Matplotlib style configuration for consistent plotting across scripts.
"""

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from contextlib import contextmanager
import numpy as np

# Colorblind-friendly color palettes
COLORBLIND_PALETTES = {
    # Tol's bright qualitative scheme (up to 7 colors)
    'tol_bright': ['#EE6677', '#228833', '#4477AA', '#CCBB44', '#66CCEE', '#AA3377', '#BBBBBB'],
    
    # Tol's muted qualitative scheme (up to 9 colors)  
    'tol_muted': ['#CC6677', '#332288', '#DDCC77', '#117733', '#88CCEE', '#882255', '#44AA99', '#999933', '#AA4499'],
    
    # Tol's light qualitative scheme (up to 9 colors)
    'tol_light': ['#BBCC33', '#AAAA00', '#77AADD', '#EE8866', '#EEDD88', '#FFAABB', '#99DDFF', '#44BB99', '#DDDDDD'],
    
    # Okabe-Ito palette (most common colorblind-friendly palette)
    'okabe_ito': ['#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7', '#000000'],
    
    # Viridis-like discrete colors
    'viridis_discrete': ['#440154', '#31688e', '#35b779', '#fde725'],
    
    # Wong palette (good for presentations)
    'wong': ['#000000', '#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'],
}

default_palette = 'tol_bright'

def get_colorblind_palette(name=default_palette, n_colors=None):
    """
    Get a colorblind-friendly color palette.
    
    Parameters:
    -----------
    name : str
        Name of the palette ('okabe_ito', 'tol_bright', 'tol_muted', 'tol_light', 'viridis_discrete', 'wong')
    n_colors : int, optional
        Number of colors to return. If None, returns all colors in palette.
    
    Returns:
    --------
    list : List of hex color codes
    """
    palette = COLORBLIND_PALETTES.get(name, COLORBLIND_PALETTES['okabe_ito'])
    if n_colors is not None:
        palette = palette[:n_colors]
    return palette

def set_colorblind_cycle(name=default_palette):
    """Set the default color cycle to a colorblind-friendly palette."""
    colors = get_colorblind_palette(name)
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=colors)

# Physical Review style settings
PHYSREV_STYLE = {
    'text.usetex': True,
    'font.family': 'serif',
    'font.serif': ['Computer Modern Roman'],
    'font.size': 16,
    'axes.labelsize': 16,
    'axes.titlesize': 16,
    'legend.fontsize': 14,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.figsize': (6.5, 3.7),
    'savefig.dpi': 300,
    'lines.linewidth': 1.5,
    'axes.linewidth': 0.8,
}

def apply_physrev_style(colorblind=True, palette=default_palette):
    """Apply Physical Review style to matplotlib."""
    plt.rcParams.update(PHYSREV_STYLE)
    if colorblind:
        set_colorblind_cycle(palette)

def apply_style(style_dict):
    """Apply custom style dictionary to matplotlib."""
    plt.rcParams.update(style_dict)

def reset_style():
    """Reset matplotlib to default style."""
    plt.rcdefaults()

@contextmanager
def physrev_style(colorblind=True, palette=default_palette):
    """Context manager for temporary style application."""
    old_params = plt.rcParams.copy()
    try:
        apply_physrev_style(colorblind=colorblind, palette=palette)
        yield
    finally:
        plt.rcParams.update(old_params)

# Alternative style configurations
PRESENTATION_STYLE = {
    'text.usetex': True,
    'font.family': 'serif',
    'font.serif': ['Computer Modern Roman'],
    'font.size': 18,
    'axes.labelsize': 20,
    'axes.titlesize': 22,
    'legend.fontsize': 16,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'figure.figsize': (10, 6),
    'savefig.dpi': 300,
    'lines.linewidth': 2.0,
    'axes.linewidth': 1.2,
}

def apply_presentation_style(colorblind=True, palette=default_palette):
    """Apply presentation style to matplotlib."""
    plt.rcParams.update(PRESENTATION_STYLE)
    if colorblind:
        set_colorblind_cycle(palette)

@contextmanager
def presentation_style(colorblind=True, palette=default_palette):
    """Context manager for temporary presentation style application."""
    old_params = plt.rcParams.copy()
    try:
        apply_presentation_style(colorblind=colorblind, palette=palette)
        yield
    finally:
        plt.rcParams.update(old_params)

# Utility functions for colorblind-friendly plotting
def create_colorblind_cmap(name=default_palette, n_colors=256):
    """
    Create a colorblind-friendly colormap.
    
    Parameters:
    -----------
    name : str
        Base palette name
    n_colors : int
        Number of colors in the colormap
        
    Returns:
    --------
    matplotlib.colors.LinearSegmentedColormap
    """
    colors = get_colorblind_palette(name)
    return mcolors.LinearSegmentedColormap.from_list(f"{name}_cmap", colors, N=n_colors)

def show_colorblind_palettes():
    """Display all available colorblind-friendly palettes."""
    fig, axes = plt.subplots(len(COLORBLIND_PALETTES), 1, figsize=(10, 2*len(COLORBLIND_PALETTES)))
    
    for i, (name, colors) in enumerate(COLORBLIND_PALETTES.items()):
        ax = axes[i] if len(COLORBLIND_PALETTES) > 1 else axes
        
        # Create color swatches
        for j, color in enumerate(colors):
            ax.add_patch(plt.Rectangle((j, 0), 1, 1, facecolor=color))
        
        ax.set_xlim(0, len(colors))
        ax.set_ylim(0, 1)
        ax.set_title(f'{name} ({len(colors)} colors)')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Add hex codes as text
        for j, color in enumerate(colors):
            ax.text(j+0.5, 0.5, color, ha='center', va='center', 
                   color='white' if _is_dark_color(color) else 'black', fontsize=8)
    
    plt.tight_layout()
    plt.show()

def _is_dark_color(hex_color):
    """Check if a hex color is dark (for text color selection)."""
    hex_color = hex_color.lstrip('#')
    rgb = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    luminance = (0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) / 255
    return luminance < 0.5