Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from tensorflow.keras.models import model_from_json | |
| from utils.utils import open_anchor_to_anchor, draw_heatmap, load_chr_ratio_matrix_from_sparse | |
| model_depths = ['1.5M', '2M', '2.4M', '4.88M', '5M', '6.29M', '8.5M', '12.5M', '16.5M', '25M', '32M', '50M', '100M', '150M'] | |
| # Load the model | |
| model_weights = 'DeepLoop_models/CPGZ_trained/12.5M.h5' # Replace with your model weights file | |
| model_architecture = 'DeepLoop_models/CPGZ_trained/12.5M.json' # Replace with your model architecture file | |
| with open(model_architecture, 'r') as f: | |
| model = model_from_json(f.read()) | |
| model.load_weights(model_weights) | |
| # Define the anchor file path | |
| anchor_file = 'ref/hg19_DPNII_anchor_bed/chr22.bed' # Replace with your anchor file | |
| # Define the tile size | |
| tile_size = 128 | |
| # Load the input matrix | |
| # input_file = '../anchor_2_anchor.loop.chr22' | |
| # chr_name = get_chromosome_from_filename('../anchor_2_anchor.loop.chr22') | |
| # input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file), | |
| # os.path.dirname(anchor_file), force_symmetry=True) | |
| input_file = None | |
| input_matrix = None | |
| # Load the anchor list | |
| anchor_list = pd.read_csv(anchor_file, sep='\t', names=['chr', 'start', 'end', 'anchor']) | |
| def predict(depth_idx): | |
| """Loads the input file, predicts the output, and visualizes the tile.""" | |
| selected_depth = model_depths[depth_idx] | |
| model_weights = f'DeepLoop_models/CPGZ_trained/{selected_depth}.h5' # Replace with your model weights file | |
| model_architecture = f'DeepLoop_models/CPGZ_trained/{selected_depth}.json' # Replace with your model architecture file | |
| with open(model_architecture, 'r') as f: | |
| model = model_from_json(f.read()) | |
| model.load_weights(model_weights) | |
| # Get the tile | |
| center_anchor = int(len(anchor_list) / 2) | |
| i = max(0, center_anchor - int(tile_size / 2)) | |
| j = i + tile_size | |
| tile = input_matrix[i:j, i:j].A | |
| tile = np.expand_dims(tile, -1) | |
| tile = np.expand_dims(tile, 0) | |
| # Predict the output | |
| denoised_tile = model.predict(tile).reshape((tile_size, tile_size)) | |
| denoised_tile[denoised_tile < 0] = 0 | |
| # Normalize the tiles | |
| tile = tile[0, ..., 0] | |
| denoised_tile = (denoised_tile + denoised_tile.T) / 2 | |
| # Visualize the tiles | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) | |
| draw_heatmap(tile, 0, ax=ax1) | |
| draw_heatmap(denoised_tile, 0, ax=ax2) | |
| ax1.set_title('Input Tile') | |
| ax2.set_title(f'{selected_depth} model') | |
| plt.tight_layout() | |
| # return as a numpy array | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close(fig) | |
| return data | |
| def upload_file(file): | |
| global input_file, input_matrix | |
| print(file) | |
| input_file = file | |
| input_matrix = load_chr_ratio_matrix_from_sparse(os.path.dirname(input_file), os.path.basename(input_file), | |
| os.path.dirname(anchor_file), force_symmetry=True) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| upload = gr.UploadButton("Upload a file", file_count="single") | |
| with gr.Row(): | |
| slider = gr.Slider(minimum=0, maximum=len(model_depths) - 1, step=1, label='Model Depth', interactive=True) | |
| heatmap = gr.Image(label='Visualization') | |
| upload.upload(upload_file, upload) | |
| slider.change(predict, [slider], heatmap) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |