File size: 12,974 Bytes
883a832
 
f4051e1
 
883a832
 
f4051e1
883a832
 
 
1f7dddc
 
1403428
 
 
f4051e1
6b6e502
 
f4051e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883a832
 
 
 
 
 
 
 
 
 
 
 
bb6d255
 
883a832
 
 
 
 
 
1f7dddc
 
 
 
883a832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ecd8d4
883a832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import gradio as gr
import os
import scipy
from scipy.sparse import tril, triu
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
from tensorflow.keras.models import model_from_json
from huggingface_hub import hf_hub_download

#input_file = hf_hub_download(repo_id="dylanplummer/hicorr", filename="arima_beta.chr22", repo_type="dataset", token=os.environ['DATASET_SECRET'])
input_file = hf_hub_download(repo_id="dylanplummer/hicorr", filename="ORC2.chr22", repo_type="dataset", token=os.environ['DATASET_SECRET'])


data_dir = 'data/'
sparse_data_dir = 'data/sparse_data/'


def get_chromosome_from_filename(filename):
    """

    Extract the chromosome string from any of the file name formats we use



    Args:

        filename (:obj:`str`) : name of anchor to anchor file



    Returns:

        Chromosome string of form chr<>

    """
    chr_index = filename.find('chr')  # index of chromosome name
    if chr_index == 0:  # if chromosome name is file prefix
        return filename[:filename.find('.')]
    file_ending_index = filename.rfind('.')  # index of file ending
    if chr_index > file_ending_index:  # if chromosome name is file ending
        return filename[chr_index:]
    else:
        return filename[chr_index: file_ending_index]


def draw_heatmap(matrix, color_scale, ax=None, return_image=False):
    """

    Display ratio heatmap containing only strong signals (values > 1 or 0.98th quantile)



    Args:

        matrix (:obj:`numpy.array`) : ratio matrix to be displayed

        color_scale (:obj:`int`) : max ratio value to be considered strongest by color mapping

        ax (:obj:`matplotlib.axes.Axes`) : axes which will contain the heatmap.  If None, new axes are created

        return_image (:obj:`bool`) : set to True to return the image obtained from drawing the heatmap with the generated color map



    Returns:

        ``numpy.array`` : if ``return_image`` is set to True, return the heatmap as an array

    """
    if color_scale != 0:
        breaks = np.append(np.arange(1.001, color_scale, (color_scale - 1.001) / 18), np.max(matrix))
    elif np.max(matrix) < 2:
        breaks = np.arange(1.001, np.max(matrix), (np.max(matrix) - 1.001) / 19)
    else:
        step = (np.quantile(matrix, q=0.95) - 1) / 18
        up = np.quantile(matrix, q=0.95) + 0.011
        if up < 2:
            up = 2
            step = 0.999 / 18
        breaks = np.append(np.arange(1.001, up, step), np.max(matrix))

    n_bin = 20  # Discretizes the interpolation into bins
    colors = ["#FFFFFF", "#FFE4E4", "#FFD7D7", "#FFC9C9", "#FFBCBC", "#FFAEAE", "#FFA1A1", "#FF9494", "#FF8686",
              "#FF7979", "#FF6B6B", "#FF5E5E", "#FF5151", "#FF4343", "#FF3636", "#FF2828", "#FF1B1B", "#FF0D0D",
              "#FF0000"]
    cmap_name = 'my_list'
    # Create the colormap
    cm = matplotlib.colors.LinearSegmentedColormap.from_list(
        cmap_name, colors, N=n_bin)
    norm = matplotlib.colors.BoundaryNorm(breaks, 20)
    # Fewer bins will result in "coarser" colomap interpolation
    if ax is None:
        _, ax = plt.subplots()
    img = ax.imshow(matrix, cmap=cm, norm=norm, interpolation='nearest')
    if return_image:
        plt.close()
        return img.get_array()
    

def anchor_list_to_dict(anchors):
    """

    Converts the array of anchor names to a dictionary mapping each anchor to its chromosomal index



    Args:

        anchors (:obj:`numpy.array`) : array of anchor name values



    Returns:

        `dict` : dictionary mapping each anchor to its index from the array

    """
    anchor_dict = {}
    for i, anchor in enumerate(anchors):
        anchor_dict[anchor] = i
    return anchor_dict


