Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| def get_best_alpha(stats_df, modality): | |
| ''' | |
| Takes a DataFrame of scMKL results and returns the alpha with the best mean AUROC | |
| stats_df: a DataFrame | |
| modality: the modality to find the best alpha for | |
| Returns best alpha for modality | |
| ''' | |
| best_alpha["None", "Estrogen Response Early", "Estrogen Response Late", "Protein Secretion", "E2F Targets", "TGF Beta Signaling", "Apical Surface"] = stats_df[(stats_df['Model'] == 'scMKL') & (stats_df['Modality'] == modality)][['Alpha', 'AUROC']].groupby('Alpha')['AUROC'].apply(lambda x: np.mean(x)) | |
| best_alpha = best_alpha[best_alpha == np.max(alpha_star)].index[0] | |
| return best_alpha | |
| def format_datatype_grouping(dtype_grouping): | |
| ''' | |
| Takes either a list | tuple | str and formats the names to match labels in dataframes | |
| Returns formatted names as list or str | |
| ''' | |
| if (type(dtype_grouping) == list) or (type(dtype_grouping) == tuple): | |
| formatted_data = [selection.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal") for selection in dtype_grouping] | |
| else: | |
| formatted_data = dtype_grouping.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal") | |
| return formatted_data | |
| def performance_boxplot(stats_df: pd.DataFrame, dataset: str, modality, metric: str, x_flag = "intersect", x_var = 'Alpha', color_dict = None): | |
| ''' | |
| This function will plot a given metric for a given dataset. | |
| stats_df: a DataFrame with columns | |
| dataset: MCF7, T47D, lymphoma, prostate | |
| modality: which modality or modalities should be visualized | |
| metric: which metric should be displayed | |
| Returns a plotnine object | |
| ''' | |
| # Formatting modality list | |
| modality = format_datatype_grouping(modality) | |
| # Filtering data frame to desired dataset and modality(s) | |
| stats_df = stats_df[(stats_df['Dataset'] == dataset) & (np.isin(stats_df['Modality'], modality)) & (stats_df['Model'] == 'scMKL')] | |
| if ((type(modality) is list) or (type(modality) is tuple)) and (x_flag == "intersect"): | |
| x_list = np.unique(stats_df[x_var]) | |
| for i, mod in enumerate(modality): | |
| x_list = [value for value in x_list if value in np.unique(stats_df[stats_df['Modality'] == mod][x_var])] | |
| stats_df = stats_df[np.isin(stats_df[x_var], x_list)] | |
| if x_flag == 'best': | |
| stats_df = stats_df[stats_df['Alpha Star'] == 'Yes'] | |
| modality_alpha_means = {mod : round(np.mean(stats_df[stats_df['Modality'] == mod]['Alpha']), 3) for mod in np.unique(stats_df['Modality'])} | |
| stats_df['Mean Alpha Star'] = stats_df['Modality'].apply(lambda x: modality_alpha_means[x]) | |
| x_var = 'Mean Alpha Star' if x_var == 'Alpha' else x_var | |
| if x_var == 'Mean_Number_of_Selected_Groups': | |
| for mod in modality: | |
| stats_df.loc[stats_df['Modality'] == mod, 'Mean_Number_of_Selected_Groups'] = np.mean(stats_df[stats_df['Modality'] == mod]['Number_of_Selected_Groups']) | |
| # Making x_var catagorical for plotting | |
| if (metric == 'RAM_usage') or (metric == 'Inference_time'): | |
| x_var = 'Modality' | |
| else: | |
| stats_df = stats_df.sort_values(by = x_var) | |
| stats_df[x_var] = pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])) if 'Alpha' not in x_var else pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])[::-1]) | |
| # performance_bp = (ggplot(stats_df, aes(x = x_var, y = metric, fill = 'Modality', label = 'Modality', color = 'Modality')) | |
| # + geom_boxplot() | |
| # + theme_classic() | |
| # # + scale_fill_manual(values = {'ATAC - cistrome' : '#2e61a3', 'ATAC - hallmark' : '#323aa8', 'ATAC - motifs' : "#05426e", | |
| # # 'ATAC_TFIDF - cistrome' : '#32b3b8', 'ATAC_TFIDF - hallmark' : '#349eeb', | |
| # # 'RNA - hallmark' : '#b52a3c', | |
| # # 'GENE SCORES - hallmark' : '#11bd50'},) | |
| # + theme(axis_text_x=element_text(rotation=90)) | |
| # + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset) | |
| # + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold')) | |
| # # + geom_text() | |
| # # + geom_text(aes(label=after_stat(stats_df['Modality'])), stat="identity", nudge_y=0.125, va="bottom") | |
| # ) | |
| # return performance_bp.draw() | |
| if x_var != 'Modality': | |
| max_x = max(np.unique(stats_df[x_var])) | |
| min_x = min(np.unique(stats_df[x_var])) | |
| range_x = max_x - min_x | |
| width_x = range_x * 0.02 | |
| else: | |
| width_x = None | |
| performance_bp = px.box( | |
| data_frame = stats_df, | |
| x = x_var, | |
| y = metric, | |
| color = 'Modality', | |
| template = 'plotly_white', | |
| height = 800, | |
| hover_name = 'Modality', | |
| category_orders = {'Modality' : modality}, | |
| color_discrete_map = color_dict | |
| ).update_traces(width = width_x, | |
| ).update_layout( | |
| hovermode = 'x unified', | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font_size=16, | |
| namelength = 40), | |
| font = dict( | |
| size = 20 | |
| ) | |
| ).update_xaxes(autorange = 'reversed' if x_var == 'Alpha' else None) | |
| return performance_bp | |
| def comparison_boxplot(stats_df: pd.DataFrame, dataset: str, model, metric: str): | |
| ''' | |
| Takes a DataFrame a makes a box plot of the selected metric for the purpose of comparing models | |
| Returns a plotly object of different model performances | |
| ''' | |
| # Filtering dataframe to desired dataset | |
| stats_df = stats_df[stats_df['Dataset'] == dataset] | |
| # Subsetting scMKL list | |
| subset_modalities = ['RNA - hallmark', 'ATAC - hallmark', 'ATAC_TFIDF - hallmark', 'RNA - all', | |
| 'RNA - hallmark', 'ATAC - mvf', 'ATAC - hallmark', 'GENE_SCORES - hallmark'] | |
| # Removing genescore for lymphoma MAKE THIS BETTER | |
| if dataset == "lymphoma": | |
| stats_df = stats_df[(stats_df['Modality'] != 'GENE_SCORES - hallmark') & (stats_df['Modality'] != 'GENE_SCORES - all')] | |
| # Filtering dataframe to desired models | |
| stats_df = stats_df[np.isin(stats_df['Model'], model)] | |
| # Filtering scMKL runs to best runs | |
| if 'scMKL' in model: | |
| stats_df = stats_df[(stats_df['Alpha Star'] == 'Yes') | (stats_df['Model'] != 'scMKL')] | |
| stats_df = stats_df[np.isin(stats_df['Modality'], subset_modalities)] | |
| stats_df['Model (Modality)'] = stats_df['Model'] + " (" + stats_df['Modality'] + ")" | |
| # Getting order of lowest to highest performance by model and modality | |
| group_order = stats_df[[metric, 'Model (Modality)']].groupby('Model (Modality)').apply(lambda x: np.mean(x)).sort_values().index | |
| stats_df['Model (Modality)'] = pd.Categorical(stats_df['Model (Modality)'], categories = group_order) | |
| # models_bp = (ggplot(stats_df, aes(x = 'Model (Modality)', y = metric, fill = 'Model', color = "Model")) | |
| # + geom_boxplot() | |
| # + theme_classic() | |
| # + scale_fill_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"}) | |
| # + scale_color_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"}) | |
| # + theme(axis_text_x=element_text(rotation=90)) | |
| # + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset) | |
| # + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold')) | |
| # ) | |
| # return models_bp.draw() | |
| models_bp = px.box( | |
| data_frame = stats_df, | |
| x = 'Model (Modality)', | |
| y = metric, | |
| color = 'Model', | |
| template = 'plotly_white', | |
| height = 700, | |
| category_orders = {'Model' : ['scMKL', 'XGBoost', 'MLP'], | |
| 'Model (Modality)' : group_order}, | |
| color_discrete_map = { | |
| 'scMKL' : 'red', | |
| 'XGBoost' : 'blue', | |
| 'MLP' : 'green' | |
| } | |
| ).update_traces(width = 0.75, | |
| ).update_layout( | |
| hovermode = 'x unified', | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font_size=16, | |
| namelength = 40), | |
| font = dict( | |
| size = 20 | |
| ) | |
| ) | |
| return models_bp | |
| def plot_umap(umap_dict, modality, dataset, grouping, label, subset): | |
| ''' | |
| Takes a dictionary of dict[RNA | ATAC][dataset][Embeddings | Cell labels | Silhouette Score] | |
| Returns a plotly object of UMAP embeddings | |
| ''' | |
| if subset == "None": | |
| subset_features = "Most Variable Features" | |
| elif grouping == 'Hallmark': | |
| subset_features = grouping.lower() + '_HALLMARK_' + subset.replace(" ", "_").upper() | |
| elif grouping == 'JASPAR': | |
| subset_features = 'motifs_' + subset | |
| else: | |
| subset_features = grouping.lower() + "_" + subset.replace(" ", "_") | |
| umap_df = pd.DataFrame(umap_dict[modality][dataset][subset_features]['Embeddings']) | |
| umap_df = umap_df.rename(columns = {0 : "UMAP_1", 1 : "UMAP_2", 2 : "UMAP_3"}) | |
| umap_df[label] = np.array(umap_dict[modality][dataset][subset_features]["Cell Labels"][label]) | |
| # umap_plot = (ggplot(umap_df, aes(x = 'UMAP_1', y = 'UMAP_2', color = label)) | |
| # + geom_point(size = 0.75) | |
| # + theme_classic() | |
| # + ggtitle("Silhouette Score: " + str(round(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label], 3)) if type(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label]) != str else umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label]) | |
| # ) | |
| # return umap_plot.draw() | |
| # umap_plot = px.scatter( | |
| # data_frame = umap_df, | |
| # x = 'UMAP_1', | |
| # y = 'UMAP_2', | |
| # color = label, | |
| # template = 'plotly_white', | |
| # ).update_layout( | |
| # hoverlabel=dict( | |
| # font_size=16, | |
| # namelength = 40), | |
| # font = dict( | |
| # size = 20 | |
| # ) | |
| # ) | |
| umap_plot = px.scatter_3d( | |
| data_frame = umap_df, | |
| x = 'UMAP_1', | |
| y = 'UMAP_2', | |
| z = 'UMAP_3', | |
| color = label, | |
| template = 'plotly_white', | |
| height = 650, | |
| ).update_layout( | |
| hoverlabel=dict( | |
| font_size=16, | |
| namelength = 40), | |
| # font = dict( | |
| # size = 1 | |
| # ) | |
| ).update_traces( | |
| marker=dict(size=3)) | |
| return umap_plot | |
| def weights_boxplot(norm_df: pd.DataFrame, dataset, modality, shown_groups = 9): | |
| ''' | |
| norm_df: a dataframe with columns: Group, alpha. norm, mean_weight, log_mean_weights, nonzero, proportion_selected | |
| shown_groups: either a number or list-like object to be displayed in the plot | |
| - if a number, most frequently selected groups are shown | |
| returns a plotly object | |
| ''' | |
| modality = format_datatype_grouping(modality) | |
| norm_df = norm_df[(norm_df['Dataset'] == dataset) & (norm_df['Modality'] == modality)] | |
| if type(shown_groups) == int: | |
| rowsums = norm_df.groupby(['Group'], observed = False).sum('Proportion Selected').sort_values('Proportion Selected') | |
| top_groups = np.array(rowsums.index)[-shown_groups:] | |
| norm_df = norm_df[norm_df.Group.isin(top_groups)] | |
| else: | |
| norm_df = norm_df.iloc[np.where(np.isin(norm_df['Group'], shown_groups))[0], :] | |
| # Building a boxplot of normalized weights | |
| # norm_plot = (ggplot(norm_df) | |
| # + geom_boxplot(aes(x = 'Alpha', y = 'Norm', fill = 'Group', group = "Alpha")) | |
| # + scale_x_continuous(breaks = np.unique(norm_df.Alpha)) | |
| # + theme(figure_size=(1000,1000), axis_text_x= element_text(weight = 'bold')) | |
| # + theme_classic() | |
| # + guides(fill = False) | |
| # + facet_wrap("Group") | |
| # + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset)) | |
| # return norm_plot.draw() | |
| norm_df['Alpha'] = norm_df['Alpha'].astype(str) | |
| norm_plot = px.box( | |
| data_frame = norm_df, | |
| x = 'Alpha', | |
| y = 'Norm', | |
| color = 'Group', | |
| template = 'plotly_white', | |
| height = 700, | |
| facet_col = 'Group', | |
| facet_col_wrap = 3, | |
| category_orders = {'Group' : top_groups[::-1], | |
| 'Alpha' : np.unique(norm_df['Alpha'])[::-1]} | |
| ).update_traces( | |
| width = 0.75, | |
| ).update_yaxes(title = '' | |
| ).update_xaxes(title = '' | |
| ).update_layout( | |
| hovermode = 'x unified', | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font_size=16, | |
| namelength = 40), | |
| font = dict( | |
| size = 20 | |
| ), | |
| showlegend = False, | |
| yaxis4=dict(title = "Normalized Weight"), | |
| xaxis2 = dict(title = "Alpha") | |
| ).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " ")) | |
| ) | |
| return norm_plot | |
| def plot_features(selections_df, dataset, modality): | |
| ''' | |
| Takes feature selection_df and returns the selected features for that experiment as a plot | |
| NOTE: if motifs in modality selection, returns None | |
| Returns a plotly object as a bar plot of top selected features | |
| ''' | |
| modality = format_datatype_grouping(modality) | |
| if 'motif' in modality: | |
| return None | |
| # Formatting DataFrame | |
| selections_df = selections_df[(selections_df['Dataset'] == dataset) & (selections_df['Modality'] == modality)] | |
| selections_df = selections_df.sort_values(by = 'selection', ascending = True) | |
| selections_df['feature'] = pd.Categorical(selections_df['feature'], categories = selections_df['feature']) | |
| selections_df = selections_df.iloc[(len(selections_df) - 40):len(selections_df), :] | |
| # gf_bar = (ggplot(selections_df, aes(y = 'selection', x = 'feature')) | |
| # + geom_bar(stat = 'identity', fill = "#3268a8") | |
| # + theme_bw() | |
| # + ggtitle('Top 50 Features') | |
| # + xlab('Top Selected Features') | |
| # + ylab('scMKL Selection Frequency') | |
| # + coord_flip() | |
| # + theme(axis_text_y= element_text(weight = 'bold'), axis_text_x= element_text(weight = 'bold')) | |
| # ) | |
| # return gf_bar | |
| gf_bar = px.bar( | |
| data_frame = selections_df, | |
| orientation = 'h', | |
| x = 'selection', | |
| y = 'feature', | |
| template = 'plotly_white', | |
| color = 'Number of Groups Feature in', | |
| height = 700, | |
| color_continuous_scale = px.colors.sequential.Bluered, | |
| ).update_layout( | |
| xaxis = dict(title = 'Times Selected by scMKL'), | |
| yaxis = dict(title = 'Features'), | |
| font = dict(size = 12), | |
| ) | |
| return gf_bar | |
| def create_volcano(vol_df, dataset, modality, grouping, group, grouping_dict): | |
| ''' | |
| Takes a processed DataFrame and plots adj. p-value by log(fold_change) | |
| Returns a plotly object | |
| ''' | |
| if dataset == "song_prostate": | |
| dataset = 'prostate_rna' | |
| elif dataset == 'prostate': | |
| dataset = 'prostate_atac' | |
| reg_colors = {'Up-regulated' : 'green', | |
| 'Down-regulated' : 'red', | |
| 'Not significant' : 'blue'} | |
| vol_df = vol_df[vol_df['Dataset'] == dataset] | |
| if "RNA" == modality: | |
| lfc = "logfoldchanges" | |
| label_name = 'names' | |
| modality = "RNA" | |
| adj_pval = 'pvals_adj' | |
| if group != "None": | |
| group = "HALLMARK_" + group.replace(" ", "_").upper() | |
| vol_df = vol_df[np.isin(vol_df['names'], list(grouping_dict[dataset]['RNA'][grouping][group]))] | |
| # vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name)) | |
| # + geom_point(size = 0.5) | |
| # + theme_classic() | |
| # # + geom_text(data = vol_df[np.isin(vol_df[label_name], selected)] , | |
| # # size = 8 | |
| # # ) | |
| # + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black']) | |
| # + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black') | |
| # + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}") | |
| # ) | |
| else: | |
| lfc = "log2(fold_change)" | |
| label_name = 'feature name' | |
| modality = "ATAC" | |
| adj_pval = 'adjusted p-value' | |
| vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Up-regulated' if 'Up' in x else x) | |
| vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Down-regulated' if 'Down' in x else x) | |
| if group != "None": | |
| if grouping == "Hallmark": | |
| group = "HALLMARK_" + group.upper().replace(" ", "_") | |
| vol_df = vol_df[np.isin(vol_df['feature name'], list(grouping_dict[dataset]['ATAC'][grouping][group]))] | |
| # vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name)) | |
| # + geom_point(size = 0.5) | |
| # + theme_classic() | |
| # + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black']) | |
| # + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black') | |
| # + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}") | |
| # ) | |
| # return vol_plot.draw() | |
| vol_plot = px.scatter( | |
| data_frame = vol_df, | |
| x = lfc, | |
| y = '-log10(adjusted p-val)', | |
| color = 'Enrichment', | |
| template = 'plotly_white', | |
| hover_name = label_name, | |
| hover_data = adj_pval, | |
| color_discrete_map = reg_colors, | |
| height = 650, | |
| ).update_layout( | |
| hoverlabel=dict( | |
| font_size=16, | |
| namelength = 40), | |
| font = dict( | |
| size = 20 | |
| ) | |
| ) | |
| return vol_plot | |
| def gene_distribution(freq_df): | |
| ''' | |
| Takes a DataFrame of genes, number of groups gene is in and returns a distribution of gene frequency in grouping. | |
| Returns a plotly histogram of gene frequencies. | |
| ''' | |
| freq_plot = px.histogram( | |
| data_frame = freq_df, | |
| x = 'Number of Sets', | |
| template = 'plotly_white', | |
| color_discrete_sequence = ['blue'], | |
| log_y = True, | |
| title = "Distribution of Hallmark Gene Overlap" | |
| ).update_layout( | |
| font = dict(size = 16), | |
| yaxis = dict(title = "log(Counts)")) | |
| return freq_plot | |
| def GO_plot(GO_df, dataset): | |
| ''' | |
| Takes gene enrichment DataFrame and returns a horizontal barplot of gene set enrichment for go biological processes. | |
| Returns a plotly barplot object. | |
| ''' | |
| GO_df = GO_df[GO_df['Dataset'] == dataset] | |
| GO_df = GO_df.sort_values(by = 'GSE (-log10(adj. p-val))', ascending = False)[0:30].reset_index() | |
| GO_df = GO_df.rename(columns = {"Group Name" : "Gene Sets"}) | |
| GO_df['Gene Sets'] = GO_df['Gene Sets'].apply(lambda x: x.split(" (")[0]) | |
| GO_fig = px.bar( | |
| data_frame = GO_df, | |
| x = 'GSE (-log10(adj. p-val))', | |
| y = 'Gene Sets', | |
| color_discrete_sequence = ['pink'], | |
| template = 'plotly_white', | |
| category_orders = {'Gene Sets' : GO_df['Gene Sets']}, | |
| height = 700, | |
| ).update_layout( | |
| yaxis = dict(dtick = 1), | |
| font = dict(size = 16), | |
| ) | |
| return GO_fig | |
| def hallmark_genesets_plot(hallmark_df, dataset): | |
| ''' | |
| Takes a geneset enrichment barplot for hallmark gene sets and returns gene set enrichment for hallmark gene sets. | |
| Returns a plotly bar plot object. | |
| ''' | |
| hallmark_df = hallmark_df[hallmark_df['Dataset'] == dataset] | |
| order_df = hallmark_df[hallmark_df['Variable'] == 'Proportion of DE Features'].copy() | |
| order = order_df.sort_values(by = 'Value', ascending = False)['Group'] | |
| order = order.tolist() | |
| hallmark_plot = px.bar( | |
| data_frame = hallmark_df, | |
| orientation = 'h', | |
| x = 'Value', | |
| y = 'Group', | |
| facet_col = 'Variable', | |
| color = 'Variable', | |
| template = 'plotly_white', | |
| height = 900, | |
| category_orders = {'Variable' : ['Proportion of DE Features', 'Gene Set Enrichment (-log10(adjusted p-value))', 'scMKL Selection Frequency'], | |
| 'Group' : order}, | |
| hover_name = 'Group', | |
| color_discrete_sequence = ['blue', "orange", "red"] | |
| ).update_layout( | |
| yaxis = dict(title = 'Gene Sets', dtick = 1), | |
| font = dict(size = 16), | |
| xaxis1 = dict(title = "Proportion of DEG Overlap with Hallmark Gene Sets"), | |
| xaxis2 = dict(title = "-log10(adjuseted p-value)"), | |
| xaxis3 = dict(title = "Times selected by scMKL"), | |
| showlegend = False | |
| ).update_xaxes(matches=None | |
| ).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " ")) | |
| ) | |
| return hallmark_plot |