splicing-predictor / figures /force_plot.py
odedrr
New figure code
b4bd946
import matplotlib.pyplot as plt
from joblib import load
import scipy.cluster.hierarchy as sch
import pandas as pd
import figutils
import sequence_logo
from quad_model import *
def get_link_midpoint(
link_function, midpoint=0.5, epsilon=1e-5, lb=-100, ub=100, max_iters=50
):
"""Assumes monotonicity and smoothness of link function"""
iters = 0
while iters < max_iters:
xx = np.linspace(lb, ub, 1000)
yy = link_function(xx[:, None]).numpy().flatten()
if min(np.abs(yy - midpoint)) < epsilon:
return xx[np.abs(yy - midpoint) < epsilon][0]
lb_idx = np.where((yy - midpoint) < 0)[0][-1]
ub_idx = np.where((yy - midpoint) > 0)[0][0]
lb = xx[lb_idx]
ub = xx[ub_idx]
iters += 1
raise RuntimeError(f"Max iterations ({max_iters}) reached without solution...")
def collapse_filters(
test_act_incl,
test_act_skip,
iM,
sM,
iM_struct,
sM_struct,
num_seq_filters,
):
incl_seq_filter_names = [
"incl_seq_" + "_".join(map(str, iM[key])) for key in iM.keys()
]
incl_struct_filter_names = [
"incl_struct_" + "_".join(map(str, iM_struct[key])) for key in iM_struct.keys()
]
skip_seq_filter_names = [
"skip_seq_" + "_".join(map(str, sM[key])) for key in sM.keys()
]
skip_struct_filter_names = [
"skip_struct_" + "_".join(map(str, sM_struct[key])) for key in sM_struct.keys()
]
nf_i = len(incl_seq_filter_names) + len(incl_struct_filter_names)
nf_s = len(skip_seq_filter_names) + len(skip_struct_filter_names)
test_act_incl_collapsed = np.zeros((test_act_incl.shape[0], nf_i))
test_act_skip_collapsed = np.zeros((test_act_skip.shape[0], nf_s))
ctr = 0
for incl_idx, incl_key in enumerate(iM.keys()):
test_act_incl_collapsed[:, incl_idx] = test_act_incl[:, iM[incl_key]].sum(
axis=1
)
ctr += 1
for incl_idx, incl_key in enumerate(iM_struct.keys()):
test_act_incl_collapsed[:, ctr + incl_idx] = test_act_incl[
:, num_seq_filters + iM_struct[incl_key]
].sum(axis=1)
ctr = 0
for skip_idx, skip_key in enumerate(sM.keys()):
test_act_skip_collapsed[:, skip_idx] = test_act_skip[:, sM[skip_key]].sum(
axis=1
)
ctr += 1
for skip_idx, skip_key in enumerate(sM_struct.keys()):
test_act_skip_collapsed[:, ctr + skip_idx] = test_act_skip[
:, num_seq_filters + sM_struct[skip_key]
].sum(axis=1)
return pd.DataFrame(
test_act_incl_collapsed,
columns=incl_seq_filter_names + incl_struct_filter_names,
), pd.DataFrame(
test_act_skip_collapsed,
columns=skip_seq_filter_names + skip_struct_filter_names,
)
def create_force_data(
test_act_incl,
test_act_skip,
seq_filters_grouping,
struct_filters_grouping,
num_seq_filters,
sum_positions=True,
link_midpoint=None
):
iM = seq_filters_grouping["incl_membership_dict"]
sM = seq_filters_grouping["skip_membership_dict"]
iM_struct = struct_filters_grouping["incl_membership_dict"]
sM_struct = struct_filters_grouping["skip_membership_dict"]
iT, sT = collapse_filters(
test_act_incl,
test_act_skip,
iM,
sM,
iM_struct,
sM_struct,
num_seq_filters,
)
if sum_positions:
A, B = iT.sum(axis=0), sT.sum(axis=0)
if link_midpoint is not None:
if link_midpoint < 0:
A['incl_bias'] = np.abs(link_midpoint)
else:
B['skip_bias'] = np.abs(link_midpoint)
return A, B
A, B = collapse_positions(iT), collapse_positions(sT)
if link_midpoint is not None:
if link_midpoint < 0:
A['incl_bias'] = np.abs(link_midpoint)
else:
B['skip_bias'] = np.abs(link_midpoint)
return A, B
def merge_small_forces(forces, threshold=1):
if (forces < threshold).sum() == 0:
return forces
merged_key = '___'.join(forces[forces < threshold].index)
merged_force = forces[forces < threshold].sum()
out = forces[forces >= threshold].copy()
out[merged_key] = merged_force
return out
####################
## INITIALIZE
####################
plt.style.use('clean.mplstyle')
# Load data and models
DATA_DIR = "../../sequence_analysis_utils/force_plots_es7/data/"
xTr = load(DATA_DIR+f'xTr_ES7_HeLa_ABC.pkl.gz')
yTr = load(DATA_DIR+f'yTr_ES7_HeLa_ABC.pkl.gz')
xTe = load(DATA_DIR+f'xTe_ES7_HeLa_ABC.pkl.gz')
yTe = load(DATA_DIR+f'yTe_ES7_HeLa_ABC.pkl.gz')
model_fname = f'custom_adjacency_regularizer_20210731_124_step3.h5'
model = tf.keras.models.load_model(model_fname)
num_seq_filters = model.get_layer('qc_incl').kernel.shape[2]
num_struct_filters = model.get_layer('c_incl_struct').kernel.shape[2]
position_bias_size = model.get_layer('position_bias_incl').kernel.shape[0]
struct_filter_width = model.get_layer("c_incl_struct").kernel.shape[0]
input_length = model.input[0].shape[1]
# Group sequence filters
def get_membership_dict(ind):
out = {}
for i, group_i in enumerate(ind):
if group_i not in out:
out[group_i] = []
out[group_i].append(i)
return out
def get_fig_num_rows_cols(membership_dict):
fig_rows = max([len(e) for e in list(membership_dict.values())])
fig_cols = len(membership_dict.keys())
return fig_rows, fig_cols
structure_out_model = Model(inputs=model.inputs, outputs=[
model.get_layer('activation_2').output,
model.get_layer('activation_3').output
])
incl_act, skip_act = structure_out_model.predict(xTr, verbose=1, batch_size=1024)
incl_act_seq = incl_act[:, :, :num_seq_filters]
skip_act_seq = skip_act[:, :, :num_seq_filters]
incl_inds = sch.fcluster(sch.linkage(incl_act_seq.sum(axis=1).T,
metric='correlation',
method='complete'),
t=0.9,
criterion='distance')
skip_inds = sch.fcluster(sch.linkage(skip_act_seq.sum(axis=1).T,
metric='correlation',
method='complete'),
t=0.9,
criterion='distance')
incl_membership_dict = get_membership_dict(incl_inds)
skip_membership_dict = get_membership_dict(skip_inds)
def get_representative_dict(membership_dict, activations):
scores = activations.sum(axis=1).mean(axis=0)
out = dict()
for cluster_id in membership_dict:
filter_ids = membership_dict[cluster_id]
top_filter = np.argmax(scores[filter_ids])
out[cluster_id] = filter_ids[top_filter]
return out
incl_representative_dict = get_representative_dict(incl_membership_dict, incl_act_seq)
skip_representative_dict = get_representative_dict(skip_membership_dict, skip_act_seq)
# Manually modify representatives
skip_representative_dict[4] = 10
incl_representative_dict[3] = 5
# Group sequence filters
seq_filters_grouping = dict(incl_membership_dict=incl_membership_dict,
skip_membership_dict=skip_membership_dict,
incl_representative_dict=incl_representative_dict,
skip_representative_dict=skip_representative_dict)
# Group structure filters
incl_membership_struct_dict = {1: np.array([0, 1, 2, 3, 4, 5, 6, 7])}
skip_membership_struct_dict = {1: np.array([1]), 2: np.array(
[0, 2, 3]), 3: np.array([5, 6, 7]), 4: np.array([4])}
incl_representative_struct_dict = {1: 0}
skip_representative_struct_dict = {1: 1, 2: 0, 3: 5}
struct_filters_grouping = dict(incl_membership_dict=incl_membership_struct_dict,
skip_membership_dict=skip_membership_struct_dict,
incl_representative_dict=incl_representative_struct_dict,
skip_representative_dict=skip_representative_struct_dict)
incl_color = '#669aff'
skip_color = '#ff6666'
light_incl_color = '#C5D6FB'
light_skip_color = '#F6C3C2'
incl_membership_scores = {
key: np.quantile(incl_act_seq[:, :, incl_membership_dict[key]].sum(axis=(1, 2)), 0.9)
for key in incl_membership_dict.keys()
}
skip_membership_scores = {
key: np.quantile(skip_act_seq[:, :, skip_membership_dict[key]].sum(axis=(1, 2)), 0.9)
for key in skip_membership_dict.keys()
}
incl_plot_order = sorted(incl_membership_scores, key=lambda x: -incl_membership_scores[x])
skip_plot_order = sorted(skip_membership_scores, key=lambda x: -skip_membership_scores[x])
top_k = 4
num_extra_filters = 2
incl_filter_lookup = {}
for key, value in seq_filters_grouping['incl_membership_dict'].items():
incl_filter_lookup[f'incl_seq_' + '_'.join(map(str, sorted(value)))] = incl_plot_order.index(key) + 1
for key, value in struct_filters_grouping['incl_membership_dict'].items():
incl_filter_lookup[f'incl_struct_' + '_'.join(map(str, sorted(value)))] = incl_plot_order.index(key) + 1
skip_filter_lookup = {}
for key, value in seq_filters_grouping['skip_membership_dict'].items():
skip_filter_lookup[f'skip_seq_' + '_'.join(map(str, sorted(value)))] = skip_plot_order.index(key) + 1
for key, value in struct_filters_grouping['skip_membership_dict'].items():
skip_filter_lookup[f'skip_struct_' + '_'.join(map(str, sorted(value)))] = skip_plot_order.index(key) + 1
# Manually modify symbols for G-poor, structure, inclusion bias, and others
skip_filter_lookup['skip_struct_0_2_3'] = 'S'
skip_filter_lookup['skip_struct_1'] = 'P'
skip_filter_lookup['skip_struct_5_6_7'] = '.'
skip_filter_lookup['skip_struct_4'] = ' '
incl_filter_lookup['incl_bias'] = 'B'
skip_filter_lookup['skip_bias'] = 'B'
# Manually relabel skipping filters
for idx in range(len(skip_plot_order)):
skip_key = skip_plot_order[idx]
skip_filter_num = skip_representative_dict[skip_key]
skip_filter_group = [b for a,b in list(seq_filters_grouping['skip_membership_dict'].items()) if skip_filter_num in b][0]
key = f'skip_seq_' + '_'.join(map(str, sorted(skip_filter_group)))
skip_filter_lookup[key] = idx + 1 + top_k + num_extra_filters
def get_model_midpoint(model, midpoint=0.5):
""" Compute the midpoint using the model's link function. This is the negation of the basal strength. I.e., positive value corresponds to a skipping basal strength.
"""
link_input = Input(shape=(1,))
w = model.get_layer('energy_seq_struct').w.numpy()
b = model.get_layer('energy_seq_struct').b.numpy()
link_output = model.get_layer('output_activation')(model.get_layer('gen_func')(w*link_input + b))
link_function = Model(inputs=link_input, outputs=link_output)
return get_link_midpoint(link_function, midpoint)
# The main function for drawing a force plot
def draw_force_plot(sequences, # sequences should be flanked by 7 intronic nucleotide on each side; for our dataset, that gives 76+7+7=90
annotations,
highlight_forces = [],
incl_color=incl_color, skip_color=skip_color, light_incl_color=light_incl_color, light_skip_color=light_skip_color,
figsize=(40/2, 10/2), force_y_range=(0, 90), delta_force_y_range=(-15, 25),
ys=[0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6,
0.7, 0.8, 0.9, 0.95, 0.98, 0.986],
draw_numbers=False,
numbers_min_bar_height=4,
label_rotation=0, label_alignment='center', delta_bar_width=2,
width_ratios=[2, 2], # for horizontal plots
height_ratios=[1, 1], # for vertical plots
vertical=False, sharex=False, axislinewidth=1,
vertical_adjustement=0,
parent_figure = None, # if force plot should be a subfigure, pass the subfigure here
custom_model = model # in case exons are non-standard size or a basal shift is required, a custom model can be provided here; layers should be following the same names as the main model
):
#assert data_incl_act.shape == data_skip_act.shape
assert(len(sequences) == len(annotations))
link_midpoint = get_model_midpoint(custom_model)
# compute activations for given sequences
activations_model = Model(inputs=custom_model.inputs, outputs=[
custom_model.get_layer('activation_2').output,
custom_model.get_layer('activation_3').output
])
data_incl_act, data_skip_act = activations_model.predict(figutils.create_input_data(sequences))
N = data_incl_act.shape[0]
if parent_figure:
if vertical:
axarr = parent_figure.subplots(2, 1, gridspec_kw={'height_ratios': height_ratios}, sharex=sharex)
else:
axarr = parent_figure.subplots(1, 2, gridspec_kw={'width_ratios': width_ratios})
else:
if vertical:
fig, axarr = plt.subplots(2, 1, figsize=figsize, dpi=150, gridspec_kw={
'height_ratios': height_ratios}, sharex=sharex)
else:
fig, axarr = plt.subplots(1, 2, figsize=figsize, dpi=150, gridspec_kw={
'width_ratios': width_ratios})
for idx in range(N):
incl_act = data_incl_act[idx]
skip_act = data_skip_act[idx]
incl_forces, skip_forces = create_force_data(incl_act, skip_act,
seq_filters_grouping, struct_filters_grouping,
num_seq_filters, sum_positions=True, link_midpoint=link_midpoint)
incl_forces = merge_small_forces(incl_forces, threshold=0)
incl_forces = incl_forces.sort_values(ascending=False, key=lambda x: (incl_forces.keys()=="incl_bias")*1000+incl_forces.values)
skip_forces = merge_small_forces(skip_forces, threshold=0)
skip_forces = skip_forces.sort_values(ascending=False, key=lambda x: (skip_forces.keys()=="skip_bias")*1000+skip_forces.values)
total_i = 0
total_s = 0
for (f_i_name, f_i), (f_s_name, f_s) in zip(incl_forces.items(), skip_forces.items()):
if f_i_name in highlight_forces:
f_i_color = incl_color
else:
f_i_color = light_incl_color
if f_s_name in highlight_forces:
f_s_color = skip_color
else:
f_s_color = light_skip_color
axarr[0].bar([3*idx], [f_i], bottom=[total_i], color=f_i_color,
linewidth=1, edgecolor='#6b6b6b', width=1, zorder=2)
total_i += f_i
axarr[0].bar([3*idx+1], [f_s], bottom=[total_s],
color=f_s_color, linewidth=1, edgecolor='#6b6b6b', width=1, zorder=2)
total_s += f_s
# draw numbers
if draw_numbers:
labels_i = [incl_filter_lookup[e]
for e in f_i_name.split("___")]
labels_s = [skip_filter_lookup[e]
for e in f_s_name.split("___")]
if f_i > numbers_min_bar_height:
axarr[0].text(3*idx, total_i - f_i/2 - vertical_adjustement,
labels_i[0], ha='center', va='center')
if f_s > numbers_min_bar_height:
axarr[0].text(3*idx + 1, total_s - f_s/2 - vertical_adjustement,
labels_s[0], ha='center', va='center')
delta_force = incl_forces.sum() - skip_forces.sum()
axarr[1].bar([3*idx + 0.5], [delta_force], color='#dad7cd' if delta_force <
0 else '#dad7cd', linewidth=1, edgecolor='#6b6b6b', width=delta_bar_width, zorder=2)
axarr[0].set_ylim(*force_y_range)
axarr[0].set_yticks(20 * np.unique(np.arange(force_y_range[1] +
1 if force_y_range[1] % 20 == 0 else force_y_range[1]) // 20))
axarr[1].set_ylim(*delta_force_y_range)
axarr[0].grid(axis='y', which='both', zorder=0)
axarr[1].grid(axis='y', which='both', zorder=0)
xmin, xmax = axarr[1].get_xlim()
axarr[1].hlines(0, xmin, xmax, color='k', zorder=3, linewidth=1)
axarr[1].set_xlim(xmin, xmax)
axarr[0].set_ylabel('Strength (a.u.)', fontsize=14)
axarr[1].set_ylabel('$\Delta$ Strength (a.u.)', fontsize=14)
axarr[0].spines['right'].set_visible(True)
axarr[0].set_xticks([3*i + 0.5 for i, _ in enumerate(annotations)])
axarr[0].set_xticklabels([annotations[i] for i, _ in enumerate(
annotations)], rotation=label_rotation, ha=label_alignment)
axarr[1].set_xticks([3*i + 0.5 for i, _ in enumerate(annotations)])
axarr[1].set_xticklabels([annotations[i] for i, _ in enumerate(
annotations)], rotation=label_rotation, ha=label_alignment)
xs = [get_model_midpoint(custom_model, midpoint=m)-link_midpoint for m in ys]
ys = [np.round(e, 2) for e in ys]
ax2 = axarr[1].twinx()
axarr[1].spines['right'].set_visible(True)
ax2.set_yticks(xs)
ax2.set_yticklabels(ys)
ax2.set_ylim(*delta_force_y_range)
ax2.set_ylabel('Predicted PSI', fontsize=14)
for ax in axarr:
ax.tick_params(axis='both', labelsize=12)
axarr[0].tick_params(axis='x', length=0)
if (parent_figure != None):
parent_figure.align_ylabels(axarr)
else:
fig.align_ylabels(axarr)
return fig