HiGLUE / app.py
dylanplummer's picture
no plotly yet!
b0256eb
raw
history blame
9.59 kB
import gradio as gr
import anndata as ad
import networkx as nx
import scanpy as sc
import scglue
import os
import gzip
import shutil
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
prior_name = 'dcq'
resolution = '10kb'
loop_q = 0.99
suffix = f'ice_{loop_q}'
out_dir = 'pooled_cells_' + suffix
plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal')
dist_range = 10e6
window_size = 64
eps = 1e-4
n_neighbors = 5
def unzip(file):
with gzip.open(file, 'rb') as f_in:
new_file = file.replace('.gz', '')
with open(new_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
return new_file
rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
rna_file = unzip(rna_file)
hic_file = unzip(hic_file)
model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET'])
rna = ad.read_h5ad(rna_file)
sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors)
sc.tl.umap(rna)
scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts")
celltypes = sorted(rna.obs['celltype'].unique())
n_clusters = len(celltypes)
colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters))))
color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)}
sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)}
rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)}
color_map = {**color_map, **sorted_color_map, **rna_color_map}
color_map['Other'] = 'gray'
hic = ad.read_h5ad(hic_file)
hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"]
prior = nx.read_graphml(graph_file)
glue = scglue.models.load_model(model_file)
genes = []
loops = []
for e, attr in dict(prior.edges).items():
if attr["type"] == 'overlap':
gene_name = e[0]
if gene_name.startswith('chr'):
gene_name = e[1]
if gene_name not in genes and not gene_name.startswith('chr'):
genes.append(gene_name)
elif attr["type"] == 'hic':
loops.append(e)
rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"]
genes = rna.var.query(f"highly_variable").index
gene_idx_map = {}
for i, gene in enumerate(genes):
gene_idx_map[gene] = i
peaks = hic.var.query("highly_variable").copy()
rna_recon = glue.decode_data("rna", "rna", rna, prior)
def get_closest_peak_to_gene(gene_name, rna, peaks):
try:
loc = rna.var.loc[gene_name]
except KeyError:
print('Could not find loci', gene_name)
return None
chrom = loc["chrom"]
chromStart = loc["chromStart"]
chromEnd = loc["chromEnd"]
peaks['in_chr'] = peaks['chrom'] == chrom
peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart))
peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
return peaks['dist'].idxmin()
def get_chromosome_from_filename(filename):
chr_index = filename.find('chr') # index of chromosome name
if chr_index == 0: # if chromosome name is file prefix
return filename[:filename.find('.')]
file_ending_index = filename.rfind('.') # index of file ending
if chr_index > file_ending_index: # if chromosome name is file ending
return filename[chr_index:]
else:
return filename[chr_index: file_ending_index]
def perturb(gene, locus1, locus2):
locus1 = locus1.replace(',', '')
locus2 = locus2.replace(',', '')
res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []}
for c in celltypes:
res[c] = []
res[f'{c}_var'] = []
links_dict = {}
for c in celltypes:
links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []}
if locus1.startswith('chr'):
chrom = locus1.split(':')[0]
pos1 = locus1.split(':')[1]
a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
else:
a1 = get_closest_peak_to_gene(locus1, rna, peaks)
if locus2.startswith('chr'):
chrom = locus2.split(':')[0]
pos2 = locus2.split(':')[1]
a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
else:
a2 = get_closest_peak_to_gene(locus2, rna, peaks)
print(a1)
print(a2)
guidance_hvf = prior.copy()
try:
path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene)
try:
guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1])
except:
pass
try:
guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2])
except:
pass
except:
pass
try:
path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene)
try:
guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1])
except:
pass
try:
guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2])
except:
pass
except:
pass
try:
guidance_hvf.remove_edge(a1, a2)
guidance_hvf.remove_edge(a2, a1)
except:
pass
perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf)
ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps))
max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax()
max_gene = rna.var.query(f"highly_variable").index[max_gene_idx]
print('Max gene:', max_gene)
# get integer index of gene
gene_idx = gene_idx_map[gene]
rna.obs['log2FC'] = ism[:, gene_idx]
fig = sc.pl.umap(rna,
color=['celltype', 'log2FC'],
color_map='Spectral',
wspace=0.05,
legend_loc='on data',
legend_fontoutline=2,
frameon=False,
return_fig=True)
fig.suptitle(f'{gene} expression after removing ' + a1 + '-' + a2)
fig.tight_layout()
#fig.patch.set_facecolor('none')
#fig.patch.set_alpha(0.0)
fig.canvas.draw()
image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
# convert from argb to rgba
image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]]
fig = sc.pl.violin(rna, gene, groupby='celltype', return_fig=True)
fig.canvas.draw()
violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
violin_img = violin_img[:, :, [1, 2, 3, 0]]
fig = sc.pl.violin(rna, 'log2FC', groupby='celltype', return_fig=True)
fig.canvas.draw()
violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,))
violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]]
return image_from_plot, violin_img, violin_img_fc
with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
with gr.Row():
with gr.Column():
in_locus = gr.Textbox(label="Target Gene", elem_id='in-locus', scale=1)
anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1)
anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1)
with gr.Row():
run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
with gr.Row():
out_img = gr.Image(elem_id='out-img', scale=1)
with gr.Row():
out_violin = gr.Image(elem_id='out-violin', scale=1)
out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1)
#out_plot = gr.Plot(elem_id='out-plot', scale=1)
inputs = [in_locus, anchor1, anchor2]
outputs = [out_img, out_violin, out_violin_fc]
gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840'],
['INS', 'INS', 'IGF1'],
['TSPAN1', 'TSPAN1', 'PIK3R3'],
['GCG', 'GCG', 'FAP'],
['KRT19', 'chr17:22,220,637', 'chr17:39,591,813'],
['LPP', 'chr3:188,097,749', 'chr3:197,916,262'],
['CEL', 'chr9:135,937,365', 'chr9:135,973,107']],
inputs=inputs,
outputs=outputs,
fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces')
run_button.click(perturb, [in_locus, anchor1, anchor2], outputs=outputs)
if __name__ == "__main__":
demo.launch(share=False)