import matplotlib.pyplot as plt from joblib import load import scipy.cluster.hierarchy as sch import pandas as pd import figutils import sequence_logo from quad_model import * def get_link_midpoint( link_function, midpoint=0.5, epsilon=1e-5, lb=-100, ub=100, max_iters=50 ): """Assumes monotonicity and smoothness of link function""" iters = 0 while iters < max_iters: xx = np.linspace(lb, ub, 1000) yy = link_function(xx[:, None]).numpy().flatten() if min(np.abs(yy - midpoint)) < epsilon: return xx[np.abs(yy - midpoint) < epsilon][0] lb_idx = np.where((yy - midpoint) < 0)[0][-1] ub_idx = np.where((yy - midpoint) > 0)[0][0] lb = xx[lb_idx] ub = xx[ub_idx] iters += 1 raise RuntimeError(f"Max iterations ({max_iters}) reached without solution...") def collapse_filters( test_act_incl, test_act_skip, iM, sM, iM_struct, sM_struct, num_seq_filters, ): incl_seq_filter_names = [ "incl_seq_" + "_".join(map(str, iM[key])) for key in iM.keys() ] incl_struct_filter_names = [ "incl_struct_" + "_".join(map(str, iM_struct[key])) for key in iM_struct.keys() ] skip_seq_filter_names = [ "skip_seq_" + "_".join(map(str, sM[key])) for key in sM.keys() ] skip_struct_filter_names = [ "skip_struct_" + "_".join(map(str, sM_struct[key])) for key in sM_struct.keys() ] nf_i = len(incl_seq_filter_names) + len(incl_struct_filter_names) nf_s = len(skip_seq_filter_names) + len(skip_struct_filter_names) test_act_incl_collapsed = np.zeros((test_act_incl.shape[0], nf_i)) test_act_skip_collapsed = np.zeros((test_act_skip.shape[0], nf_s)) ctr = 0 for incl_idx, incl_key in enumerate(iM.keys()): test_act_incl_collapsed[:, incl_idx] = test_act_incl[:, iM[incl_key]].sum( axis=1 ) ctr += 1 for incl_idx, incl_key in enumerate(iM_struct.keys()): test_act_incl_collapsed[:, ctr + incl_idx] = test_act_incl[ :, num_seq_filters + iM_struct[incl_key] ].sum(axis=1) ctr = 0 for skip_idx, skip_key in enumerate(sM.keys()): test_act_skip_collapsed[:, skip_idx] = test_act_skip[:, sM[skip_key]].sum( axis=1 ) ctr += 1 for skip_idx, skip_key in enumerate(sM_struct.keys()): test_act_skip_collapsed[:, ctr + skip_idx] = test_act_skip[ :, num_seq_filters + sM_struct[skip_key] ].sum(axis=1) return pd.DataFrame( test_act_incl_collapsed, columns=incl_seq_filter_names + incl_struct_filter_names, ), pd.DataFrame( test_act_skip_collapsed, columns=skip_seq_filter_names + skip_struct_filter_names, ) def create_force_data( test_act_incl, test_act_skip, seq_filters_grouping, struct_filters_grouping, num_seq_filters, sum_positions=True, link_midpoint=None ): iM = seq_filters_grouping["incl_membership_dict"] sM = seq_filters_grouping["skip_membership_dict"] iM_struct = struct_filters_grouping["incl_membership_dict"] sM_struct = struct_filters_grouping["skip_membership_dict"] iT, sT = collapse_filters( test_act_incl, test_act_skip, iM, sM, iM_struct, sM_struct, num_seq_filters, ) if sum_positions: A, B = iT.sum(axis=0), sT.sum(axis=0) if link_midpoint is not None: if link_midpoint < 0: A['incl_bias'] = np.abs(link_midpoint) else: B['skip_bias'] = np.abs(link_midpoint) return A, B A, B = collapse_positions(iT), collapse_positions(sT) if link_midpoint is not None: if link_midpoint < 0: A['incl_bias'] = np.abs(link_midpoint) else: B['skip_bias'] = np.abs(link_midpoint) return A, B def merge_small_forces(forces, threshold=1): if (forces < threshold).sum() == 0: return forces merged_key = '___'.join(forces[forces < threshold].index) merged_force = forces[forces < threshold].sum() out = forces[forces >= threshold].copy() out[merged_key] = merged_force return out #################### ## INITIALIZE #################### plt.style.use('clean.mplstyle') # Load data and models DATA_DIR = "../../sequence_analysis_utils/force_plots_es7/data/" xTr = load(DATA_DIR+f'xTr_ES7_HeLa_ABC.pkl.gz') yTr = load(DATA_DIR+f'yTr_ES7_HeLa_ABC.pkl.gz') xTe = load(DATA_DIR+f'xTe_ES7_HeLa_ABC.pkl.gz') yTe = load(DATA_DIR+f'yTe_ES7_HeLa_ABC.pkl.gz') model_fname = f'custom_adjacency_regularizer_20210731_124_step3.h5' model = tf.keras.models.load_model(model_fname) num_seq_filters = model.get_layer('qc_incl').kernel.shape[2] num_struct_filters = model.get_layer('c_incl_struct').kernel.shape[2] position_bias_size = model.get_layer('position_bias_incl').kernel.shape[0] struct_filter_width = model.get_layer("c_incl_struct").kernel.shape[0] input_length = model.input[0].shape[1] # Group sequence filters def get_membership_dict(ind): out = {} for i, group_i in enumerate(ind): if group_i not in out: out[group_i] = [] out[group_i].append(i) return out def get_fig_num_rows_cols(membership_dict): fig_rows = max([len(e) for e in list(membership_dict.values())]) fig_cols = len(membership_dict.keys()) return fig_rows, fig_cols structure_out_model = Model(inputs=model.inputs, outputs=[ model.get_layer('activation_2').output, model.get_layer('activation_3').output ]) incl_act, skip_act = structure_out_model.predict(xTr, verbose=1, batch_size=1024) incl_act_seq = incl_act[:, :, :num_seq_filters] skip_act_seq = skip_act[:, :, :num_seq_filters] incl_inds = sch.fcluster(sch.linkage(incl_act_seq.sum(axis=1).T, metric='correlation', method='complete'), t=0.9, criterion='distance') skip_inds = sch.fcluster(sch.linkage(skip_act_seq.sum(axis=1).T, metric='correlation', method='complete'), t=0.9, criterion='distance') incl_membership_dict = get_membership_dict(incl_inds) skip_membership_dict = get_membership_dict(skip_inds) def get_representative_dict(membership_dict, activations): scores = activations.sum(axis=1).mean(axis=0) out = dict() for cluster_id in membership_dict: filter_ids = membership_dict[cluster_id] top_filter = np.argmax(scores[filter_ids]) out[cluster_id] = filter_ids[top_filter] return out incl_representative_dict = get_representative_dict(incl_membership_dict, incl_act_seq) skip_representative_dict = get_representative_dict(skip_membership_dict, skip_act_seq) # Manually modify representatives skip_representative_dict[4] = 10 incl_representative_dict[3] = 5 # Group sequence filters seq_filters_grouping = dict(incl_membership_dict=incl_membership_dict, skip_membership_dict=skip_membership_dict, incl_representative_dict=incl_representative_dict, skip_representative_dict=skip_representative_dict) # Group structure filters incl_membership_struct_dict = {1: np.array([0, 1, 2, 3, 4, 5, 6, 7])} skip_membership_struct_dict = {1: np.array([1]), 2: np.array( [0, 2, 3]), 3: np.array([5, 6, 7]), 4: np.array([4])} incl_representative_struct_dict = {1: 0} skip_representative_struct_dict = {1: 1, 2: 0, 3: 5} struct_filters_grouping = dict(incl_membership_dict=incl_membership_struct_dict, skip_membership_dict=skip_membership_struct_dict, incl_representative_dict=incl_representative_struct_dict, skip_representative_dict=skip_representative_struct_dict) incl_color = '#669aff' skip_color = '#ff6666' light_incl_color = '#C5D6FB' light_skip_color = '#F6C3C2' incl_membership_scores = { key: np.quantile(incl_act_seq[:, :, incl_membership_dict[key]].sum(axis=(1, 2)), 0.9) for key in incl_membership_dict.keys() } skip_membership_scores = { key: np.quantile(skip_act_seq[:, :, skip_membership_dict[key]].sum(axis=(1, 2)), 0.9) for key in skip_membership_dict.keys() } incl_plot_order = sorted(incl_membership_scores, key=lambda x: -incl_membership_scores[x]) skip_plot_order = sorted(skip_membership_scores, key=lambda x: -skip_membership_scores[x]) top_k = 4 num_extra_filters = 2 incl_filter_lookup = {} for key, value in seq_filters_grouping['incl_membership_dict'].items(): incl_filter_lookup[f'incl_seq_' + '_'.join(map(str, sorted(value)))] = incl_plot_order.index(key) + 1 for key, value in struct_filters_grouping['incl_membership_dict'].items(): incl_filter_lookup[f'incl_struct_' + '_'.join(map(str, sorted(value)))] = incl_plot_order.index(key) + 1 skip_filter_lookup = {} for key, value in seq_filters_grouping['skip_membership_dict'].items(): skip_filter_lookup[f'skip_seq_' + '_'.join(map(str, sorted(value)))] = skip_plot_order.index(key) + 1 for key, value in struct_filters_grouping['skip_membership_dict'].items(): skip_filter_lookup[f'skip_struct_' + '_'.join(map(str, sorted(value)))] = skip_plot_order.index(key) + 1 # Manually modify symbols for G-poor, structure, inclusion bias, and others skip_filter_lookup['skip_struct_0_2_3'] = 'S' skip_filter_lookup['skip_struct_1'] = 'P' skip_filter_lookup['skip_struct_5_6_7'] = '.' skip_filter_lookup['skip_struct_4'] = ' ' incl_filter_lookup['incl_bias'] = 'B' skip_filter_lookup['skip_bias'] = 'B' # Manually relabel skipping filters for idx in range(len(skip_plot_order)): skip_key = skip_plot_order[idx] skip_filter_num = skip_representative_dict[skip_key] skip_filter_group = [b for a,b in list(seq_filters_grouping['skip_membership_dict'].items()) if skip_filter_num in b][0] key = f'skip_seq_' + '_'.join(map(str, sorted(skip_filter_group))) skip_filter_lookup[key] = idx + 1 + top_k + num_extra_filters def get_model_midpoint(model, midpoint=0.5): """ Compute the midpoint using the model's link function. This is the negation of the basal strength. I.e., positive value corresponds to a skipping basal strength. """ link_input = Input(shape=(1,)) w = model.get_layer('energy_seq_struct').w.numpy() b = model.get_layer('energy_seq_struct').b.numpy() link_output = model.get_layer('output_activation')(model.get_layer('gen_func')(w*link_input + b)) link_function = Model(inputs=link_input, outputs=link_output) return get_link_midpoint(link_function, midpoint) # The main function for drawing a force plot def draw_force_plot(sequences, # sequences should be flanked by 7 intronic nucleotide on each side; for our dataset, that gives 76+7+7=90 annotations, highlight_forces = [], incl_color=incl_color, skip_color=skip_color, light_incl_color=light_incl_color, light_skip_color=light_skip_color, figsize=(40/2, 10/2), force_y_range=(0, 90), delta_force_y_range=(-15, 25), ys=[0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.98, 0.986], draw_numbers=False, numbers_min_bar_height=4, label_rotation=0, label_alignment='center', delta_bar_width=2, width_ratios=[2, 2], # for horizontal plots height_ratios=[1, 1], # for vertical plots vertical=False, sharex=False, axislinewidth=1, vertical_adjustement=0, parent_figure = None, # if force plot should be a subfigure, pass the subfigure here custom_model = model # in case exons are non-standard size or a basal shift is required, a custom model can be provided here; layers should be following the same names as the main model ): #assert data_incl_act.shape == data_skip_act.shape assert(len(sequences) == len(annotations)) link_midpoint = get_model_midpoint(custom_model) # compute activations for given sequences activations_model = Model(inputs=custom_model.inputs, outputs=[ custom_model.get_layer('activation_2').output, custom_model.get_layer('activation_3').output ]) data_incl_act, data_skip_act = activations_model.predict(figutils.create_input_data(sequences)) N = data_incl_act.shape[0] if parent_figure: if vertical: axarr = parent_figure.subplots(2, 1, gridspec_kw={'height_ratios': height_ratios}, sharex=sharex) else: axarr = parent_figure.subplots(1, 2, gridspec_kw={'width_ratios': width_ratios}) else: if vertical: fig, axarr = plt.subplots(2, 1, figsize=figsize, dpi=150, gridspec_kw={ 'height_ratios': height_ratios}, sharex=sharex) else: fig, axarr = plt.subplots(1, 2, figsize=figsize, dpi=150, gridspec_kw={ 'width_ratios': width_ratios}) for idx in range(N): incl_act = data_incl_act[idx] skip_act = data_skip_act[idx] incl_forces, skip_forces = create_force_data(incl_act, skip_act, seq_filters_grouping, struct_filters_grouping, num_seq_filters, sum_positions=True, link_midpoint=link_midpoint) incl_forces = merge_small_forces(incl_forces, threshold=0) incl_forces = incl_forces.sort_values(ascending=False, key=lambda x: (incl_forces.keys()=="incl_bias")*1000+incl_forces.values) skip_forces = merge_small_forces(skip_forces, threshold=0) skip_forces = skip_forces.sort_values(ascending=False, key=lambda x: (skip_forces.keys()=="skip_bias")*1000+skip_forces.values) total_i = 0 total_s = 0 for (f_i_name, f_i), (f_s_name, f_s) in zip(incl_forces.items(), skip_forces.items()): if f_i_name in highlight_forces: f_i_color = incl_color else: f_i_color = light_incl_color if f_s_name in highlight_forces: f_s_color = skip_color else: f_s_color = light_skip_color axarr[0].bar([3*idx], [f_i], bottom=[total_i], color=f_i_color, linewidth=1, edgecolor='#6b6b6b', width=1, zorder=2) total_i += f_i axarr[0].bar([3*idx+1], [f_s], bottom=[total_s], color=f_s_color, linewidth=1, edgecolor='#6b6b6b', width=1, zorder=2) total_s += f_s # draw numbers if draw_numbers: labels_i = [incl_filter_lookup[e] for e in f_i_name.split("___")] labels_s = [skip_filter_lookup[e] for e in f_s_name.split("___")] if f_i > numbers_min_bar_height: axarr[0].text(3*idx, total_i - f_i/2 - vertical_adjustement, labels_i[0], ha='center', va='center') if f_s > numbers_min_bar_height: axarr[0].text(3*idx + 1, total_s - f_s/2 - vertical_adjustement, labels_s[0], ha='center', va='center') delta_force = incl_forces.sum() - skip_forces.sum() axarr[1].bar([3*idx + 0.5], [delta_force], color='#dad7cd' if delta_force < 0 else '#dad7cd', linewidth=1, edgecolor='#6b6b6b', width=delta_bar_width, zorder=2) axarr[0].set_ylim(*force_y_range) axarr[0].set_yticks(20 * np.unique(np.arange(force_y_range[1] + 1 if force_y_range[1] % 20 == 0 else force_y_range[1]) // 20)) axarr[1].set_ylim(*delta_force_y_range) axarr[0].grid(axis='y', which='both', zorder=0) axarr[1].grid(axis='y', which='both', zorder=0) xmin, xmax = axarr[1].get_xlim() axarr[1].hlines(0, xmin, xmax, color='k', zorder=3, linewidth=1) axarr[1].set_xlim(xmin, xmax) axarr[0].set_ylabel('Strength (a.u.)', fontsize=14) axarr[1].set_ylabel('$\Delta$ Strength (a.u.)', fontsize=14) axarr[0].spines['right'].set_visible(True) axarr[0].set_xticks([3*i + 0.5 for i, _ in enumerate(annotations)]) axarr[0].set_xticklabels([annotations[i] for i, _ in enumerate( annotations)], rotation=label_rotation, ha=label_alignment) axarr[1].set_xticks([3*i + 0.5 for i, _ in enumerate(annotations)]) axarr[1].set_xticklabels([annotations[i] for i, _ in enumerate( annotations)], rotation=label_rotation, ha=label_alignment) xs = [get_model_midpoint(custom_model, midpoint=m)-link_midpoint for m in ys] ys = [np.round(e, 2) for e in ys] ax2 = axarr[1].twinx() axarr[1].spines['right'].set_visible(True) ax2.set_yticks(xs) ax2.set_yticklabels(ys) ax2.set_ylim(*delta_force_y_range) ax2.set_ylabel('Predicted PSI', fontsize=14) for ax in axarr: ax.tick_params(axis='both', labelsize=12) axarr[0].tick_params(axis='x', length=0) if (parent_figure != None): parent_figure.align_ylabels(axarr) else: fig.align_ylabels(axarr) return fig