dylanplummer commited on
Commit
f9fac2f
·
verified ·
1 Parent(s): e5f3ece

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +519 -519
app.py CHANGED
@@ -1,520 +1,520 @@
1
- import gradio as gr
2
- import anndata as ad
3
- import networkx as nx
4
- import scanpy as sc
5
- import scglue
6
- import os
7
- import gzip
8
- import shutil
9
- import cooler
10
- import matplotlib
11
- import numpy as np
12
- import matplotlib.pyplot as plt
13
- from matplotlib.colors import PowerNorm
14
-
15
- from huggingface_hub import hf_hub_download
16
-
17
-
18
- prior_name = 'dcq'
19
- resolution = '10kb'
20
- loop_q = 0.99
21
- suffix = f'ice_{loop_q}'
22
- out_dir = 'pooled_cells_' + suffix
23
- plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal')
24
- dist_range = 10e6
25
- window_size = 64
26
- eps = 1e-4
27
- n_neighbors = 5
28
-
29
-
30
- def unzip(file):
31
- with gzip.open(file, 'rb') as f_in:
32
- new_file = file.replace('.gz', '')
33
- with open(new_file, 'wb') as f_out:
34
- shutil.copyfileobj(f_in, f_out)
35
- return new_file
36
-
37
-
38
- rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
39
- hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
40
- graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
41
- rna_file = unzip(rna_file)
42
- hic_file = unzip(hic_file)
43
-
44
- cooler_celltypes = ['Alpha', 'Beta', 'Acinar', 'Duct', 'PSC']
45
- cooler_resolutions = ['deeploop', '50kb', '200kb', '500kb']
46
- coolers = {}
47
- for res in cooler_resolutions:
48
- coolers[res] = []
49
- for celltype in cooler_celltypes:
50
- for res in cooler_resolutions:
51
- cooler_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"pseudobulk_coolers/{celltype}_{res}.cool", repo_type="dataset", token=os.environ['DATASET_SECRET'])
52
- coolers[res].append(cooler.Cooler(cooler_file))
53
- beta_cooler = coolers[cooler_resolutions[0]][1]
54
-
55
- model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET'])
56
-
57
- rna = ad.read_h5ad(rna_file)
58
- # sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors)
59
- # sc.tl.umap(rna)
60
- scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts")
61
-
62
- celltypes = sorted(rna.obs['celltype'].unique())
63
- n_clusters = len(celltypes)
64
- colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters))))
65
- color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)}
66
- sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)}
67
- rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)}
68
- color_map = {**color_map, **sorted_color_map, **rna_color_map}
69
- color_map['Other'] = 'gray'
70
-
71
- hic = ad.read_h5ad(hic_file)
72
- hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"]
73
- prior = nx.read_graphml(graph_file)
74
- glue = scglue.models.load_model(model_file)
75
-
76
-
77
- genes = []
78
- loops = []
79
- for e, attr in dict(prior.edges).items():
80
- if attr["type"] == 'overlap':
81
- gene_name = e[0]
82
- if gene_name.startswith('chr'):
83
- gene_name = e[1]
84
- if gene_name not in genes and not gene_name.startswith('chr'):
85
- genes.append(gene_name)
86
- elif attr["type"] == 'hic':
87
- loops.append(e)
88
- rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"]
89
- genes = rna.var.query(f"highly_variable").index
90
- gene_idx_map = {}
91
- for i, gene in enumerate(genes):
92
- gene_idx_map[gene] = i
93
- peaks = hic.var.query("highly_variable").copy()
94
-
95
- rna_recon = glue.decode_data("rna", "rna", rna, prior)
96
-
97
-
98
- def get_closest_peak_to_gene(gene_name, rna, peaks):
99
- try:
100
- loc = rna.var.loc[gene_name]
101
- except KeyError:
102
- print('Could not find loci', gene_name)
103
- return None
104
- chrom = loc["chrom"]
105
- chromStart = loc["chromStart"]
106
- peaks['in_chr'] = peaks['chrom'] == chrom
107
- peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart))
108
- peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
109
- return peaks['dist'].idxmin()
110
-
111
-
112
- def get_closest_gene_to_peak(peak_name, rna, peaks):
113
- try:
114
- loc = peaks.loc[peak_name]
115
- except KeyError:
116
- print('Could not find peak', peak_name)
117
- return None
118
- chrom = loc["chrom"]
119
- chromStart = loc["chromStart"]
120
- rna.var['in_chr'] = rna.var['chrom'] == chrom
121
- rna.var['dist'] = rna.var['chromStart'].apply(lambda s: abs(s - chromStart))
122
- rna.var.loc[~rna.var['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
123
- return rna.var['dist'].idxmin()
124
-
125
-
126
- def get_chromosome_from_filename(filename):
127
- chr_index = filename.find('chr') # index of chromosome name
128
- if chr_index == 0: # if chromosome name is file prefix
129
- return filename[:filename.find('.')]
130
- file_ending_index = filename.rfind('.') # index of file ending
131
- if chr_index > file_ending_index: # if chromosome name is file ending
132
- return filename[chr_index:]
133
- else:
134
- return filename[chr_index: file_ending_index]
135
-
136
-
137
- def draw_heatmap(matrix, color_scale, ax=None, min_val=1.001, return_image=False, return_plt_im=True):
138
- if color_scale != 0:
139
- color_scale = min(color_scale, np.max(matrix))
140
- breaks = np.append(np.arange(min_val, color_scale, (color_scale - min_val) / 18), np.max(matrix))
141
- elif np.max(matrix) < 2:
142
- breaks = np.arange(min_val, np.max(matrix), (np.max(matrix) - min_val) / 19)
143
- else:
144
- step = (np.quantile(matrix, q=0.98) - 1) / 18
145
- up = np.quantile(matrix, q=0.98) + 0.011
146
- if up < 2:
147
- up = 2
148
- step = 0.999 / 18
149
- breaks = np.append(np.arange(min_val, up, step), np.max(matrix) + 0.01)
150
- n_bin = 20 # Discretizes the interpolation into bins
151
- colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686",
152
- "#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D",
153
- "#FF0000"]
154
- cmap_name = 'deeploop'
155
- # Create the colormap
156
- cm = matplotlib.colors.LinearSegmentedColormap.from_list(
157
- cmap_name, colors, N=n_bin)
158
- norm = matplotlib.colors.BoundaryNorm(breaks, 20)
159
- # Fewer bins will result in "coarser" colomap interpolation
160
- if ax is None:
161
- _, ax = plt.subplots()
162
- img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation=None)
163
- if return_image:
164
- plt.close()
165
- return img.get_array()
166
- elif return_plt_im:
167
- return img
168
-
169
-
170
- def get_heatmap_locus(gene, locus1, locus2, heatmap_range):
171
- locus1 = locus1.replace(',', '')
172
- locus2 = locus2.replace(',', '')
173
- try:
174
- loc = rna.var.loc[gene]
175
- except KeyError:
176
- print('Could not find loci', gene)
177
- return None
178
- chrom = loc["chrom"]
179
- if locus1.startswith('chr'):
180
- chrom = locus1.split(':')[0]
181
- pos1 = locus1.split(':')[1]
182
- a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
183
- else:
184
- a1 = get_closest_peak_to_gene(locus1, rna, peaks)
185
- if locus2.startswith('chr'):
186
- chrom = locus2.split(':')[0]
187
- pos2 = locus2.split(':')[1]
188
- a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
189
- else:
190
- a2 = get_closest_peak_to_gene(locus2, rna, peaks)
191
- a1_start = peaks.loc[a1, 'chromStart']
192
- a2_start = peaks.loc[a2, 'chromStart']
193
- interaction_dist = abs(a1_start - a2_start)
194
- chrom_size = beta_cooler.chromsizes[chrom]
195
- locus_start = max(0, a1_start - interaction_dist - heatmap_range * 1000)
196
- locus_end = min(chrom_size, a2_start + interaction_dist + heatmap_range * 1000)
197
- heatmap_locus = f'{chrom}:{locus_start}-{locus_end}'
198
- heatmap_size = abs(locus_start - locus_end)
199
- res = 'deeploop'
200
- if heatmap_size > 5000000:
201
- res = '50kb'
202
- if heatmap_size > 10000000:
203
- res = '200kb'
204
- if heatmap_size > 20000000:
205
- res = '500kb'
206
- print(heatmap_locus, res)
207
- return heatmap_locus, res
208
-
209
-
210
- def get_heatmap(celltype, gene, locus1, locus2, heatmap_range):
211
- heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
212
- c = coolers[res][cooler_celltypes.index(celltype)]
213
- mat = c.matrix().fetch(heatmap_locus)
214
- bins = c.bins().fetch(heatmap_locus)
215
- locus1 = locus1.replace(',', '')
216
- locus2 = locus2.replace(',', '')
217
- if locus1.startswith('chr'):
218
- a1_chrom = locus1.split(':')[0]
219
- a1_start = int(locus1.split(':')[1])
220
- else:
221
- try:
222
- loc = rna.var.loc[gene]
223
- except KeyError:
224
- print('Could not find loci', gene)
225
- return None
226
- a1_chrom = loc["chrom"]
227
- a1_start = loc["chromStart"]
228
- a1_idx = bins.query(f"chrom == '{a1_chrom}'")['start'].apply(lambda s: abs(s - a1_start)).argmin()
229
- if locus2.startswith('chr'):
230
- a2_chrom = locus2.split(':')[0]
231
- a2_start = int(locus2.split(':')[1])
232
- else:
233
- try:
234
- loc = rna.var.loc[gene]
235
- except KeyError:
236
- print('Could not find loci', gene)
237
- return None
238
- a2_chrom = loc["chrom"]
239
- a2_start = loc["chromStart"]
240
- a2_idx = bins.query(f"chrom == '{a2_chrom}'")['start'].apply(lambda s: abs(s - a2_start)).argmin()
241
- print(a1_idx, a2_idx)
242
- #img = draw_heatmap(mat, 0, return_image=True)
243
- fig, ax = plt.subplots(figsize=(4, 4))
244
- if res == 'deeploop':
245
- draw_heatmap(mat, 3.0, ax=ax, min_val=0.8)
246
- else:
247
- ax.imshow(mat, cmap='Reds', norm=PowerNorm(gamma=0.3), interpolation=None)
248
-
249
- # remove white padding
250
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
251
- plt.axis('off')
252
- plt.axis('image')
253
-
254
- w, h = fig.canvas.get_width_height()
255
- #heatmap_x = (a1_idx / len(bins)) * w
256
- #heatmap_y = (a2_idx / len(bins)) * h
257
- #print(heatmap_x, heatmap_y)
258
- heatmap_x = a1_idx
259
- heatmap_y = a2_idx
260
- ax.scatter(int(heatmap_x), int(heatmap_y), color='green', marker='2', s=150)
261
-
262
- # redraw the canvas
263
- fig = plt.gcf()
264
- fig.canvas.draw()
265
-
266
- # convert canvas to image using numpy
267
- img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
268
- img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
269
- plt.close()
270
- return img
271
-
272
- def get_heatmaps(gene, locus1, locus2, heatmap_range):
273
- res = []
274
- for celltype in cooler_celltypes:
275
- res.append(get_heatmap(celltype, gene, locus1, locus2, heatmap_range))
276
- alpha, beta, acinar, duct, psc = res
277
- return alpha, beta, acinar, duct, psc
278
-
279
-
280
- def perturb(gene, locus1, locus2, use_max_gene, tf_knockout):
281
- if tf_knockout:
282
- guidance_hvf = prior.copy()
283
- try:
284
- guidance_hvf.remove_node(gene)
285
- except:
286
- pass
287
- # find index of tf and set all counts and values to zero
288
- gene_idx = gene_idx_map[gene]
289
- a1 = get_closest_peak_to_gene(gene, rna, peaks)
290
- try:
291
- guidance_hvf.remove_node(a1)
292
- except:
293
- pass
294
- else:
295
- locus1 = locus1.replace(',', '')
296
- locus2 = locus2.replace(',', '')
297
- res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []}
298
- for c in celltypes:
299
- res[c] = []
300
- res[f'{c}_var'] = []
301
-
302
- links_dict = {}
303
- for c in celltypes:
304
- links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []}
305
-
306
- if locus1.startswith('chr'):
307
- chrom = locus1.split(':')[0]
308
- pos1 = locus1.split(':')[1]
309
- a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
310
- else:
311
- a1 = get_closest_peak_to_gene(locus1, rna, peaks)
312
- if locus2.startswith('chr'):
313
- chrom = locus2.split(':')[0]
314
- pos2 = locus2.split(':')[1]
315
- a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
316
- else:
317
- a2 = get_closest_peak_to_gene(locus2, rna, peaks)
318
-
319
- print(a1)
320
- print(a2)
321
-
322
- a1_closest_gene = get_closest_gene_to_peak(a1, rna, peaks)
323
- a2_closest_gene = get_closest_gene_to_peak(a2, rna, peaks)
324
- print(a1_closest_gene, a2_closest_gene)
325
-
326
- guidance_hvf = prior.copy()
327
- try:
328
- path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene)
329
- try:
330
- guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1])
331
- except:
332
- pass
333
- try:
334
- guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2])
335
- except:
336
- pass
337
- except:
338
- pass
339
- try:
340
- path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene)
341
- try:
342
- guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1])
343
- except:
344
- pass
345
- try:
346
- guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2])
347
- except:
348
- pass
349
- except:
350
- pass
351
- try:
352
- guidance_hvf.remove_edge(a1, a2)
353
- guidance_hvf.remove_edge(a2, a1)
354
- except:
355
- pass
356
- try:
357
- guidance_hvf.remove_edge(a1, a1_closest_gene)
358
- guidance_hvf.remove_edge(a1_closest_gene, a1)
359
- except:
360
- pass
361
- try:
362
- guidance_hvf.remove_edge(a2, a2_closest_gene)
363
- guidance_hvf.remove_edge(a2_closest_gene, a2)
364
- except:
365
- pass
366
-
367
- perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf)
368
- ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps))
369
- if tf_knockout:
370
- max_gene_idxs = np.abs(np.mean(ism, axis=0)).argsort()[::-1][:10]
371
- for max_gene_idx in max_gene_idxs:
372
- print(rna.var.query(f"highly_variable").index[max_gene_idx], np.mean(ism[:, max_gene_idx]))
373
- max_gene_idx = max_gene_idxs[1]
374
- else:
375
- max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax()
376
- max_gene = rna.var.query(f"highly_variable").index[max_gene_idx]
377
- print('Max gene:', max_gene)
378
-
379
- # get integer index of gene
380
- if use_max_gene:
381
- gene_idx = gene_idx_map[max_gene]
382
- else:
383
- gene_idx = gene_idx_map[gene]
384
- rna.obs['log2FC'] = ism[:, gene_idx]
385
- rna.obs['old_mean'] = rna_recon[:, gene_idx]
386
- rna.obs['new_mean'] = perterbed_rna_recon[:, gene_idx]
387
- # compute new count based on log2FC
388
- rna.obs['new_count'] = (rna.layers['counts'][:, gene_idx] + eps) * 2 ** rna.obs['log2FC'] - eps
389
-
390
- fig = sc.pl.umap(rna,
391
- color=['celltype', 'log2FC'],
392
- color_map='Spectral',
393
- wspace=0.05,
394
- legend_loc='on data',
395
- legend_fontoutline=2,
396
- frameon=False,
397
- return_fig=True)
398
- if tf_knockout:
399
- fig.suptitle(f'{max_gene if use_max_gene else gene} expression after TF knockout of {gene}')
400
- else:
401
- fig.suptitle(f'{max_gene if use_max_gene else gene} expression after removing ' + a1 + '-' + a2)
402
- #fig.tight_layout()
403
- #fig.patch.set_facecolor('none')
404
- #fig.patch.set_alpha(0.0)
405
-
406
- fig.canvas.draw()
407
- image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
408
- image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
409
- # convert from argb to rgba
410
- image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]]
411
-
412
- fig, ax = plt.subplots(figsize=(6, 4))
413
- sc.pl.violin(rna, max_gene if use_max_gene else gene, groupby='celltype', layer='counts', show=False, ax=ax, rotation=90)
414
- #sc.pl.violin(rna, 'old_mean', groupby='celltype', show=False, ax=ax, rotation=90)
415
- ax.set_title(f'Normal {max_gene if use_max_gene else gene} expression')
416
- fig.tight_layout()
417
- fig.canvas.draw()
418
- violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
419
- violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
420
- violin_img = violin_img[:, :, [1, 2, 3, 0]]
421
-
422
- fig, ax = plt.subplots(figsize=(6, 4))
423
- sc.pl.violin(rna, 'new_mean', groupby='celltype', show=False, ax=ax, rotation=90)
424
- ax.set_title(f'{max_gene if use_max_gene else gene} normalized expression after perturbation')
425
- fig.tight_layout()
426
- fig.canvas.draw()
427
- violin_img_new = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
428
- violin_img_new = violin_img_new.reshape(fig.canvas.get_width_height()[::-1] + (4,))
429
- violin_img_new = violin_img_new[:, :, [1, 2, 3, 0]]
430
-
431
- fig, ax = plt.subplots(figsize=(6, 4))
432
- sc.pl.violin(rna, 'log2FC', groupby='celltype', show=False, ax=ax, rotation=90)
433
- ax.set_title(f'{max_gene if use_max_gene else gene} expression change')
434
- ax.hlines(0, -1, len(celltypes), linestyles='dashed')
435
- fig.tight_layout()
436
- fig.canvas.draw()
437
- violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
438
- violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,))
439
- violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]]
440
-
441
- return image_from_plot, violin_img, violin_img_new, violin_img_fc
442
-
443
-
444
- def update_on_tf_ko(tf_knockout, use_max_gene):
445
- if tf_knockout:
446
- return gr.update(visible=False), True
447
- else:
448
- return gr.update(visible=True), use_max_gene
449
-
450
-
451
- with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
452
- with gr.Row():
453
- with gr.Column():
454
- in_locus = gr.Textbox(label="Gene/TF", elem_id='in-locus', scale=1)
455
- with gr.Row():
456
- max_gene_checkbox = gr.Checkbox(label="Max Gene", elem_id='max-gene-checkbox', scale=1, checked=False)
457
- tf_ko = gr.Checkbox(label="TF KO", elem_id='tf-ko', scale=1, checked=False)
458
- with gr.Column() as heatmap_col:
459
- anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1)
460
- anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1)
461
- heatmap_x = gr.Textbox(label="Heatmap X", elem_id='heatmap-x', scale=1, visible=False)
462
- heatmap_y = gr.Textbox(label="Heatmap Y", elem_id='heatmap-y', scale=1, visible=False)
463
- heatmap_size = gr.Slider(label="Heatmap Range", info="kb range aroung input gene locus to expand", minimum=50, maximum=5000, value=250)
464
- heatmap_button = gr.Button(value="Generate Heatmaps", elem_id='heatmap-button', scale=1)
465
- with gr.Row():
466
- out_heatmaps = []
467
- for celltype in cooler_celltypes:
468
- out_heatmaps.append(gr.Image(label=celltype, elem_id=f'out-heatmap-{celltype}', scale=1))
469
- with gr.Row():
470
- run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
471
- with gr.Row():
472
- out_img = gr.Image(elem_id='out-img', scale=1)
473
- with gr.Row():
474
- out_violin = gr.Image(elem_id='out-violin', scale=1)
475
- out_violin_new = gr.Image(elem_id='out-violin-new', scale=1)
476
- out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1)
477
- #out_plot = gr.Plot(elem_id='out-plot', scale=1)
478
-
479
- inputs = [in_locus, anchor1, anchor2, max_gene_checkbox, tf_ko]
480
- outputs = [out_img, out_violin, out_violin_new, out_violin_fc]
481
-
482
- gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840', False, False],
483
- ['BHLHA15', '', '', True, True],
484
- ['FOXA2', '', '', True, True], # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4878272/
485
- # ['OTUD3', 'chr1:20457786', 'chr1:20230413', False],
486
- ['IRX2', 'chr5:2704405', 'chr5:2164879', False, False],
487
- ['TSPAN1', 'TSPAN1', 'PIK3R3', False, False],
488
- # ['IRX1', 'chr5:5397612', 'chr5:4850008', False],
489
- ['GCG', 'GCG', 'FAP', False, False],
490
- # ['GCG', 'chr2:163162903', 'chr2:162852164', True],
491
- ['LOXL4', 'chr10:100186215', 'chr10:99913493', False, False],
492
- # ['KRT19', 'chr17:22,220,637', 'chr17:39,591,813', False],
493
- # ['LPP', 'chr3:188,097,749', 'chr3:197,916,262', False],
494
- ['MAFB', 'chr20:39431654', 'chr20:39368271', True, False],
495
- ['CEL', 'chr9:135,937,365', 'chr9:135,973,107', False, False]],
496
- inputs=inputs,
497
- outputs=outputs,
498
- fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces')
499
- run_button.click(perturb, inputs, outputs=outputs)
500
- heatmap_button.click(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
501
- tf_ko.change(update_on_tf_ko, [tf_ko, max_gene_checkbox], outputs=[heatmap_col, max_gene_checkbox])
502
- anchor1.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
503
- anchor2.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
504
- def set_loop(img, gene, locus1, locus2, heatmap_range, evt: gr.SelectData):
505
- h, w = img.shape[:2]
506
- idx = evt.index
507
- x, y = idx[0], idx[1]
508
- heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
509
- bins = coolers[res][0].bins().fetch(heatmap_locus)
510
- bins_idx_x = int(x / w * len(bins))
511
- bins_idx_y = int(y / h * len(bins))
512
- new_a1 = f"{bins.iloc[bins_idx_x]['chrom']}:{bins.iloc[bins_idx_x]['start']}"
513
- new_a2 = f"{bins.iloc[bins_idx_y]['chrom']}:{bins.iloc[bins_idx_y]['start']}"
514
- return new_a1, new_a2
515
- for out_heatmap in out_heatmaps:
516
- out_heatmap.select(set_loop, [out_heatmap, in_locus, anchor1, anchor2, heatmap_size], outputs=[anchor1, anchor2])
517
-
518
-
519
- if __name__ == "__main__":
520
  demo.launch(share=False)
 
1
+ import gradio as gr
2
+ import anndata as ad
3
+ import networkx as nx
4
+ import scanpy as sc
5
+ import scglue
6
+ import os
7
+ import gzip
8
+ import shutil
9
+ import cooler
10
+ import matplotlib
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib.colors import PowerNorm
14
+
15
+ from huggingface_hub import hf_hub_download
16
+
17
+
18
+ prior_name = 'dcq'
19
+ resolution = '10kb'
20
+ loop_q = 0.99
21
+ suffix = f'ice_{loop_q}'
22
+ out_dir = 'pooled_cells_' + suffix
23
+ plot_dir = os.path.join(out_dir, 'ism_loop_plots_distal')
24
+ dist_range = 10e6
25
+ window_size = 64
26
+ eps = 1e-4
27
+ n_neighbors = 5
28
+
29
+
30
+ def unzip(file):
31
+ with gzip.open(file, 'rb') as f_in:
32
+ new_file = file.replace('.gz', '')
33
+ with open(new_file, 'wb') as f_out:
34
+ shutil.copyfileobj(f_in, f_out)
35
+ return new_file
36
+
37
+
38
+ rna_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="rna.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
39
+ hic_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename="hic.h5ad.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
40
+ graph_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"{prior_name}_prior_{resolution}_{suffix}.graphml.gz", repo_type="dataset", token=os.environ['DATASET_SECRET'])
41
+ rna_file = unzip(rna_file)
42
+ hic_file = unzip(hic_file)
43
+
44
+ cooler_celltypes = ['Alpha', 'Beta', 'Acinar', 'Duct', 'PSC']
45
+ cooler_resolutions = ['50kb', '200kb', '500kb']
46
+ coolers = {}
47
+ for res in cooler_resolutions:
48
+ coolers[res] = []
49
+ for celltype in cooler_celltypes:
50
+ for res in cooler_resolutions:
51
+ cooler_file = hf_hub_download(repo_id="dylanplummer/islet-epigenome", filename=f"pseudobulk_coolers/{celltype}_{res}.cool", repo_type="dataset", token=os.environ['DATASET_SECRET'])
52
+ coolers[res].append(cooler.Cooler(cooler_file))
53
+ beta_cooler = coolers[cooler_resolutions[0]][1]
54
+
55
+ model_file = hf_hub_download(repo_id="dylanplummer/HiGLUE-islet", filename=f"glue_hic_{prior_name}_prior_{resolution}_{suffix}.dill", repo_type="model", token=os.environ['DATASET_SECRET'])
56
+
57
+ rna = ad.read_h5ad(rna_file)
58
+ # sc.pp.neighbors(rna, metric="cosine", n_neighbors=n_neighbors)
59
+ # sc.tl.umap(rna)
60
+ scglue.models.configure_dataset(rna, "NB", use_highly_variable=True, use_rep="X_pca", use_layer="counts")
61
+
62
+ celltypes = sorted(rna.obs['celltype'].unique())
63
+ n_clusters = len(celltypes)
64
+ colors = list(plt.cm.tab10(np.int32(np.linspace(0, n_clusters + 0.99, n_clusters))))
65
+ color_map = {celltype: colors[i] for i, celltype in enumerate(celltypes)}
66
+ sorted_color_map = {celltype + '_sorted': colors[i] for i, celltype in enumerate(celltypes)}
67
+ rna_color_map = {celltype + '_rna': colors[i] for i, celltype in enumerate(celltypes)}
68
+ color_map = {**color_map, **sorted_color_map, **rna_color_map}
69
+ color_map['Other'] = 'gray'
70
+
71
+ hic = ad.read_h5ad(hic_file)
72
+ hic.var["highly_variable"] = hic.var[f"{prior_name}_highly_variable"]
73
+ prior = nx.read_graphml(graph_file)
74
+ glue = scglue.models.load_model(model_file)
75
+
76
+
77
+ genes = []
78
+ loops = []
79
+ for e, attr in dict(prior.edges).items():
80
+ if attr["type"] == 'overlap':
81
+ gene_name = e[0]
82
+ if gene_name.startswith('chr'):
83
+ gene_name = e[1]
84
+ if gene_name not in genes and not gene_name.startswith('chr'):
85
+ genes.append(gene_name)
86
+ elif attr["type"] == 'hic':
87
+ loops.append(e)
88
+ rna.var["highly_variable"] = rna.var["highly_variable"] & rna.var["in_hic"]
89
+ genes = rna.var.query(f"highly_variable").index
90
+ gene_idx_map = {}
91
+ for i, gene in enumerate(genes):
92
+ gene_idx_map[gene] = i
93
+ peaks = hic.var.query("highly_variable").copy()
94
+
95
+ rna_recon = glue.decode_data("rna", "rna", rna, prior)
96
+
97
+
98
+ def get_closest_peak_to_gene(gene_name, rna, peaks):
99
+ try:
100
+ loc = rna.var.loc[gene_name]
101
+ except KeyError:
102
+ print('Could not find loci', gene_name)
103
+ return None
104
+ chrom = loc["chrom"]
105
+ chromStart = loc["chromStart"]
106
+ peaks['in_chr'] = peaks['chrom'] == chrom
107
+ peaks['dist'] = peaks['chromStart'].apply(lambda s: abs(s - chromStart))
108
+ peaks.loc[~peaks['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
109
+ return peaks['dist'].idxmin()
110
+
111
+
112
+ def get_closest_gene_to_peak(peak_name, rna, peaks):
113
+ try:
114
+ loc = peaks.loc[peak_name]
115
+ except KeyError:
116
+ print('Could not find peak', peak_name)
117
+ return None
118
+ chrom = loc["chrom"]
119
+ chromStart = loc["chromStart"]
120
+ rna.var['in_chr'] = rna.var['chrom'] == chrom
121
+ rna.var['dist'] = rna.var['chromStart'].apply(lambda s: abs(s - chromStart))
122
+ rna.var.loc[~rna.var['in_chr'], 'dist'] = 1e9 # set distance to 1e9 if not in same chromosome
123
+ return rna.var['dist'].idxmin()
124
+
125
+
126
+ def get_chromosome_from_filename(filename):
127
+ chr_index = filename.find('chr') # index of chromosome name
128
+ if chr_index == 0: # if chromosome name is file prefix
129
+ return filename[:filename.find('.')]
130
+ file_ending_index = filename.rfind('.') # index of file ending
131
+ if chr_index > file_ending_index: # if chromosome name is file ending
132
+ return filename[chr_index:]
133
+ else:
134
+ return filename[chr_index: file_ending_index]
135
+
136
+
137
+ def draw_heatmap(matrix, color_scale, ax=None, min_val=1.001, return_image=False, return_plt_im=True):
138
+ if color_scale != 0:
139
+ color_scale = min(color_scale, np.max(matrix))
140
+ breaks = np.append(np.arange(min_val, color_scale, (color_scale - min_val) / 18), np.max(matrix))
141
+ elif np.max(matrix) < 2:
142
+ breaks = np.arange(min_val, np.max(matrix), (np.max(matrix) - min_val) / 19)
143
+ else:
144
+ step = (np.quantile(matrix, q=0.98) - 1) / 18
145
+ up = np.quantile(matrix, q=0.98) + 0.011
146
+ if up < 2:
147
+ up = 2
148
+ step = 0.999 / 18
149
+ breaks = np.append(np.arange(min_val, up, step), np.max(matrix) + 0.01)
150
+ n_bin = 20 # Discretizes the interpolation into bins
151
+ colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686",
152
+ "#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D",
153
+ "#FF0000"]
154
+ cmap_name = 'deeploop'
155
+ # Create the colormap
156
+ cm = matplotlib.colors.LinearSegmentedColormap.from_list(
157
+ cmap_name, colors, N=n_bin)
158
+ norm = matplotlib.colors.BoundaryNorm(breaks, 20)
159
+ # Fewer bins will result in "coarser" colomap interpolation
160
+ if ax is None:
161
+ _, ax = plt.subplots()
162
+ img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation=None)
163
+ if return_image:
164
+ plt.close()
165
+ return img.get_array()
166
+ elif return_plt_im:
167
+ return img
168
+
169
+
170
+ def get_heatmap_locus(gene, locus1, locus2, heatmap_range):
171
+ locus1 = locus1.replace(',', '')
172
+ locus2 = locus2.replace(',', '')
173
+ try:
174
+ loc = rna.var.loc[gene]
175
+ except KeyError:
176
+ print('Could not find loci', gene)
177
+ return None
178
+ chrom = loc["chrom"]
179
+ if locus1.startswith('chr'):
180
+ chrom = locus1.split(':')[0]
181
+ pos1 = locus1.split(':')[1]
182
+ a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
183
+ else:
184
+ a1 = get_closest_peak_to_gene(locus1, rna, peaks)
185
+ if locus2.startswith('chr'):
186
+ chrom = locus2.split(':')[0]
187
+ pos2 = locus2.split(':')[1]
188
+ a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
189
+ else:
190
+ a2 = get_closest_peak_to_gene(locus2, rna, peaks)
191
+ a1_start = peaks.loc[a1, 'chromStart']
192
+ a2_start = peaks.loc[a2, 'chromStart']
193
+ interaction_dist = abs(a1_start - a2_start)
194
+ chrom_size = beta_cooler.chromsizes[chrom]
195
+ locus_start = max(0, a1_start - interaction_dist - heatmap_range * 1000)
196
+ locus_end = min(chrom_size, a2_start + interaction_dist + heatmap_range * 1000)
197
+ heatmap_locus = f'{chrom}:{locus_start}-{locus_end}'
198
+ heatmap_size = abs(locus_start - locus_end)
199
+ res = 'deeploop'
200
+ if heatmap_size > 5000000 or res not in cooler_resolutions:
201
+ res = '50kb'
202
+ if heatmap_size > 10000000:
203
+ res = '200kb'
204
+ if heatmap_size > 20000000:
205
+ res = '500kb'
206
+ print(heatmap_locus, res)
207
+ return heatmap_locus, res
208
+
209
+
210
+ def get_heatmap(celltype, gene, locus1, locus2, heatmap_range):
211
+ heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
212
+ c = coolers[res][cooler_celltypes.index(celltype)]
213
+ mat = c.matrix().fetch(heatmap_locus)
214
+ bins = c.bins().fetch(heatmap_locus)
215
+ locus1 = locus1.replace(',', '')
216
+ locus2 = locus2.replace(',', '')
217
+ if locus1.startswith('chr'):
218
+ a1_chrom = locus1.split(':')[0]
219
+ a1_start = int(locus1.split(':')[1])
220
+ else:
221
+ try:
222
+ loc = rna.var.loc[gene]
223
+ except KeyError:
224
+ print('Could not find loci', gene)
225
+ return None
226
+ a1_chrom = loc["chrom"]
227
+ a1_start = loc["chromStart"]
228
+ a1_idx = bins.query(f"chrom == '{a1_chrom}'")['start'].apply(lambda s: abs(s - a1_start)).argmin()
229
+ if locus2.startswith('chr'):
230
+ a2_chrom = locus2.split(':')[0]
231
+ a2_start = int(locus2.split(':')[1])
232
+ else:
233
+ try:
234
+ loc = rna.var.loc[gene]
235
+ except KeyError:
236
+ print('Could not find loci', gene)
237
+ return None
238
+ a2_chrom = loc["chrom"]
239
+ a2_start = loc["chromStart"]
240
+ a2_idx = bins.query(f"chrom == '{a2_chrom}'")['start'].apply(lambda s: abs(s - a2_start)).argmin()
241
+ print(a1_idx, a2_idx)
242
+ #img = draw_heatmap(mat, 0, return_image=True)
243
+ fig, ax = plt.subplots(figsize=(4, 4))
244
+ if res == 'deeploop':
245
+ draw_heatmap(mat, 3.0, ax=ax, min_val=0.8)
246
+ else:
247
+ ax.imshow(mat, cmap='Reds', norm=PowerNorm(gamma=0.3), interpolation=None)
248
+
249
+ # remove white padding
250
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
251
+ plt.axis('off')
252
+ plt.axis('image')
253
+
254
+ w, h = fig.canvas.get_width_height()
255
+ #heatmap_x = (a1_idx / len(bins)) * w
256
+ #heatmap_y = (a2_idx / len(bins)) * h
257
+ #print(heatmap_x, heatmap_y)
258
+ heatmap_x = a1_idx
259
+ heatmap_y = a2_idx
260
+ ax.scatter(int(heatmap_x), int(heatmap_y), color='green', marker='2', s=150)
261
+
262
+ # redraw the canvas
263
+ fig = plt.gcf()
264
+ fig.canvas.draw()
265
+
266
+ # convert canvas to image using numpy
267
+ img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
268
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
269
+ plt.close()
270
+ return img
271
+
272
+ def get_heatmaps(gene, locus1, locus2, heatmap_range):
273
+ res = []
274
+ for celltype in cooler_celltypes:
275
+ res.append(get_heatmap(celltype, gene, locus1, locus2, heatmap_range))
276
+ alpha, beta, acinar, duct, psc = res
277
+ return alpha, beta, acinar, duct, psc
278
+
279
+
280
+ def perturb(gene, locus1, locus2, use_max_gene, tf_knockout):
281
+ if tf_knockout:
282
+ guidance_hvf = prior.copy()
283
+ try:
284
+ guidance_hvf.remove_node(gene)
285
+ except:
286
+ pass
287
+ # find index of tf and set all counts and values to zero
288
+ gene_idx = gene_idx_map[gene]
289
+ a1 = get_closest_peak_to_gene(gene, rna, peaks)
290
+ try:
291
+ guidance_hvf.remove_node(a1)
292
+ except:
293
+ pass
294
+ else:
295
+ locus1 = locus1.replace(',', '')
296
+ locus2 = locus2.replace(',', '')
297
+ res = {'feat': [], 'log2FC': [], 'var': [], 'a1_idx': [], 'a2_idx': []}
298
+ for c in celltypes:
299
+ res[c] = []
300
+ res[f'{c}_var'] = []
301
+
302
+ links_dict = {}
303
+ for c in celltypes:
304
+ links_dict[c] = {'chrom1': [], 'chrom1Start': [], 'chrom1End': [], 'chrom2': [], 'chrom2Start': [], 'chrom2End': [], 'score': [], 'strand1': [], 'strand2': []}
305
+
306
+ if locus1.startswith('chr'):
307
+ chrom = locus1.split(':')[0]
308
+ pos1 = locus1.split(':')[1]
309
+ a1 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos1))).idxmin()
310
+ else:
311
+ a1 = get_closest_peak_to_gene(locus1, rna, peaks)
312
+ if locus2.startswith('chr'):
313
+ chrom = locus2.split(':')[0]
314
+ pos2 = locus2.split(':')[1]
315
+ a2 = peaks.query(f"chrom == '{chrom}'")['chromStart'].apply(lambda s: abs(s - int(pos2))).idxmin()
316
+ else:
317
+ a2 = get_closest_peak_to_gene(locus2, rna, peaks)
318
+
319
+ print(a1)
320
+ print(a2)
321
+
322
+ a1_closest_gene = get_closest_gene_to_peak(a1, rna, peaks)
323
+ a2_closest_gene = get_closest_gene_to_peak(a2, rna, peaks)
324
+ print(a1_closest_gene, a2_closest_gene)
325
+
326
+ guidance_hvf = prior.copy()
327
+ try:
328
+ path_to_gene_a1 = nx.shortest_path(prior, source=a1, target=gene)
329
+ try:
330
+ guidance_hvf.remove_edge(path_to_gene_a1[-2], path_to_gene_a1[-1])
331
+ except:
332
+ pass
333
+ try:
334
+ guidance_hvf.remove_edge(path_to_gene_a1[-1], path_to_gene_a1[-2])
335
+ except:
336
+ pass
337
+ except:
338
+ pass
339
+ try:
340
+ path_to_gene_a2 = nx.shortest_path(prior, source=a2, target=gene)
341
+ try:
342
+ guidance_hvf.remove_edge(path_to_gene_a2[-2], path_to_gene_a2[-1])
343
+ except:
344
+ pass
345
+ try:
346
+ guidance_hvf.remove_edge(path_to_gene_a2[-1], path_to_gene_a2[-2])
347
+ except:
348
+ pass
349
+ except:
350
+ pass
351
+ try:
352
+ guidance_hvf.remove_edge(a1, a2)
353
+ guidance_hvf.remove_edge(a2, a1)
354
+ except:
355
+ pass
356
+ try:
357
+ guidance_hvf.remove_edge(a1, a1_closest_gene)
358
+ guidance_hvf.remove_edge(a1_closest_gene, a1)
359
+ except:
360
+ pass
361
+ try:
362
+ guidance_hvf.remove_edge(a2, a2_closest_gene)
363
+ guidance_hvf.remove_edge(a2_closest_gene, a2)
364
+ except:
365
+ pass
366
+
367
+ perterbed_rna_recon = glue.decode_data("rna", "rna", rna, guidance_hvf)
368
+ ism = np.log2((perterbed_rna_recon + eps) / (rna_recon + eps))
369
+ if tf_knockout:
370
+ max_gene_idxs = np.abs(np.mean(ism, axis=0)).argsort()[::-1][:10]
371
+ for max_gene_idx in max_gene_idxs:
372
+ print(rna.var.query(f"highly_variable").index[max_gene_idx], np.mean(ism[:, max_gene_idx]))
373
+ max_gene_idx = max_gene_idxs[1]
374
+ else:
375
+ max_gene_idx = np.abs(np.mean(ism, axis=0)).argmax()
376
+ max_gene = rna.var.query(f"highly_variable").index[max_gene_idx]
377
+ print('Max gene:', max_gene)
378
+
379
+ # get integer index of gene
380
+ if use_max_gene:
381
+ gene_idx = gene_idx_map[max_gene]
382
+ else:
383
+ gene_idx = gene_idx_map[gene]
384
+ rna.obs['log2FC'] = ism[:, gene_idx]
385
+ rna.obs['old_mean'] = rna_recon[:, gene_idx]
386
+ rna.obs['new_mean'] = perterbed_rna_recon[:, gene_idx]
387
+ # compute new count based on log2FC
388
+ rna.obs['new_count'] = (rna.layers['counts'][:, gene_idx] + eps) * 2 ** rna.obs['log2FC'] - eps
389
+
390
+ fig = sc.pl.umap(rna,
391
+ color=['celltype', 'log2FC'],
392
+ color_map='Spectral',
393
+ wspace=0.05,
394
+ legend_loc='on data',
395
+ legend_fontoutline=2,
396
+ frameon=False,
397
+ return_fig=True)
398
+ if tf_knockout:
399
+ fig.suptitle(f'{max_gene if use_max_gene else gene} expression after TF knockout of {gene}')
400
+ else:
401
+ fig.suptitle(f'{max_gene if use_max_gene else gene} expression after removing ' + a1 + '-' + a2)
402
+ #fig.tight_layout()
403
+ #fig.patch.set_facecolor('none')
404
+ #fig.patch.set_alpha(0.0)
405
+
406
+ fig.canvas.draw()
407
+ image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
408
+ image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (4,))
409
+ # convert from argb to rgba
410
+ image_from_plot = image_from_plot[:, :, [1, 2, 3, 0]]
411
+
412
+ fig, ax = plt.subplots(figsize=(6, 4))
413
+ sc.pl.violin(rna, max_gene if use_max_gene else gene, groupby='celltype', layer='counts', show=False, ax=ax, rotation=90)
414
+ #sc.pl.violin(rna, 'old_mean', groupby='celltype', show=False, ax=ax, rotation=90)
415
+ ax.set_title(f'Normal {max_gene if use_max_gene else gene} expression')
416
+ fig.tight_layout()
417
+ fig.canvas.draw()
418
+ violin_img = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
419
+ violin_img = violin_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
420
+ violin_img = violin_img[:, :, [1, 2, 3, 0]]
421
+
422
+ fig, ax = plt.subplots(figsize=(6, 4))
423
+ sc.pl.violin(rna, 'new_mean', groupby='celltype', show=False, ax=ax, rotation=90)
424
+ ax.set_title(f'{max_gene if use_max_gene else gene} normalized expression after perturbation')
425
+ fig.tight_layout()
426
+ fig.canvas.draw()
427
+ violin_img_new = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
428
+ violin_img_new = violin_img_new.reshape(fig.canvas.get_width_height()[::-1] + (4,))
429
+ violin_img_new = violin_img_new[:, :, [1, 2, 3, 0]]
430
+
431
+ fig, ax = plt.subplots(figsize=(6, 4))
432
+ sc.pl.violin(rna, 'log2FC', groupby='celltype', show=False, ax=ax, rotation=90)
433
+ ax.set_title(f'{max_gene if use_max_gene else gene} expression change')
434
+ ax.hlines(0, -1, len(celltypes), linestyles='dashed')
435
+ fig.tight_layout()
436
+ fig.canvas.draw()
437
+ violin_img_fc = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
438
+ violin_img_fc = violin_img_fc.reshape(fig.canvas.get_width_height()[::-1] + (4,))
439
+ violin_img_fc = violin_img_fc[:, :, [1, 2, 3, 0]]
440
+
441
+ return image_from_plot, violin_img, violin_img_new, violin_img_fc
442
+
443
+
444
+ def update_on_tf_ko(tf_knockout, use_max_gene):
445
+ if tf_knockout:
446
+ return gr.update(visible=False), True
447
+ else:
448
+ return gr.update(visible=True), use_max_gene
449
+
450
+
451
+ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
452
+ with gr.Row():
453
+ with gr.Column():
454
+ in_locus = gr.Textbox(label="Gene/TF", elem_id='in-locus', scale=1)
455
+ with gr.Row():
456
+ max_gene_checkbox = gr.Checkbox(label="Max Gene", elem_id='max-gene-checkbox', scale=1, checked=False)
457
+ tf_ko = gr.Checkbox(label="TF KO", elem_id='tf-ko', scale=1, checked=False)
458
+ with gr.Column() as heatmap_col:
459
+ anchor1 = gr.Textbox(label="Locus 1 (gene or genomic coords)", elem_id='anchor1', scale=1)
460
+ anchor2 = gr.Textbox(label="Locus 2 (gene or genomic coords)", elem_id='anchor2', scale=1)
461
+ heatmap_x = gr.Textbox(label="Heatmap X", elem_id='heatmap-x', scale=1, visible=False)
462
+ heatmap_y = gr.Textbox(label="Heatmap Y", elem_id='heatmap-y', scale=1, visible=False)
463
+ heatmap_size = gr.Slider(label="Heatmap Range", info="kb range aroung input gene locus to expand", minimum=50, maximum=5000, value=250)
464
+ heatmap_button = gr.Button(value="Generate Heatmaps", elem_id='heatmap-button', scale=1)
465
+ with gr.Row():
466
+ out_heatmaps = []
467
+ for celltype in cooler_celltypes:
468
+ out_heatmaps.append(gr.Image(label=celltype, elem_id=f'out-heatmap-{celltype}', scale=1))
469
+ with gr.Row():
470
+ run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
471
+ with gr.Row():
472
+ out_img = gr.Image(elem_id='out-img', scale=1)
473
+ with gr.Row():
474
+ out_violin = gr.Image(elem_id='out-violin', scale=1)
475
+ out_violin_new = gr.Image(elem_id='out-violin-new', scale=1)
476
+ out_violin_fc = gr.Image(elem_id='out-violin-fc', scale=1)
477
+ #out_plot = gr.Plot(elem_id='out-plot', scale=1)
478
+
479
+ inputs = [in_locus, anchor1, anchor2, max_gene_checkbox, tf_ko]
480
+ outputs = [out_img, out_violin, out_violin_new, out_violin_fc]
481
+
482
+ gr.Examples(examples=[['INS', 'chr11:2,289,895', 'chr11:2,298,840', False, False],
483
+ ['BHLHA15', '', '', True, True],
484
+ ['FOXA2', '', '', True, True], # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4878272/
485
+ # ['OTUD3', 'chr1:20457786', 'chr1:20230413', False],
486
+ ['IRX2', 'chr5:2704405', 'chr5:2164879', False, False],
487
+ ['TSPAN1', 'TSPAN1', 'PIK3R3', False, False],
488
+ # ['IRX1', 'chr5:5397612', 'chr5:4850008', False],
489
+ ['GCG', 'GCG', 'FAP', False, False],
490
+ # ['GCG', 'chr2:163162903', 'chr2:162852164', True],
491
+ ['LOXL4', 'chr10:100186215', 'chr10:99913493', False, False],
492
+ # ['KRT19', 'chr17:22,220,637', 'chr17:39,591,813', False],
493
+ # ['LPP', 'chr3:188,097,749', 'chr3:197,916,262', False],
494
+ ['MAFB', 'chr20:39431654', 'chr20:39368271', True, False],
495
+ ['CEL', 'chr9:135,937,365', 'chr9:135,973,107', False, False]],
496
+ inputs=inputs,
497
+ outputs=outputs,
498
+ fn=perturb, cache_examples=os.getenv('SYSTEM') == 'spaces')
499
+ run_button.click(perturb, inputs, outputs=outputs)
500
+ heatmap_button.click(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
501
+ tf_ko.change(update_on_tf_ko, [tf_ko, max_gene_checkbox], outputs=[heatmap_col, max_gene_checkbox])
502
+ anchor1.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
503
+ anchor2.change(get_heatmaps, [in_locus, anchor1, anchor2, heatmap_size], outputs=out_heatmaps)
504
+ def set_loop(img, gene, locus1, locus2, heatmap_range, evt: gr.SelectData):
505
+ h, w = img.shape[:2]
506
+ idx = evt.index
507
+ x, y = idx[0], idx[1]
508
+ heatmap_locus, res = get_heatmap_locus(gene, locus1, locus2, heatmap_range)
509
+ bins = coolers[res][0].bins().fetch(heatmap_locus)
510
+ bins_idx_x = int(x / w * len(bins))
511
+ bins_idx_y = int(y / h * len(bins))
512
+ new_a1 = f"{bins.iloc[bins_idx_x]['chrom']}:{bins.iloc[bins_idx_x]['start']}"
513
+ new_a2 = f"{bins.iloc[bins_idx_y]['chrom']}:{bins.iloc[bins_idx_y]['start']}"
514
+ return new_a1, new_a2
515
+ for out_heatmap in out_heatmaps:
516
+ out_heatmap.select(set_loop, [out_heatmap, in_locus, anchor1, anchor2, heatmap_size], outputs=[anchor1, anchor2])
517
+
518
+
519
+ if __name__ == "__main__":
520
  demo.launch(share=False)