Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import anndata as ad | |
| import networkx as nx | |
| import scanpy as sc | |
| import scglue | |
| import os | |
| import gzip | |
| import shutil | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import hf_hub_download | |
| prior_name = 'dcq' | |
| resolution = '10kb' | |
| loop_q = 0.99 | |
| suffix = f'ice_{loop_q}' | |
| out_dir = 'pooled_cells_' + suffix | |
| plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal') | |
| dist_range = 10e6 | |
| window_size = 64 | |
| eps = 1e-4 | |
| n_neighbors = 5 | |
| def unzip(file): | |
| with gzip.open(file, 'rb') as f_in: | |
| new_file = file.replace('.gz', '') | |
| with open(new_file, 'wb') as f_out: | |
| shutil.copyfileobj(f_in, f_out) | |
| return new_file | |
| rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) | |
| hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) | |
| graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET']) | |
| rna_file = unzip(rna_file) | |
| hic_file = unzip(hic_file) | |
| model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET']) | |
| rna = ad.read_h5ad(rna_file) | |
| sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors) | |
| sc.tl.umap(rna) | |
| scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts") | |
| celltypes = sorted(rna.obs['celltype'].unique()) | |
| n_clusters = len(celltypes) | |
| colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters)))) | |
| color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)} | |
| sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)} | |
| rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)} | |
| color_map = {**color_map, **sorted_color_map, **rna_color_map} | |
| color_map['Other'] = 'gray' | |
| hic = ad.read_h5ad(hic_file) | |
| hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"] | |
| prior = nx.read_graphml(graph_file) | |
| glue = scglue.models.load_model(model_file) | |
| genes = [] | |
| loops = [] | |
| for e, attr in dict(prior.edges).items(): | |
| if attr["type"] == 'overlap': | |
| gene_name = e[0] | |
| if gene_name.startswith('chr'): | |
| gene_name = e[1] | |
| if gene_name not in genes and not gene_name.startswith('chr'): | |
| genes.append(gene_name) | |
| elif attr["type"] == 'hic': | |
| loops.append(e) | |
| rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"] | |
| genes = rna.var.query(f"highly_variable").index | |
| gene_idx_map = {} | |
| for i, gene in enumerate(genes): | |
| gene_idx_map[gene] = i | |
| peaks = hic.var.query("highly_variable").copy() | |
| rna_recon = glue.decode_data("rna", "rna", rna, prior) | |
| def get_closest_peak_to_gene(gene_name, rna, peaks): | |
| try: | |
| loc = rna.var.loc[gene_name] | |
| except KeyError: | |
| print('Could not find loci', gene_name) | |
| return None | |
| chrom = loc["chrom"] | |
| chromStart = loc["chromStart"] | |
| chromEnd = loc["chromEnd"] | |
| peaks['in_chr'] = peaks['chrom'] == chrom | |
| peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart)) | |
| peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome | |
| return peaks['dist'].idxmin() | |
| def get_chromosome_from_filename(filename): | |
| chr_index = filename.find('chr') # index of chromosome name | |
| if chr_index == 0: # if chromosome name is file prefix | |
| return filename[:filename.find('.')] | |
| file_ending_index = filename.rfind('.') # index of file ending | |
| if chr_index > file_ending_index: # if chromosome name is file ending | |
| return filename[chr_index:] | |
| else: | |
| return filename[chr_index: file_ending_index] | |
| def perturb(gene, locus1, locus2): | |
| locus1 = locus1.replace(',', '') | |
| locus2 = locus2.replace(',', '') | |
| res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []} | |
| for c in celltypes: | |
| res[c] = [] | |
| res[f'{c}_var'] = [] | |
| links_dict = {} | |
| for c in celltypes: | |
| links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []} | |
| if locus1.startswith('chr'): | |
| chrom = locus1.split(':')[0] | |
| pos1 = locus1.split(':')[1] | |
| a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin() | |
| else: | |
| a1 = get_closest_peak_to_gene(locus1, rna, peaks) | |
| if locus2.startswith('chr'): | |
| chrom = locus2.split(':')[0] | |
| pos2 = locus2.split(':')[1] | |
| a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin() | |
| else: | |
| a2 = get_closest_peak_to_gene(locus2, rna, peaks) | |
| print(a1) | |
| print(a2) | |
| guidance_hvf = prior.copy() | |
| try: | |
| path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene) | |
| try: | |
| guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1]) | |
| except: | |
| pass | |
| try: | |
| guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2]) | |
| except: | |
| pass | |
| except: | |
| pass | |
| try: | |
| path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene) | |
| try: | |
| guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1]) | |
| except: | |
| pass | |
| try: | |
| guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2]) | |
| except: | |
| pass | |
| except: | |
| pass | |
| try: | |
| guidance_hvf.remove_edge(a1, a2) | |
| guidance_hvf.remove_edge(a2, a1) | |
| except: | |
| pass | |
| perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf) | |
| ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps)) | |
| max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax() | |
| max_gene = rna.var.query(f"highly_variable").index[max_gene_idx] | |
| print('Max gene:', max_gene) | |
| # get integer index of gene | |
| gene_idx = gene_idx_map[gene] | |
| rna.obs['log2FC'] = ism[:, gene_idx] | |
| fig = sc.pl.umap(rna, | |
| color=['celltype', 'log2FC'], | |
| color_map='Spectral', | |
| wspace=0.05, | |
| legend_loc='on data', | |
| legend_fontoutline=2, | |
| frameon=False, | |
| return_fig=True) | |
| fig.suptitle(f'{gene} expression after removing ' + a1 + '-' + a2) | |
| fig.tight_layout() | |
| #fig.patch.set_facecolor('none') | |
| #fig.patch.set_alpha(0.0) | |
| fig.canvas.draw() | |
| image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) | |
| image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| # convert from argb to rgba | |
| image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]] | |
| fig = sc.pl.violin(rna, gene, groupby='celltype', return_fig=True) | |
| fig.canvas.draw() | |
| violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) | |
| violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| violin_img = violin_img[:, :, [1, 2, 3, 0]] | |
| fig = sc.pl.violin(rna, 'log2FC', groupby='celltype', return_fig=True) | |
| fig.canvas.draw() | |
| violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) | |
| violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]] | |
| return image_from_plot, violin_img, violin_img_fc | |
| with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_locus = gr.Textbox(label="Target Gene", elem_id='in-locus', scale=1) | |
| anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1) | |
| anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1) | |
| with gr.Row(): | |
| run_button = gr.Button(label="Run", elem_id='run-button', scale=1) | |
| with gr.Row(): | |
| out_img = gr.Image(elem_id='out-img', scale=1) | |
| with gr.Row(): | |
| out_violin = gr.Image(elem_id='out-violin', scale=1) | |
| out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1) | |
| #out_plot = gr.Plot(elem_id='out-plot', scale=1) | |
| inputs = [in_locus, anchor1, anchor2] | |
| outputs = [out_img, out_violin, out_violin_fc] | |
| gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840'], | |
| ['INS', 'INS', 'IGF1'], | |
| ['TSPAN1', 'TSPAN1', 'PIK3R3'], | |
| ['GCG', 'GCG', 'FAP'], | |
| ['KRT19', 'chr17:22,220,637', 'chr17:39,591,813'], | |
| ['LPP', 'chr3:188,097,749', 'chr3:197,916,262'], | |
| ['CEL', 'chr9:135,937,365', 'chr9:135,973,107']], | |
| inputs=inputs, | |
| outputs=outputs, | |
| fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces') | |
| run_button.click(perturb, [in_locus, anchor1, anchor2], outputs=outputs) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |