DeepLoop / app.py
dylanplummer's picture
proof of concept
883a832
raw
history blame
3.76 kB
import gradio as gr
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tensorflow.keras.models import model_from_json
from utils.utils import open_anchor_to_anchor, draw_heatmap, load_chr_ratio_matrix_from_sparse
model_depths = ['1.5M', '2M', '2.4M', '4.88M', '5M', '6.29M', '8.5M', '12.5M', '16.5M', '25M', '32M', '50M', '100M', '150M']
# Load the model
model_weights = 'DeepLoop_models/CPGZ_trained/12.5M.h5' # Replace with your model weights file
model_architecture = 'DeepLoop_models/CPGZ_trained/12.5M.json' # Replace with your model architecture file
with open(model_architecture, 'r') as f:
model = model_from_json(f.read())
model.load_weights(model_weights)
# Define the anchor file path
anchor_file = 'ref/hg19_DPNII_anchor_bed/chr22.bed' # Replace with your anchor file
# Define the tile size
tile_size = 128
# Load the input matrix
# input_file = '../anchor_2_anchor.loop.chr22'
# chr_name = get_chromosome_from_filename('../anchor_2_anchor.loop.chr22')
# input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
# os.path.dirname(anchor_file), force_symmetry=True)
input_file = None
input_matrix = None
# Load the anchor list
anchor_list = pd.read_csv(anchor_file, sep='\t', names=['chr', 'start', 'end', 'anchor'])
def predict(depth_idx):
"""Loads the input file, predicts the output, and visualizes the tile."""
selected_depth = model_depths[depth_idx]
model_weights = f'DeepLoop_models/CPGZ_trained/{selected_depth}.h5' # Replace with your model weights file
model_architecture = f'DeepLoop_models/CPGZ_trained/{selected_depth}.json' # Replace with your model architecture file
with open(model_architecture, 'r') as f:
model = model_from_json(f.read())
model.load_weights(model_weights)
# Get the tile
center_anchor = int(len(anchor_list) / 2)
i = max(0, center_anchor - int(tile_size / 2))
j = i + tile_size
tile = input_matrix[i:j, i:j].A
tile = np.expand_dims(tile, -1)
tile = np.expand_dims(tile, 0)
# Predict the output
denoised_tile = model.predict(tile).reshape((tile_size, tile_size))
denoised_tile[denoised_tile < 0] = 0
# Normalize the tiles
tile = tile[0, ..., 0]
denoised_tile = (denoised_tile + denoised_tile.T) / 2
# Visualize the tiles
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
draw_heatmap(tile, 0, ax=ax1)
draw_heatmap(denoised_tile, 0, ax=ax2)
ax1.set_title('Input Tile')
ax2.set_title(f'{selected_depth} model')
plt.tight_layout()
# return as a numpy array
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
return data
def upload_file(file):
global input_file, input_matrix
print(file)
input_file = file
input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file),
os.path.dirname(anchor_file), force_symmetry=True)
with gr.Blocks() as demo:
with gr.Row():
upload = gr.UploadButton("Upload a file", file_count="single")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=len(model_depths) - 1, step=1, label='Model Depth', interactive=True)
heatmap = gr.Image(label='Visualization')
upload.upload(upload_file, upload)
slider.change(predict, [slider], heatmap)
if __name__ == "__main__":
demo.queue().launch()