EPCOTv2 / func_gradio.py
Xin Luo
2
7f22d0d
import sys
import os, sys
import inspect
import pickle
import argparse, torch,re
import numpy as np
import gradio
import torch
import io
from typing import IO
from curriculum.lora_prompt_model import build_model
from erna.util import load_ref_genome, load_dnase,pad_signal_matrix
import tempfile
class defaultArg:
def __init__(self):
self.bins = 600
self.crop = 50
self.embed_dim = 960
self.epochs = 20
self.accum_iter = 2
self.lr = 1e-5
self.batchsize = 1
self.atac_block = True
self.full = False
self.lora_r_pretrain = 0
self.lora_r_pretrain_1 = 0
self.lora_trunk_r = 0
self.lora_head_epi_r = 0
self.lora_head_rna_r = 0
self.lora_head_erna_r = 0
self.lora_head_microc_r = 0
self.logits_type = 'dilate'
self.prefix = ''
self.prompt = False
self.teacher = False
self.external = True
self.out = ''
self.include_scatac = False
self.seq_specific = False
def run_epcotv2(chrom : int, start : int, end : int, user_modalities: list, atac_pickle_file) -> IO[bytes]:
if atac_pickle_file is None:
raise gradio.Error("ATAC-seq pickle file not loaded!")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = defaultArg()
model = build_model(args)
model.to(device)
model.load_state_dict(torch.load(os.path.join(sys.path[0], "models/human_model.pt"), map_location=torch.device('cpu')))
#
input_locs = np.loadtxt(os.path.join(sys.path[0], "erna/input_region_600kb.bed"),
dtype=str, delimiter="\t")
input_locs[:, 0] = np.char.replace(input_locs[:, 0], "chr", "")
#
locas_chrom = input_locs[input_locs[:, 0] == str(chrom)]
ref_data = {}
atac_data={}
try:
pickle_path = atac_pickle_file.name
with open(pickle_path, 'rb') as f:
atacseq = pickle.load(f)
except Exception:
raise gradio.Error("ATAC-seq data not loaded!")
#with open('/nfs/turbo/umms-drjieliu/usr/luosanj/EPCOTv2_pancreas/bspleen_atac.pickle', 'rb') as f:
ref_data[chrom] = load_ref_genome(chrom)
atac_data[chrom] = load_dnase(atacseq[chrom].astype('float32'))
def check_and_align_region(chrom : int, start : int, end : int, input_locas : np.ndarray) -> np.ndarray:
start, end = int(start), int(end)
if end - start != 600000:
raise gradio.Error("Please enter a 600kb region!")
min_start = np.min(input_locas[:, 1].astype('int'))
max_end = np.max(input_locas[:, 2].astype('int'))
if start < min_start:
raise gradio.Error("The start of input region in chromosome %s should be greater than %s!"%(chrom, min_start))
if end > max_end:
raise gradio.Error("The end of input region in chromosome %s should be less than %s!"%(chrom, max_end))
start_idx = np.where(input_locas[:, 1].astype('int') > start)[0][0] - 1
if input_locas[0, 2].astype('int') == end:
end_idx = 0
else:
end_idx = np.where(input_locas[:, 2].astype('int') < end)[0][-1] + 1
input_aligned_loc = input_locas[start_idx : (end_idx + 1), :]
return input_aligned_loc
def load_data(lidx : int, input_aligned_doc : np.ndarray) -> torch.Tensor:
chrom,s,e=input_aligned_doc[lidx]
inp_s,inp_e= int(s)//1000,int(e)//1000
if chrom!='X':
chrom=int(chrom)
#print(inp_s,inp_e,ref_data[chrom].shape,atac_data[chrom].shape )
input = torch.cat((ref_data[chrom][inp_s:inp_e], atac_data[chrom][inp_s:inp_e]), dim=1).unsqueeze(0).to(device)
return input
input_aligned_doc = check_and_align_region(chrom, start, end, locas_chrom)
pred_indices= np.arange(input_aligned_doc.shape[0])
modalities=['epi', 'rna', 'bru', 'microc', 'hic','intacthic','rna_strand',
'external_tf', 'tt', 'groseq', 'grocap', 'proseq','netcage','starr']
#pred_modalities=['rna','epi','bru','groseq','netcage','tt', 'grocap']
pred_modalities = user_modalities
pred_outputs= {}
for step,vidx in enumerate(pred_indices):
# chrom, start, end = input_locs[vidx]
valid_input = load_data(vidx, input_aligned_doc)
with torch.no_grad():
#print(valid_input.shape)
output,external_output = model(valid_input)
mix_output=[out.cpu().data.detach().numpy() for out in (output+external_output)]
out_dic=dict(zip(modalities,mix_output))
if step==0:
for mod in pred_modalities:
pred_outputs[mod]=out_dic[mod]
else:
for mod in pred_modalities:
pred_outputs[mod]=np.vstack((pred_outputs[mod],out_dic[mod]))
with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as tmp:
pickle.dump(pred_outputs, tmp)
tmp_path = tmp.name # Get the temporary file path
return tmp_path