ai / plotting.py
iamspriyadarshi's picture
Upload 13 files
11767f5 verified
import numpy as np
def check_if_matplotlib(return_mpl=False):
if not return_mpl:
try:
import matplotlib.pyplot as plt
except Exception:
raise ImportError('matplotlib is not installed. Please install it with: pip install matplotlib')
return plt
else:
try:
import matplotlib as mpl
except Exception:
raise ImportError('matplotlib is not installed. Please install it with: pip install matplotlib')
return mpl
def check_if_seaborn():
try:
import seaborn as sns
except Exception:
raise ImportError('seaborn is not installed. Please install it with: pip install seaborn')
return sns
def set_limits(vmin, vcenter, vmax, values):
if vmin is None:
vmin = values.min()
if vmax is None:
vmax = values.max()
if vcenter is None:
vcenter = values.mean()
if vmin >= vcenter:
vmin = -vmax
if vcenter >= vmax:
vmax = -vmin
return vmin, vcenter, vmax
def plot_barplot(acts, contrast, top=25, vertical=False, cmap='coolwarm', vmin=None, vcenter=0, vmax=None,
figsize=(7, 5), dpi=100, ax=None, return_fig=False):
"""
Plot barplots showing the top absolute value activities.
Parameters
----------
acts : DataFrame
Activities obtained from any method.
contrast : str
Name of the contrast (row) to plot.
top : int
Number of top features to plot.
vertical : bool
Whether to plot verticaly or horizontaly.
cmap : str
Colormap to use.
vmin : float, None
The value representing the lower limit of the color scale.
vcenter : float, None
The value representing the center of the color scale.
vmax : float, None
The value representing the upper limit of the color scale.
figsize : tuple
Figure size.
dpi : int
DPI resolution of figure.
ax : Axes, None
A matplotlib axes object. If None returns new figure.
return_fig : bool
Whether to return a Figure object or not.
save : str, None
Path to where to save the plot. Infer the filetype if ending on {``.pdf``, ``.png``, ``.svg``}.
Returns
-------
fig : Figure, None
If return_fig, returns Figure object.
"""
# Load plotting packages
sns = check_if_seaborn()
plt = check_if_matplotlib()
mpl = check_if_matplotlib(return_mpl=True)
# Check for non finite values
if np.any(~np.isfinite(acts)):
raise ValueError('Input acts contains non finite values.')
# Process df
df = acts.loc[[contrast]]
df.index.name = None
df.columns.name = None
df = (df
# Sort by absolute value and transpose
.iloc[:, np.argsort(abs(df.values))[0]].T
# Select top features and add index col
.tail(top).reset_index()
# Rename col
.rename({contrast: 'acts'}, axis=1)
# Sort by activities
.sort_values('acts'))
if vertical:
x, y = 'acts', 'index'
else:
x, y = 'index', 'acts'
# Plot
fig = None
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
sns.barplot(data=df, x=x, y=y, ax=ax)
if vertical:
sizes = np.array([bar.get_width() for bar in ax.containers[0]])
ax.set_xlabel('Activity')
ax.set_ylabel('')
else:
sizes = np.array([bar.get_height() for bar in ax.containers[0]])
ax.tick_params(axis='x', rotation=90)
ax.set_ylabel('Activity')
ax.set_xlabel('')
# Compute color limits
vmin, vcenter, vmax = set_limits(vmin, vcenter, vmax, df['acts'])
# Rescale cmap
divnorm = mpl.colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
cmap_f = plt.get_cmap(cmap)
div_colors = cmap_f(divnorm(sizes))
for bar, color in zip(ax.containers[0], div_colors):
bar.set_facecolor(color)
# Add legend
sm = plt.cm.ScalarMappable(cmap=cmap, norm=divnorm)
sm.set_array([])
ax.get_figure().colorbar(sm, ax=ax)
if return_fig:
return fig