dylanplummer commited on
Commit
883a832
·
1 Parent(s): c072da3

proof of concept

Browse files
Files changed (1) hide show
  1. app.py +99 -7
app.py CHANGED
@@ -1,7 +1,99 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ from pathlib import Path
7
+ from tensorflow.keras.models import model_from_json
8
+ from utils.utils import open_anchor_to_anchor, draw_heatmap, load_chr_ratio_matrix_from_sparse
9
+
10
+
11
+ model_depths = ['1.5M', '2M', '2.4M', '4.88M', '5M', '6.29M', '8.5M', '12.5M', '16.5M', '25M', '32M', '50M', '100M', '150M']
12
+ # Load the model
13
+ model_weights = 'DeepLoop_models/CPGZ_trained/12.5M.h5' # Replace with your model weights file
14
+ model_architecture = 'DeepLoop_models/CPGZ_trained/12.5M.json' # Replace with your model architecture file
15
+
16
+ with open(model_architecture, 'r') as f:
17
+ model = model_from_json(f.read())
18
+ model.load_weights(model_weights)
19
+
20
+ # Define the anchor file path
21
+ anchor_file = 'ref/hg19_DPNII_anchor_bed/chr22.bed' # Replace with your anchor file
22
+
23
+ # Define the tile size
24
+ tile_size = 128
25
+
26
+ # Load the input matrix
27
+ # input_file = '../anchor_2_anchor.loop.chr22'
28
+ # chr_name = get_chromosome_from_filename('../anchor_2_anchor.loop.chr22')
29
+ # input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
30
+ # os.path.dirname(anchor_file), force_symmetry=True)
31
+ input_file = None
32
+ input_matrix = None
33
+
34
+ # Load the anchor list
35
+ anchor_list = pd.read_csv(anchor_file, sep='\t', names=['chr', 'start', 'end', 'anchor'])
36
+
37
+ def predict(depth_idx):
38
+ """Loads the input file, predicts the output, and visualizes the tile."""
39
+ selected_depth = model_depths[depth_idx]
40
+ model_weights = f'DeepLoop_models/CPGZ_trained/{selected_depth}.h5' # Replace with your model weights file
41
+ model_architecture = f'DeepLoop_models/CPGZ_trained/{selected_depth}.json' # Replace with your model architecture file
42
+
43
+ with open(model_architecture, 'r') as f:
44
+ model = model_from_json(f.read())
45
+ model.load_weights(model_weights)
46
+
47
+ # Get the tile
48
+ center_anchor = int(len(anchor_list) / 2)
49
+ i = max(0, center_anchor - int(tile_size / 2))
50
+ j = i + tile_size
51
+ tile = input_matrix[i:j, i:j].A
52
+ tile = np.expand_dims(tile, -1)
53
+ tile = np.expand_dims(tile, 0)
54
+
55
+ # Predict the output
56
+ denoised_tile = model.predict(tile).reshape((tile_size, tile_size))
57
+ denoised_tile[denoised_tile < 0] = 0
58
+
59
+ # Normalize the tiles
60
+ tile = tile[0, ..., 0]
61
+ denoised_tile = (denoised_tile + denoised_tile.T) / 2
62
+
63
+ # Visualize the tiles
64
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
65
+ draw_heatmap(tile, 0, ax=ax1)
66
+ draw_heatmap(denoised_tile, 0, ax=ax2)
67
+ ax1.set_title('Input Tile')
68
+ ax2.set_title(f'{selected_depth} model')
69
+ plt.tight_layout()
70
+
71
+ # return as a numpy array
72
+ fig.canvas.draw()
73
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
74
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
75
+ plt.close(fig)
76
+ return data
77
+
78
+ def upload_file(file):
79
+ global input_file, input_matrix
80
+ print(file)
81
+ input_file = file
82
+ input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
83
+ os.path.dirname(anchor_file), force_symmetry=True)
84
+
85
+
86
+ with gr.Blocks() as demo:
87
+ with gr.Row():
88
+ upload = gr.UploadButton("Upload a file", file_count="single")
89
+ with gr.Row():
90
+ slider = gr.Slider(minimum=0, maximum=len(model_depths) - 1, step=1, label='Model Depth', interactive=True)
91
+ heatmap = gr.Image(label='Visualization')
92
+
93
+ upload.upload(upload_file, upload)
94
+ slider.change(predict, [slider], heatmap)
95
+
96
+
97
+
98
+ if __name__ == "__main__":
99
+ demo.queue().launch()