File size: 4,317 Bytes
11767f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import numpy as np
def check_if_matplotlib(return_mpl=False):
if not return_mpl:
try:
import matplotlib.pyplot as plt
except Exception:
raise ImportError('matplotlib is not installed. Please install it with: pip install matplotlib')
return plt
else:
try:
import matplotlib as mpl
except Exception:
raise ImportError('matplotlib is not installed. Please install it with: pip install matplotlib')
return mpl
def check_if_seaborn():
try:
import seaborn as sns
except Exception:
raise ImportError('seaborn is not installed. Please install it with: pip install seaborn')
return sns
def set_limits(vmin, vcenter, vmax, values):
if vmin is None:
vmin = values.min()
if vmax is None:
vmax = values.max()
if vcenter is None:
vcenter = values.mean()
if vmin >= vcenter:
vmin = -vmax
if vcenter >= vmax:
vmax = -vmin
return vmin, vcenter, vmax
def plot_barplot(acts, contrast, top=25, vertical=False, cmap='coolwarm', vmin=None, vcenter=0, vmax=None,
figsize=(7, 5), dpi=100, ax=None, return_fig=False):
"""
Plot barplots showing the top absolute value activities.
Parameters
----------
acts : DataFrame
Activities obtained from any method.
contrast : str
Name of the contrast (row) to plot.
top : int
Number of top features to plot.
vertical : bool
Whether to plot verticaly or horizontaly.
cmap : str
Colormap to use.
vmin : float, None
The value representing the lower limit of the color scale.
vcenter : float, None
The value representing the center of the color scale.
vmax : float, None
The value representing the upper limit of the color scale.
figsize : tuple
Figure size.
dpi : int
DPI resolution of figure.
ax : Axes, None
A matplotlib axes object. If None returns new figure.
return_fig : bool
Whether to return a Figure object or not.
save : str, None
Path to where to save the plot. Infer the filetype if ending on {``.pdf``, ``.png``, ``.svg``}.
Returns
-------
fig : Figure, None
If return_fig, returns Figure object.
"""
# Load plotting packages
sns = check_if_seaborn()
plt = check_if_matplotlib()
mpl = check_if_matplotlib(return_mpl=True)
# Check for non finite values
if np.any(~np.isfinite(acts)):
raise ValueError('Input acts contains non finite values.')
# Process df
df = acts.loc[[contrast]]
df.index.name = None
df.columns.name = None
df = (df
# Sort by absolute value and transpose
.iloc[:, np.argsort(abs(df.values))[0]].T
# Select top features and add index col
.tail(top).reset_index()
# Rename col
.rename({contrast: 'acts'}, axis=1)
# Sort by activities
.sort_values('acts'))
if vertical:
x, y = 'acts', 'index'
else:
x, y = 'index', 'acts'
# Plot
fig = None
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
sns.barplot(data=df, x=x, y=y, ax=ax)
if vertical:
sizes = np.array([bar.get_width() for bar in ax.containers[0]])
ax.set_xlabel('Activity')
ax.set_ylabel('')
else:
sizes = np.array([bar.get_height() for bar in ax.containers[0]])
ax.tick_params(axis='x', rotation=90)
ax.set_ylabel('Activity')
ax.set_xlabel('')
# Compute color limits
vmin, vcenter, vmax = set_limits(vmin, vcenter, vmax, df['acts'])
# Rescale cmap
divnorm = mpl.colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
cmap_f = plt.get_cmap(cmap)
div_colors = cmap_f(divnorm(sizes))
for bar, color in zip(ax.containers[0], div_colors):
bar.set_facecolor(color)
# Add legend
sm = plt.cm.ScalarMappable(cmap=cmap, norm=divnorm)
sm.set_array([])
ax.get_figure().colorbar(sm, ax=ax)
if return_fig:
return fig
|