import gradio as gr import anndata as ad import networkx as nx import scanpy as sc import scglue import os import gzip import shutil import cooler import matplotlib import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import PowerNorm from huggingface_hub import hf_hub_download prior_name = 'dcq' resolution = '10kb' loop_q = 0.99 suffix = f'ice_{loop_q}' out_dir = 'pooled_cells_' + suffix plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal') dist_range = 10e6 window_size = 64 eps = 1e-4 n_neighbors = 5 def unzip(file): with gzip.open(file, 'rb') as f_in: new_file = file.replace('.gz', '') with open(new_file, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) return new_file rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) rna_file = unzip(rna_file) hic_file = unzip(hic_file) cooler_celltypes = ['Alpha', 'Beta', 'Acinar', 'Duct', 'PSC'] cooler_resolutions = ['50kb', '200kb', '500kb'] coolers = {} for res in cooler_resolutions: coolers[res] = [] for celltype in cooler_celltypes: for res in cooler_resolutions: cooler_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"pseudobulk_coolers/{celltype}_{res}.cool", repo_type="dataset", token=os.environ['DATASET_SECRET']) coolers[res].append(cooler.Cooler(cooler_file)) beta_cooler = coolers[cooler_resolutions[0]][1] model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET']) rna = ad.read_h5ad(rna_file) # sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors) # sc.tl.umap(rna) scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts") celltypes = sorted(rna.obs['celltype'].unique()) n_clusters = len(celltypes) colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters)))) color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)} sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)} rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)} color_map = {**color_map, **sorted_color_map, **rna_color_map} color_map['Other'] = 'gray' hic = ad.read_h5ad(hic_file) hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"] prior = nx.read_graphml(graph_file) glue = scglue.models.load_model(model_file) genes = [] loops = [] for e, attr in dict(prior.edges).items(): if attr["type"] == 'overlap': gene_name = e[0] if gene_name.startswith('chr'): gene_name = e[1] if gene_name not in genes and not gene_name.startswith('chr'): genes.append(gene_name) elif attr["type"] == 'hic': loops.append(e) rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"] genes = rna.var.query(f"highly_variable").index gene_idx_map = {} for i, gene in enumerate(genes): gene_idx_map[gene] = i peaks = hic.var.query("highly_variable").copy() rna_recon = glue.decode_data("rna", "rna", rna, prior) def get_closest_peak_to_gene(gene_name, rna, peaks): try: loc = rna.var.loc[gene_name] except KeyError: print('Could not find loci', gene_name) return None chrom = loc["chrom"] chromStart = loc["chromStart"] peaks['in_chr'] = peaks['chrom'] == chrom peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart)) peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome return peaks['dist'].idxmin() def get_closest_gene_to_peak(peak_name, rna, peaks): try: loc = peaks.loc[peak_name] except KeyError: print('Could not find peak', peak_name) return None chrom = loc["chrom"] chromStart = loc["chromStart"] rna.var['in_chr'] = rna.var['chrom'] == chrom rna.var['dist'] = rna.var['chromStart'].apply(lambda s: abs(s - chromStart)) rna.var.loc[~rna.var['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome return rna.var['dist'].idxmin() def get_chromosome_from_filename(filename): chr_index = filename.find('chr') # index of chromosome name if chr_index == 0: # if chromosome name is file prefix return filename[:filename.find('.')] file_ending_index = filename.rfind('.') # index of file ending if chr_index > file_ending_index: # if chromosome name is file ending return filename[chr_index:] else: return filename[chr_index: file_ending_index] def draw_heatmap(matrix, color_scale, ax=None, min_val=1.001, return_image=False, return_plt_im=True): if color_scale != 0: color_scale = min(color_scale, np.max(matrix)) breaks = np.append(np.arange(min_val, color_scale, (color_scale - min_val) / 18), np.max(matrix)) elif np.max(matrix) < 2: breaks = np.arange(min_val, np.max(matrix), (np.max(matrix) - min_val) / 19) else: step = (np.quantile(matrix, q=0.98) - 1) / 18 up = np.quantile(matrix, q=0.98) + 0.011 if up < 2: up = 2 step = 0.999 / 18 breaks = np.append(np.arange(min_val, up, step), np.max(matrix) + 0.01) n_bin = 20 # Discretizes the interpolation into bins colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686", "#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D", "#FF0000"] cmap_name = 'deeploop' # Create the colormap cm = matplotlib.colors.LinearSegmentedColormap.from_list( cmap_name, colors, N=n_bin) norm = matplotlib.colors.BoundaryNorm(breaks, 20) # Fewer bins will result in "coarser" colomap interpolation if ax is None: _, ax = plt.subplots() img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation=None) if return_image: plt.close() return img.get_array() elif return_plt_im: return img def get_heatmap_locus(gene, locus1, locus2, heatmap_range): locus1 = locus1.replace(',', '') locus2 = locus2.replace(',', '') try: loc = rna.var.loc[gene] except KeyError: print('Could not find loci', gene) return None chrom = loc["chrom"] if locus1.startswith('chr'): chrom = locus1.split(':')[0] pos1 = locus1.split(':')[1] a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin() else: a1 = get_closest_peak_to_gene(locus1, rna, peaks) if locus2.startswith('chr'): chrom = locus2.split(':')[0] pos2 = locus2.split(':')[1] a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin() else: a2 = get_closest_peak_to_gene(locus2, rna, peaks) a1_start = peaks.loc[a1, 'chromStart'] a2_start = peaks.loc[a2, 'chromStart'] interaction_dist = abs(a1_start - a2_start) chrom_size = beta_cooler.chromsizes[chrom] locus_start = max(0, a1_start - interaction_dist - heatmap_range * 1000) locus_end = min(chrom_size, a2_start + interaction_dist + heatmap_range * 1000) heatmap_locus = f'{chrom}:{locus_start}-{locus_end}' heatmap_size = abs(locus_start - locus_end) res = 'deeploop' if heatmap_size > 5000000 or res not in cooler_resolutions: res = '50kb' if heatmap_size > 10000000: res = '200kb' if heatmap_size > 20000000: res = '500kb' print(heatmap_locus, res) return heatmap_locus, res def get_heatmap(celltype, gene, locus1, locus2, heatmap_range): heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range) c = coolers[res][cooler_celltypes.index(celltype)] mat = c.matrix().fetch(heatmap_locus) bins = c.bins().fetch(heatmap_locus) locus1 = locus1.replace(',', '') locus2 = locus2.replace(',', '') if locus1.startswith('chr'): a1_chrom = locus1.split(':')[0] a1_start = int(locus1.split(':')[1]) else: try: loc = rna.var.loc[gene] except KeyError: print('Could not find loci', gene) return None a1_chrom = loc["chrom"] a1_start = loc["chromStart"] a1_idx = bins.query(f"chrom == '{a1_chrom}'")['start'].apply(lambda s: abs(s - a1_start)).argmin() if locus2.startswith('chr'): a2_chrom = locus2.split(':')[0] a2_start = int(locus2.split(':')[1]) else: try: loc = rna.var.loc[gene] except KeyError: print('Could not find loci', gene) return None a2_chrom = loc["chrom"] a2_start = loc["chromStart"] a2_idx = bins.query(f"chrom == '{a2_chrom}'")['start'].apply(lambda s: abs(s - a2_start)).argmin() print(a1_idx, a2_idx) #img = draw_heatmap(mat, 0, return_image=True) fig, ax = plt.subplots(figsize=(4, 4)) if res == 'deeploop': draw_heatmap(mat, 3.0, ax=ax, min_val=0.8) else: ax.imshow(mat, cmap='Reds', norm=PowerNorm(gamma=0.3), interpolation=None) # remove white padding plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis('off') plt.axis('image') w, h = fig.canvas.get_width_height() #heatmap_x = (a1_idx / len(bins)) * w #heatmap_y = (a2_idx / len(bins)) * h #print(heatmap_x, heatmap_y) heatmap_x = a1_idx heatmap_y = a2_idx ax.scatter(int(heatmap_x), int(heatmap_y), color='green', marker='2', s=150) # redraw the canvas fig = plt.gcf() fig.canvas.draw() # convert canvas to image using numpy img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return img def get_heatmaps(gene, locus1, locus2, heatmap_range): res = [] for celltype in cooler_celltypes: res.append(get_heatmap(celltype, gene, locus1, locus2, heatmap_range)) alpha, beta, acinar, duct, psc = res return alpha, beta, acinar, duct, psc def perturb(gene, locus1, locus2, use_max_gene, tf_knockout): if tf_knockout: guidance_hvf = prior.copy() try: guidance_hvf.remove_node(gene) except: pass # find index of tf and set all counts and values to zero gene_idx = gene_idx_map[gene] a1 = get_closest_peak_to_gene(gene, rna, peaks) try: guidance_hvf.remove_node(a1) except: pass else: locus1 = locus1.replace(',', '') locus2 = locus2.replace(',', '') res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []} for c in celltypes: res[c] = [] res[f'{c}_var'] = [] links_dict = {} for c in celltypes: links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []} if locus1.startswith('chr'): chrom = locus1.split(':')[0] pos1 = locus1.split(':')[1] a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin() else: a1 = get_closest_peak_to_gene(locus1, rna, peaks) if locus2.startswith('chr'): chrom = locus2.split(':')[0] pos2 = locus2.split(':')[1] a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin() else: a2 = get_closest_peak_to_gene(locus2, rna, peaks) print(a1) print(a2) a1_closest_gene = get_closest_gene_to_peak(a1, rna, peaks) a2_closest_gene = get_closest_gene_to_peak(a2, rna, peaks) print(a1_closest_gene, a2_closest_gene) guidance_hvf = prior.copy() try: path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene) try: guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1]) except: pass try: guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2]) except: pass except: pass try: path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene) try: guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1]) except: pass try: guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2]) except: pass except: pass try: guidance_hvf.remove_edge(a1, a2) guidance_hvf.remove_edge(a2, a1) except: pass try: guidance_hvf.remove_edge(a1, a1_closest_gene) guidance_hvf.remove_edge(a1_closest_gene, a1) except: pass try: guidance_hvf.remove_edge(a2, a2_closest_gene) guidance_hvf.remove_edge(a2_closest_gene, a2) except: pass perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf) ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps)) if tf_knockout: max_gene_idxs = np.abs(np.mean(ism, axis=0)).argsort()[::-1][:10] for max_gene_idx in max_gene_idxs: print(rna.var.query(f"highly_variable").index[max_gene_idx], np.mean(ism[:, max_gene_idx])) max_gene_idx = max_gene_idxs[1] else: max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax() max_gene = rna.var.query(f"highly_variable").index[max_gene_idx] print('Max gene:', max_gene) # get integer index of gene if use_max_gene: gene_idx = gene_idx_map[max_gene] else: gene_idx = gene_idx_map[gene] rna.obs['log2FC'] = ism[:, gene_idx] rna.obs['old_mean'] = rna_recon[:, gene_idx] rna.obs['new_mean'] = perterbed_rna_recon[:, gene_idx] # compute new count based on log2FC rna.obs['new_count'] = (rna.layers['counts'][:, gene_idx] + eps) * 2 ** rna.obs['log2FC'] - eps fig = sc.pl.umap(rna, color=['celltype', 'log2FC'], color_map='Spectral', wspace=0.05, legend_loc='on data', legend_fontoutline=2, frameon=False, return_fig=True) if tf_knockout: fig.suptitle(f'{max_gene if use_max_gene else gene} expression after TF knockout of {gene}') else: fig.suptitle(f'{max_gene if use_max_gene else gene} expression after removing ' + a1 + '-' + a2) #fig.tight_layout() #fig.patch.set_facecolor('none') #fig.patch.set_alpha(0.0) fig.canvas.draw() image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) # convert from argb to rgba image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]] fig, ax = plt.subplots(figsize=(6, 4)) sc.pl.violin(rna, max_gene if use_max_gene else gene, groupby='celltype', layer='counts', show=False, ax=ax, rotation=90) #sc.pl.violin(rna, 'old_mean', groupby='celltype', show=False, ax=ax, rotation=90) ax.set_title(f'Normal {max_gene if use_max_gene else gene} expression') fig.tight_layout() fig.canvas.draw() violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) violin_img = violin_img[:, :, [1, 2, 3, 0]] fig, ax = plt.subplots(figsize=(6, 4)) sc.pl.violin(rna, 'new_mean', groupby='celltype', show=False, ax=ax, rotation=90) ax.set_title(f'{max_gene if use_max_gene else gene} normalized expression after perturbation') fig.tight_layout() fig.canvas.draw() violin_img_new = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) violin_img_new = violin_img_new.reshape(fig.canvas.get_width_height()[::-1] + (4,)) violin_img_new = violin_img_new[:, :, [1, 2, 3, 0]] fig, ax = plt.subplots(figsize=(6, 4)) sc.pl.violin(rna, 'log2FC', groupby='celltype', show=False, ax=ax, rotation=90) ax.set_title(f'{max_gene if use_max_gene else gene} expression change') ax.hlines(0, -1, len(celltypes), linestyles='dashed') fig.tight_layout() fig.canvas.draw() violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,)) violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]] return image_from_plot, violin_img, violin_img_new, violin_img_fc def update_on_tf_ko(tf_knockout, use_max_gene): if tf_knockout: return gr.update(visible=False), True else: return gr.update(visible=True), use_max_gene with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo: with gr.Row(): with gr.Column(): in_locus = gr.Textbox(label="Gene/TF", elem_id='in-locus', scale=1) with gr.Row(): max_gene_checkbox = gr.Checkbox(label="Max Gene", elem_id='max-gene-checkbox', scale=1, checked=False) tf_ko = gr.Checkbox(label="TF KO", elem_id='tf-ko', scale=1, checked=False) with gr.Column() as heatmap_col: anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1) anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1) heatmap_x = gr.Textbox(label="Heatmap X", elem_id='heatmap-x', scale=1, visible=False) heatmap_y = gr.Textbox(label="Heatmap Y", elem_id='heatmap-y', scale=1, visible=False) heatmap_size = gr.Slider(label="Heatmap Range", info="kb range aroung input gene locus to expand", minimum=50, maximum=5000, value=250) heatmap_button = gr.Button(value="Generate Heatmaps", elem_id='heatmap-button', scale=1) with gr.Row(): out_heatmaps = [] for celltype in cooler_celltypes: out_heatmaps.append(gr.Image(label=celltype, elem_id=f'out-heatmap-{celltype}', scale=1)) with gr.Row(): run_button = gr.Button(value="Run", elem_id='run-button', scale=1) with gr.Row(): out_img = gr.Image(elem_id='out-img', scale=1) with gr.Row(): out_violin = gr.Image(elem_id='out-violin', scale=1) out_violin_new = gr.Image(elem_id='out-violin-new', scale=1) out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1) #out_plot = gr.Plot(elem_id='out-plot', scale=1) inputs = [in_locus, anchor1, anchor2, max_gene_checkbox, tf_ko] outputs = [out_img, out_violin, out_violin_new, out_violin_fc] gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840', False, False], ['BHLHA15', '', '', True, True], ['FOXA2', '', '', True, True], # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4878272/ # ['OTUD3', 'chr1:20457786', 'chr1:20230413', False], ['IRX2', 'chr5:2704405', 'chr5:2164879', False, False], ['TSPAN1', 'TSPAN1', 'PIK3R3', False, False], # ['IRX1', 'chr5:5397612', 'chr5:4850008', False], ['GCG', 'GCG', 'FAP', False, False], # ['GCG', 'chr2:163162903', 'chr2:162852164', True], ['LOXL4', 'chr10:100186215', 'chr10:99913493', False, False], # ['KRT19', 'chr17:22,220,637', 'chr17:39,591,813', False], # ['LPP', 'chr3:188,097,749', 'chr3:197,916,262', False], ['MAFB', 'chr20:39431654', 'chr20:39368271', True, False], ['CEL', 'chr9:135,937,365', 'chr9:135,973,107', False, False]], inputs=inputs, outputs=outputs, fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces') run_button.click(perturb, inputs, outputs=outputs) heatmap_button.click(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps) tf_ko.change(update_on_tf_ko, [tf_ko, max_gene_checkbox], outputs=[heatmap_col, max_gene_checkbox]) anchor1.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps) anchor2.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps) def set_loop(img, gene, locus1, locus2, heatmap_range, evt: gr.SelectData): h, w = img.shape[:2] idx = evt.index x, y = idx[0], idx[1] heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range) bins = coolers[res][0].bins().fetch(heatmap_locus) bins_idx_x = int(x / w * len(bins)) bins_idx_y = int(y / h * len(bins)) new_a1 = f"{bins.iloc[bins_idx_x]['chrom']}:{bins.iloc[bins_idx_x]['start']}" new_a2 = f"{bins.iloc[bins_idx_y]['chrom']}:{bins.iloc[bins_idx_y]['start']}" return new_a1, new_a2 for out_heatmap in out_heatmaps: out_heatmap.select(set_loop, [out_heatmap, in_locus, anchor1, anchor2, heatmap_size], outputs=[anchor1, anchor2]) if __name__ == "__main__": demo.launch(share=False)