| | import sys |
| | import os, sys |
| | import inspect |
| | import pickle |
| | import argparse, torch,re |
| | import numpy as np |
| | import gradio |
| | import torch |
| | import io |
| | from typing import IO |
| | from curriculum.lora_prompt_model import build_model |
| | from erna.util import load_ref_genome, load_dnase,pad_signal_matrix |
| | import tempfile |
| | class defaultArg: |
| | def __init__(self): |
| | self.bins = 600 |
| | self.crop = 50 |
| | self.embed_dim = 960 |
| | self.epochs = 20 |
| | self.accum_iter = 2 |
| | self.lr = 1e-5 |
| | self.batchsize = 1 |
| | self.atac_block = True |
| | self.full = False |
| | self.lora_r_pretrain = 0 |
| | self.lora_r_pretrain_1 = 0 |
| | self.lora_trunk_r = 0 |
| | self.lora_head_epi_r = 0 |
| | self.lora_head_rna_r = 0 |
| | self.lora_head_erna_r = 0 |
| | self.lora_head_microc_r = 0 |
| | self.logits_type = 'dilate' |
| | self.prefix = '' |
| | self.prompt = False |
| | self.teacher = False |
| | self.external = True |
| | self.out = '' |
| | self.include_scatac = False |
| | self.seq_specific = False |
| |
|
| | def run_epcotv2(chrom : int, start : int, end : int, user_modalities: list, atac_pickle_file) -> IO[bytes]: |
| | if atac_pickle_file is None: |
| | raise gradio.Error("ATAC-seq pickle file not loaded!") |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | args = defaultArg() |
| | model = build_model(args) |
| | model.to(device) |
| | model.load_state_dict(torch.load(os.path.join(sys.path[0], "models/human_model.pt"), map_location=torch.device('cpu'))) |
| | |
| | input_locs = np.loadtxt(os.path.join(sys.path[0], "erna/input_region_600kb.bed"), |
| | dtype=str, delimiter="\t") |
| | input_locs[:, 0] = np.char.replace(input_locs[:, 0], "chr", "") |
| | |
| | locas_chrom = input_locs[input_locs[:, 0] == str(chrom)] |
| |
|
| | ref_data = {} |
| | atac_data={} |
| | |
| | try: |
| | pickle_path = atac_pickle_file.name |
| | with open(pickle_path, 'rb') as f: |
| | atacseq = pickle.load(f) |
| | except Exception: |
| | raise gradio.Error("ATAC-seq data not loaded!") |
| | |
| | ref_data[chrom] = load_ref_genome(chrom) |
| | atac_data[chrom] = load_dnase(atacseq[chrom].astype('float32')) |
| |
|
| | def check_and_align_region(chrom : int, start : int, end : int, input_locas : np.ndarray) -> np.ndarray: |
| | start, end = int(start), int(end) |
| | if end - start != 600000: |
| | raise gradio.Error("Please enter a 600kb region!") |
| | min_start = np.min(input_locas[:, 1].astype('int')) |
| | max_end = np.max(input_locas[:, 2].astype('int')) |
| | if start < min_start: |
| | raise gradio.Error("The start of input region in chromosome %s should be greater than %s!"%(chrom, min_start)) |
| | if end > max_end: |
| | raise gradio.Error("The end of input region in chromosome %s should be less than %s!"%(chrom, max_end)) |
| | start_idx = np.where(input_locas[:, 1].astype('int') > start)[0][0] - 1 |
| | if input_locas[0, 2].astype('int') == end: |
| | end_idx = 0 |
| | else: |
| | end_idx = np.where(input_locas[:, 2].astype('int') < end)[0][-1] + 1 |
| | input_aligned_loc = input_locas[start_idx : (end_idx + 1), :] |
| | return input_aligned_loc |
| |
|
| | def load_data(lidx : int, input_aligned_doc : np.ndarray) -> torch.Tensor: |
| | chrom,s,e=input_aligned_doc[lidx] |
| | inp_s,inp_e= int(s)//1000,int(e)//1000 |
| | if chrom!='X': |
| | chrom=int(chrom) |
| | |
| | input = torch.cat((ref_data[chrom][inp_s:inp_e], atac_data[chrom][inp_s:inp_e]), dim=1).unsqueeze(0).to(device) |
| | return input |
| | |
| | input_aligned_doc = check_and_align_region(chrom, start, end, locas_chrom) |
| | pred_indices= np.arange(input_aligned_doc.shape[0]) |
| | modalities=['epi', 'rna', 'bru', 'microc', 'hic','intacthic','rna_strand', |
| | 'external_tf', 'tt', 'groseq', 'grocap', 'proseq','netcage','starr'] |
| | |
| | pred_modalities = user_modalities |
| | pred_outputs= {} |
| | for step,vidx in enumerate(pred_indices): |
| | |
| | valid_input = load_data(vidx, input_aligned_doc) |
| | with torch.no_grad(): |
| | |
| | output,external_output = model(valid_input) |
| | mix_output=[out.cpu().data.detach().numpy() for out in (output+external_output)] |
| | out_dic=dict(zip(modalities,mix_output)) |
| | if step==0: |
| | for mod in pred_modalities: |
| | pred_outputs[mod]=out_dic[mod] |
| | else: |
| | for mod in pred_modalities: |
| | pred_outputs[mod]=np.vstack((pred_outputs[mod],out_dic[mod])) |
| |
|
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as tmp: |
| | pickle.dump(pred_outputs, tmp) |
| | tmp_path = tmp.name |
| | return tmp_path |
| |
|