def anchor_to_locus(anchor_dict):
    """

    Function to convert an anchor name to its genomic locus which can be easily vectorized



    Args:

        anchor_dict (:obj:`dict`) : dictionary mapping each anchor to its chromosomal index



    Returns:

        `function` : function which returns the locus of an anchor name

    """
    def f(anchor):
        return anchor_dict[anchor]
    return f
    


def load_chr_ratio_matrix_from_sparse(dir_name, file_name, anchor_dir, sparse_dir=None, anchor_list=None, chr_name=None, dummy=5, ignore_sparse=False, force_symmetry=True, use_raw=False):
    """

    Loads data as a sparse matrix by either reading a precompiled sparse matrix or an anchor to anchor file which is converted to sparse CSR format.

    Ratio values are computed using the observed (obs) and expected (exp) values:



    .. math::

       ratio = \\frac{obs + dummy}{exp + dummy}



    Args:

        dir_name (:obj:`str`) : directory containing the anchor to anchor or precompiled (.npz) sparse matrix file

        file_name (:obj:`str`) : name of anchor to anchor or precompiled (.npz) sparse matrix file

        anchor_dir (:obj:`str`) : directory containing the reference anchor ``.bed`` files

        dummy (:obj:`int`) : dummy value to used when computing ratio values

        ignore_sparse (:obj:`bool`) : set to True to ignore precompiled sparse matrices even if they exist



    Returns:

        ``scipy.sparse.csr_matrix``: sparse matrix of ratio values

    """
    global data_dir
    global sparse_data_dir
    if chr_name is None:
        chr_name = get_chromosome_from_filename(file_name)
    sparse_rep_dir = dir_name[dir_name[: -1].rfind('/') + 1:]  # directory where the pre-compiled sparse matrices are saved
    if sparse_dir is not None:
        sparse_data_dir = sparse_dir
    os.makedirs(os.path.join(sparse_data_dir, sparse_rep_dir), exist_ok=True)
    if file_name.endswith('.npz'):  # loading pre-combined and pre-compiled sparse data
        sparse_matrix = scipy.sparse.load_npz(dir_name + file_name)
    else:  # load from file name
        if file_name + '.npz' in os.listdir(os.path.join(sparse_data_dir, sparse_rep_dir)) and not ignore_sparse:  # check if pre-compiled data already exists
            sparse_matrix = scipy.sparse.load_npz(os.path.join(sparse_data_dir, sparse_rep_dir, file_name + '.npz'))
        else:  # otherwise generate sparse matrix from anchor2anchor file and save pre-compiled data
            if anchor_list is None:
                if anchor_dir is None:
                    assert 'You must supply either an anchor reference list or the directory containing one'
                anchor_list = pd.read_csv(os.path.join(anchor_dir, '%s.bed' % chr_name), sep='\t',
                                          names=['chr', 'start', 'end', 'anchor'])  # read anchor list file
            matrix_size = len(anchor_list) # matrix size is needed to construct sparse CSR matrix
            anchor_dict = anchor_list_to_dict(anchor_list['anchor'].values)  # convert to anchor --> index dictionary
            try:  # first try reading anchor to anchor file as <a1> <a2> <obs> <exp>
                chr_anchor_file = pd.read_csv(
                    os.path.join(dir_name, file_name),
                    delimiter='\t',
                    names=['anchor1', 'anchor2', 'obs', 'exp'],
                    usecols=['anchor1', 'anchor2', 'obs', 'exp'])  # read chromosome anchor to anchor file
                rows = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor1'].values)  # convert anchor names to row indices
                cols = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor2'].values)  # convert anchor names to column indices
                ratio = (chr_anchor_file['obs'] + dummy) / (chr_anchor_file['exp'] + dummy)  # compute matrix ratio value
                sparse_matrix = scipy.sparse.csr_matrix((ratio, (rows, cols)), shape=(matrix_size, matrix_size))  # construct sparse CSR matrix
            except:  # otherwise read anchor to anchor file as <a1> <a2> <ratio>
                chr_anchor_file = pd.read_csv(
                    os.path.join(dir_name, file_name),
                    delimiter='\t',
                    names=['anchor1', 'anchor2', 'ratio'],
                    usecols=['anchor1', 'anchor2', 'ratio'])
                rows = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor1'].values)  # convert anchor names to row indices
                cols = np.vectorize(anchor_to_locus(anchor_dict))(chr_anchor_file['anchor2'].values)  # convert anchor names to column indices
                if use_raw:
                    sparse_matrix = scipy.sparse.csr_matrix((chr_anchor_file['obs'], (rows, cols)), shape=(
                    matrix_size, matrix_size))  # construct sparse CSR matrix
                else:
                    sparse_matrix = scipy.sparse.csr_matrix((chr_anchor_file['ratio'], (rows, cols)), shape=(matrix_size, matrix_size))  # construct sparse CSR matrix
            if force_symmetry:
                upper_sum =  triu(sparse_matrix, k=1).sum()
                lower_sum = tril(sparse_matrix, k=-1).sum()
                if upper_sum == 0 or lower_sum == 0:
                    sparse_matrix = sparse_matrix + sparse_matrix.transpose()
                sparse_triu = scipy.sparse.triu(sparse_matrix)
                sparse_matrix = sparse_triu + sparse_triu.transpose()
            if not ignore_sparse:
                scipy.sparse.save_npz(os.path.join(sparse_data_dir, sparse_rep_dir, file_name), sparse_matrix)  # save precompiled data
    return sparse_matrix


