HiGLUE / app.py
dylanplummer's picture
Update app.py
f9fac2f verified
import gradio as gr
import anndata as ad
import networkx as nx
import scanpy as sc
import scglue
import os
import gzip
import shutil
import cooler
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm
from huggingface_hub import hf_hub_download
prior_name = 'dcq'
resolution = '10kb'
loop_q = 0.99
suffix = f'ice_{loop_q}'
out_dir = 'pooled_cells_' + suffix
plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal')
dist_range = 10e6
window_size = 64
eps = 1e-4
n_neighbors = 5
def unzip(file):
with gzip.open(file, 'rb') as f_in:
new_file = file.replace('.gz', '')
with open(new_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
return new_file
rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
rna_file = unzip(rna_file)
hic_file = unzip(hic_file)
cooler_celltypes = ['Alpha', 'Beta', 'Acinar', 'Duct', 'PSC']
cooler_resolutions = ['50kb', '200kb', '500kb']
coolers = {}
for res in cooler_resolutions:
coolers[res] = []
for celltype in cooler_celltypes:
for res in cooler_resolutions:
cooler_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"pseudobulk_coolers/{celltype}_{res}.cool", repo_type="dataset", token=os.environ['DATASET_SECRET'])
coolers[res].append(cooler.Cooler(cooler_file))
beta_cooler = coolers[cooler_resolutions[0]][1]
model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET'])
rna = ad.read_h5ad(rna_file)
# sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors)
# sc.tl.umap(rna)
scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts")
celltypes = sorted(rna.obs['celltype'].unique())
n_clusters = len(celltypes)
colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters))))
color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)}
sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)}
rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)}
color_map = {**color_map, **sorted_color_map, **rna_color_map}
color_map['Other'] = 'gray'
hic = ad.read_h5ad(hic_file)
hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"]
prior = nx.read_graphml(graph_file)
glue = scglue.models.load_model(model_file)
genes = []
loops = []
for e, attr in dict(prior.edges).items():
if attr["type"] == 'overlap':
gene_name = e[0]
if gene_name.startswith('chr'):
gene_name = e[1]
if gene_name not in genes and not gene_name.startswith('chr'):
genes.append(gene_name)
elif attr["type"] == 'hic':
loops.append(e)
rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"]
genes = rna.var.query(f"highly_variable").index
gene_idx_map = {}
for i, gene in enumerate(genes):
gene_idx_map[gene] = i
peaks = hic.var.query("highly_variable").copy()
rna_recon = glue.decode_data("rna", "rna", rna, prior)
def get_closest_peak_to_gene(gene_name, rna, peaks):
try:
loc = rna.var.loc[gene_name]
except KeyError:
print('Could not find loci', gene_name)
return None
chrom = loc["chrom"]
chromStart = loc["chromStart"]
peaks['in_chr'] = peaks['chrom'] == chrom
peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart))
peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
return peaks['dist'].idxmin()
def get_closest_gene_to_peak(peak_name, rna, peaks):
try:
loc = peaks.loc[peak_name]
except KeyError:
print('Could not find peak', peak_name)
return None
chrom = loc["chrom"]
chromStart = loc["chromStart"]
rna.var['in_chr'] = rna.var['chrom'] == chrom
rna.var['dist'] = rna.var['chromStart'].apply(lambda s: abs(s - chromStart))
rna.var.loc[~rna.var['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
return rna.var['dist'].idxmin()
def get_chromosome_from_filename(filename):
chr_index = filename.find('chr') # index of chromosome name
if chr_index == 0: # if chromosome name is file prefix
return filename[:filename.find('.')]
file_ending_index = filename.rfind('.') # index of file ending
if chr_index > file_ending_index: # if chromosome name is file ending
return filename[chr_index:]
else:
return filename[chr_index: file_ending_index]
def draw_heatmap(matrix, color_scale, ax=None, min_val=1.001, return_image=False, return_plt_im=True):
if color_scale != 0:
color_scale = min(color_scale, np.max(matrix))
breaks = np.append(np.arange(min_val, color_scale, (color_scale - min_val) / 18), np.max(matrix))
elif np.max(matrix) < 2:
breaks = np.arange(min_val, np.max(matrix), (np.max(matrix) - min_val) / 19)
else:
step = (np.quantile(matrix, q=0.98) - 1) / 18
up = np.quantile(matrix, q=0.98) + 0.011
if up < 2:
up = 2
step = 0.999 / 18
breaks = np.append(np.arange(min_val, up, step), np.max(matrix) + 0.01)
n_bin = 20 # Discretizes the interpolation into bins
colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686",
"#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D",
"#FF0000"]
cmap_name = 'deeploop'
# Create the colormap
cm = matplotlib.colors.LinearSegmentedColormap.from_list(
cmap_name, colors, N=n_bin)
norm = matplotlib.colors.BoundaryNorm(breaks, 20)
# Fewer bins will result in "coarser" colomap interpolation
if ax is None:
_, ax = plt.subplots()
img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation=None)
if return_image:
plt.close()
return img.get_array()
elif return_plt_im:
return img
def get_heatmap_locus(gene, locus1, locus2, heatmap_range):
locus1 = locus1.replace(',', '')
locus2 = locus2.replace(',', '')
try:
loc = rna.var.loc[gene]
except KeyError:
print('Could not find loci', gene)
return None
chrom = loc["chrom"]
if locus1.startswith('chr'):
chrom = locus1.split(':')[0]
pos1 = locus1.split(':')[1]
a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
else:
a1 = get_closest_peak_to_gene(locus1, rna, peaks)
if locus2.startswith('chr'):
chrom = locus2.split(':')[0]
pos2 = locus2.split(':')[1]
a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
else:
a2 = get_closest_peak_to_gene(locus2, rna, peaks)
a1_start = peaks.loc[a1, 'chromStart']
a2_start = peaks.loc[a2, 'chromStart']
interaction_dist = abs(a1_start - a2_start)
chrom_size = beta_cooler.chromsizes[chrom]
locus_start = max(0, a1_start - interaction_dist - heatmap_range * 1000)
locus_end = min(chrom_size, a2_start + interaction_dist + heatmap_range * 1000)
heatmap_locus = f'{chrom}:{locus_start}-{locus_end}'
heatmap_size = abs(locus_start - locus_end)
res = 'deeploop'
if heatmap_size > 5000000 or res not in cooler_resolutions:
res = '50kb'
if heatmap_size > 10000000:
res = '200kb'
if heatmap_size > 20000000:
res = '500kb'
print(heatmap_locus, res)
return heatmap_locus, res
def get_heatmap(celltype, gene, locus1, locus2, heatmap_range):
heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
c = coolers[res][cooler_celltypes.index(celltype)]
mat = c.matrix().fetch(heatmap_locus)
bins = c.bins().fetch(heatmap_locus)
locus1 = locus1.replace(',', '')
locus2 = locus2.replace(',', '')
if locus1.startswith('chr'):
a1_chrom = locus1.split(':')[0]
a1_start = int(locus1.split(':')[1])
else:
try:
loc = rna.var.loc[gene]
except KeyError:
print('Could not find loci', gene)
return None
a1_chrom = loc["chrom"]
a1_start = loc["chromStart"]
a1_idx = bins.query(f"chrom == '{a1_chrom}'")['start'].apply(lambda s: abs(s - a1_start)).argmin()
if locus2.startswith('chr'):
a2_chrom = locus2.split(':')[0]
a2_start = int(locus2.split(':')[1])
else:
try:
loc = rna.var.loc[gene]
except KeyError:
print('Could not find loci', gene)
return None
a2_chrom = loc["chrom"]
a2_start = loc["chromStart"]
a2_idx = bins.query(f"chrom == '{a2_chrom}'")['start'].apply(lambda s: abs(s - a2_start)).argmin()
print(a1_idx, a2_idx)
#img = draw_heatmap(mat, 0, return_image=True)
fig, ax = plt.subplots(figsize=(4, 4))
if res == 'deeploop':
draw_heatmap(mat, 3.0, ax=ax, min_val=0.8)
else:
ax.imshow(mat, cmap='Reds', norm=PowerNorm(gamma=0.3), interpolation=None)
# remove white padding
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.axis('off')
plt.axis('image')
w, h = fig.canvas.get_width_height()
#heatmap_x = (a1_idx / len(bins)) * w
#heatmap_y = (a2_idx / len(bins)) * h
#print(heatmap_x, heatmap_y)
heatmap_x = a1_idx
heatmap_y = a2_idx
ax.scatter(int(heatmap_x), int(heatmap_y), color='green', marker='2', s=150)
# redraw the canvas
fig = plt.gcf()
fig.canvas.draw()
# convert canvas to image using numpy
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return img
def get_heatmaps(gene, locus1, locus2, heatmap_range):
res = []
for celltype in cooler_celltypes:
res.append(get_heatmap(celltype, gene, locus1, locus2, heatmap_range))
alpha, beta, acinar, duct, psc = res
return alpha, beta, acinar, duct, psc
def perturb(gene, locus1, locus2, use_max_gene, tf_knockout):
if tf_knockout:
guidance_hvf = prior.copy()
try:
guidance_hvf.remove_node(gene)
except:
pass
# find index of tf and set all counts and values to zero
gene_idx = gene_idx_map[gene]
a1 = get_closest_peak_to_gene(gene, rna, peaks)
try:
guidance_hvf.remove_node(a1)
except:
pass
else:
locus1 = locus1.replace(',', '')
locus2 = locus2.replace(',', '')
res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []}
for c in celltypes:
res[c] = []
res[f'{c}_var'] = []
links_dict = {}
for c in celltypes:
links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []}
if locus1.startswith('chr'):
chrom = locus1.split(':')[0]
pos1 = locus1.split(':')[1]
a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
else:
a1 = get_closest_peak_to_gene(locus1, rna, peaks)
if locus2.startswith('chr'):
chrom = locus2.split(':')[0]
pos2 = locus2.split(':')[1]
a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
else:
a2 = get_closest_peak_to_gene(locus2, rna, peaks)
print(a1)
print(a2)
a1_closest_gene = get_closest_gene_to_peak(a1, rna, peaks)
a2_closest_gene = get_closest_gene_to_peak(a2, rna, peaks)
print(a1_closest_gene, a2_closest_gene)
guidance_hvf = prior.copy()
try:
path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene)
try:
guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1])
except:
pass
try:
guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2])
except:
pass
except:
pass
try:
path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene)
try:
guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1])
except:
pass
try:
guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2])
except:
pass
except:
pass
try:
guidance_hvf.remove_edge(a1, a2)
guidance_hvf.remove_edge(a2, a1)
except:
pass
try:
guidance_hvf.remove_edge(a1, a1_closest_gene)
guidance_hvf.remove_edge(a1_closest_gene, a1)
except:
pass
try:
guidance_hvf.remove_edge(a2, a2_closest_gene)
guidance_hvf.remove_edge(a2_closest_gene, a2)
except:
pass
perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf)
ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps))
if tf_knockout:
max_gene_idxs = np.abs(np.mean(ism, axis=0)).argsort()[::-1][:10]
for max_gene_idx in max_gene_idxs:
print(rna.var.query(f"highly_variable").index[max_gene_idx], np.mean(ism[:, max_gene_idx]))
max_gene_idx = max_gene_idxs[1]
else:
max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax()
max_gene = rna.var.query(f"highly_variable").index[max_gene_idx]
print('Max gene:', max_gene)
# get integer index of gene
if use_max_gene:
gene_idx = gene_idx_map[max_gene]
else:
gene_idx = gene_idx_map[gene]
rna.obs['log2FC'] = ism[:, gene_idx]
rna.obs['old_mean'] = rna_recon[:, gene_idx]
rna.obs['new_mean'] = perterbed_rna_recon[:, gene_idx]
# compute new count based on log2FC
rna.obs['new_count'] = (rna.layers['counts'][:, gene_idx] + eps) * 2 ** rna.obs['log2FC'] - eps
fig = sc.pl.umap(rna,
color=['celltype', 'log2FC'],
color_map='Spectral',
wspace=0.05,
legend_loc='on data',
legend_fontoutline=2,
frameon=False,
return_fig=True)
if tf_knockout:
fig.suptitle(f'{max_gene if use_max_gene else gene} expression after TF knockout of {gene}')
else:
fig.suptitle(f'{max_gene if use_max_gene else gene} expression after removing ' + a1 + '-' + a2)
#fig.tight_layout()
#fig.patch.set_facecolor('none')
#fig.patch.set_alpha(0.0)
fig.canvas.draw()
image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
# convert from argb to rgba
image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]]
fig, ax = plt.subplots(figsize=(6, 4))
sc.pl.violin(rna, max_gene if use_max_gene else gene, groupby='celltype', layer='counts', show=False, ax=ax, rotation=90)
#sc.pl.violin(rna, 'old_mean', groupby='celltype', show=False, ax=ax, rotation=90)
ax.set_title(f'Normal {max_gene if use_max_gene else gene} expression')
fig.tight_layout()
fig.canvas.draw()
violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
violin_img = violin_img[:, :, [1, 2, 3, 0]]
fig, ax = plt.subplots(figsize=(6, 4))
sc.pl.violin(rna, 'new_mean', groupby='celltype', show=False, ax=ax, rotation=90)
ax.set_title(f'{max_gene if use_max_gene else gene} normalized expression after perturbation')
fig.tight_layout()
fig.canvas.draw()
violin_img_new = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
violin_img_new = violin_img_new.reshape(fig.canvas.get_width_height()[::-1] + (4,))
violin_img_new = violin_img_new[:, :, [1, 2, 3, 0]]
fig, ax = plt.subplots(figsize=(6, 4))
sc.pl.violin(rna, 'log2FC', groupby='celltype', show=False, ax=ax, rotation=90)
ax.set_title(f'{max_gene if use_max_gene else gene} expression change')
ax.hlines(0, -1, len(celltypes), linestyles='dashed')
fig.tight_layout()
fig.canvas.draw()
violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,))
violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]]
return image_from_plot, violin_img, violin_img_new, violin_img_fc
def update_on_tf_ko(tf_knockout, use_max_gene):
if tf_knockout:
return gr.update(visible=False), True
else:
return gr.update(visible=True), use_max_gene
with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
with gr.Row():
with gr.Column():
in_locus = gr.Textbox(label="Gene/TF", elem_id='in-locus', scale=1)
with gr.Row():
max_gene_checkbox = gr.Checkbox(label="Max Gene", elem_id='max-gene-checkbox', scale=1, checked=False)
tf_ko = gr.Checkbox(label="TF KO", elem_id='tf-ko', scale=1, checked=False)
with gr.Column() as heatmap_col:
anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1)
anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1)
heatmap_x = gr.Textbox(label="Heatmap X", elem_id='heatmap-x', scale=1, visible=False)
heatmap_y = gr.Textbox(label="Heatmap Y", elem_id='heatmap-y', scale=1, visible=False)
heatmap_size = gr.Slider(label="Heatmap Range", info="kb range aroung input gene locus to expand", minimum=50, maximum=5000, value=250)
heatmap_button = gr.Button(value="Generate Heatmaps", elem_id='heatmap-button', scale=1)
with gr.Row():
out_heatmaps = []
for celltype in cooler_celltypes:
out_heatmaps.append(gr.Image(label=celltype, elem_id=f'out-heatmap-{celltype}', scale=1))
with gr.Row():
run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
with gr.Row():
out_img = gr.Image(elem_id='out-img', scale=1)
with gr.Row():
out_violin = gr.Image(elem_id='out-violin', scale=1)
out_violin_new = gr.Image(elem_id='out-violin-new', scale=1)
out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1)
#out_plot = gr.Plot(elem_id='out-plot', scale=1)
inputs = [in_locus, anchor1, anchor2, max_gene_checkbox, tf_ko]
outputs = [out_img, out_violin, out_violin_new, out_violin_fc]
gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840', False, False],
['BHLHA15', '', '', True, True],
['FOXA2', '', '', True, True], # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4878272/
# ['OTUD3', 'chr1:20457786', 'chr1:20230413', False],
['IRX2', 'chr5:2704405', 'chr5:2164879', False, False],
['TSPAN1', 'TSPAN1', 'PIK3R3', False, False],
# ['IRX1', 'chr5:5397612', 'chr5:4850008', False],
['GCG', 'GCG', 'FAP', False, False],
# ['GCG', 'chr2:163162903', 'chr2:162852164', True],
['LOXL4', 'chr10:100186215', 'chr10:99913493', False, False],
# ['KRT19', 'chr17:22,220,637', 'chr17:39,591,813', False],
# ['LPP', 'chr3:188,097,749', 'chr3:197,916,262', False],
['MAFB', 'chr20:39431654', 'chr20:39368271', True, False],
['CEL', 'chr9:135,937,365', 'chr9:135,973,107', False, False]],
inputs=inputs,
outputs=outputs,
fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces')
run_button.click(perturb, inputs, outputs=outputs)
heatmap_button.click(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
tf_ko.change(update_on_tf_ko, [tf_ko, max_gene_checkbox], outputs=[heatmap_col, max_gene_checkbox])
anchor1.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
anchor2.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
def set_loop(img, gene, locus1, locus2, heatmap_range, evt: gr.SelectData):
h, w = img.shape[:2]
idx = evt.index
x, y = idx[0], idx[1]
heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
bins = coolers[res][0].bins().fetch(heatmap_locus)
bins_idx_x = int(x / w * len(bins))
bins_idx_y = int(y / h * len(bins))
new_a1 = f"{bins.iloc[bins_idx_x]['chrom']}:{bins.iloc[bins_idx_x]['start']}"
new_a2 = f"{bins.iloc[bins_idx_y]['chrom']}:{bins.iloc[bins_idx_y]['start']}"
return new_a1, new_a2
for out_heatmap in out_heatmaps:
out_heatmap.select(set_loop, [out_heatmap, in_locus, anchor1, anchor2, heatmap_size], outputs=[anchor1, anchor2])
if __name__ == "__main__":
demo.launch(share=False)