model_depths = ['1.5M', '2M', '2.4M', '4.88M', '5M', '6.29M', '8.5M', '12.5M', '16.5M', '25M', '32M', '50M', '100M', '150M']
# Load the model
model_weights = 'DeepLoop_models/CPGZ_trained/12.5M.h5'  # Replace with your model weights file
model_architecture = 'DeepLoop_models/CPGZ_trained/12.5M.json'  # Replace with your model architecture file

with open(model_architecture, 'r') as f:
    model = model_from_json(f.read())
model.load_weights(model_weights)

# Define the anchor file path
anchor_file = 'ref/hg19_DPNII_anchor_bed/chr22.bed'
#anchor_file = 'ref/hg19_Arima_anchor_bed/chr22.bed'

# Define the tile size
tile_size = 128

# Load the input matrix
# input_file = '../anchor_2_anchor.loop.chr22'
input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
                                                    os.path.dirname(anchor_file), force_symmetry=True)
# input_file = None
# input_matrix = None

# Load the anchor list
anchor_list = pd.read_csv(anchor_file, sep='\t', names=['chr', 'start', 'end', 'anchor'])

def predict(depth_idx):
    """Loads the input file, predicts the output, and visualizes the tile."""
    selected_depth = model_depths[depth_idx]
    model_weights = f'DeepLoop_models/CPGZ_trained/{selected_depth}.h5'  # Replace with your model weights file
    model_architecture = f'DeepLoop_models/CPGZ_trained/{selected_depth}.json'  # Replace with your model architecture file

    with open(model_architecture, 'r') as f:
        model = model_from_json(f.read())
    model.load_weights(model_weights)

    # Get the tile
    center_anchor = int(len(anchor_list) / 2)
    i = max(0, center_anchor - int(tile_size / 2))
    j = i + tile_size
    tile = input_matrix[i:j, i:j].toarray()
    tile = np.expand_dims(tile, -1)
    tile = np.expand_dims(tile, 0)

    # Predict the output
    denoised_tile = model.predict(tile).reshape((tile_size, tile_size))
    denoised_tile[denoised_tile < 0] = 0

    # Normalize the tiles
    tile = tile[0, ..., 0]
    denoised_tile = (denoised_tile + denoised_tile.T) / 2

    # Visualize the tiles
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    draw_heatmap(tile, 0, ax=ax1)
    draw_heatmap(denoised_tile, 0, ax=ax2)
    ax1.set_title('Input Tile')
    ax2.set_title(f'{selected_depth} model')
    plt.tight_layout()
    
    # return as a numpy array
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    return data

def upload_file(file):
    global input_file, input_matrix
    print(file)
    input_file = file
    input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
                                                        os.path.dirname(anchor_file), force_symmetry=True)


with gr.Blocks() as demo:
    with gr.Row():
        upload = gr.UploadButton("Upload a file", file_count="single")
    with gr.Row():
        slider = gr.Slider(minimum=0, maximum=len(model_depths) - 1, step=1, label='Model Depth', interactive=True)
        heatmap = gr.Image(label='Visualization')
    
    upload.upload(upload_file, upload)
    slider.change(predict, [slider], heatmap)



if __name__ == "__main__":
    demo.queue().launch()