|
|
import os |
|
|
import math |
|
|
import pandas as pd |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from collections import OrderedDict |
|
|
from .ernie_rna.tasks.ernie_rna import * |
|
|
from .ernie_rna.models.ernie_rna import * |
|
|
from .ernie_rna.criterions.ernie_rna import * |
|
|
|
|
|
|
|
|
def prepare_input_for_ernierna(index, seq_len): |
|
|
|
|
|
shorten_index = index[:seq_len+2] |
|
|
one_d = torch.from_numpy(shorten_index).long().reshape(1,-1) |
|
|
two_d = np.zeros((1,seq_len+2,seq_len+2)) |
|
|
two_d[0,:,:] = creatmat(shorten_index.astype(int),base_range=1,lamda=0.8) |
|
|
two_d = two_d.transpose(1,2,0) |
|
|
two_d = torch.from_numpy(two_d).reshape(1,seq_len+2,seq_len+2,1) |
|
|
|
|
|
return one_d, two_d |
|
|
|
|
|
def read_text_file(file_path): |
|
|
''' |
|
|
input: |
|
|
file_path: str, txt file path of input seqs |
|
|
|
|
|
return: |
|
|
lines: list[str], list of seqs |
|
|
''' |
|
|
try: |
|
|
with open(file_path, 'r') as file: |
|
|
lines = file.readlines() |
|
|
lines = [line.strip() for line in lines] |
|
|
return lines |
|
|
except FileNotFoundError: |
|
|
print(f"Error: File '{file_path}' not found.") |
|
|
return [] |
|
|
|
|
|
def read_fasta_file(file_path): |
|
|
''' |
|
|
input: |
|
|
file_path: str, fasta file path of input seqs |
|
|
|
|
|
return: |
|
|
seqs_dict: dict[str], dict of seqs |
|
|
''' |
|
|
try: |
|
|
with open(file_path) as fa: |
|
|
seqs_dict = {} |
|
|
for line in fa: |
|
|
line = line.replace('\n','') |
|
|
if line.startswith('>'): |
|
|
seq_name = line[1:] |
|
|
seqs_dict[seq_name] = '' |
|
|
else: |
|
|
seqs_dict[seq_name] += line |
|
|
return seqs_dict |
|
|
except FileNotFoundError: |
|
|
print(f"Error: File '{file_path}' not found.") |
|
|
return {} |
|
|
|
|
|
def save_rnass_results(file_path, seq_names_lst, ss_results_lst): |
|
|
''' |
|
|
input: |
|
|
file_path: str, the path of rna ss extracted by ERNIE-RNA |
|
|
seq_names_lst: list[str], names list of input seqs |
|
|
ss_results_lst: list[list[str]], rna ss predicted by ernie_rna |
|
|
|
|
|
return: |
|
|
''' |
|
|
os.makedirs(file_path, exist_ok=True) |
|
|
fine_tune_prediction_result = open(file_path + 'fine_tune_results.txt','w+') |
|
|
pretrain_prediction_result = open(file_path + 'pretrain_results.txt','w+') |
|
|
for name, result_lst in zip(seq_names_lst, ss_results_lst): |
|
|
fine_tune_prediction_result.write(name + '\n') |
|
|
pretrain_prediction_result.write(name + '\n') |
|
|
fine_tune_prediction_result.write(result_lst[0] + '\n') |
|
|
pretrain_prediction_result.write(result_lst[1] + '\n') |
|
|
fine_tune_prediction_result.close() |
|
|
pretrain_prediction_result.close() |
|
|
|
|
|
|
|
|
def save_rnavalue_results(file_path, seq_names_lst, ss_results_lst,label_lst=None): |
|
|
''' |
|
|
input: |
|
|
file_path: str, the path of rna ss extracted by ERNIE-RNA |
|
|
seq_names_lst: list[str], names list of input seqs |
|
|
ss_results_lst: list[list[str]], rna ss predicted by ernie_rna |
|
|
|
|
|
return: |
|
|
''' |
|
|
os.makedirs(file_path, exist_ok=True) |
|
|
df = pd.DataFrame(ss_results_lst) |
|
|
df.columns = ['sequence','pred'] |
|
|
df['_id'] = seq_names_lst |
|
|
if label_lst is not None: df['label'] = label_lst |
|
|
df.to_csv(file_path + 'pretrain_results.csv',index=False) |
|
|
def load_pretrained_ernierna(mlm_pretrained_model_path,arg_overrides): |
|
|
rna_models, _, _ = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model_path.split(os.pathsep),arg_overrides=arg_overrides) |
|
|
model_pretrained = rna_models[0] |
|
|
return model_pretrained |
|
|
|
|
|
def gaussian(x): |
|
|
return math.exp(-0.5*(x*x)) |
|
|
|
|
|
def paired(x,y,lamda=0.8): |
|
|
if x == 5 and y == 6: |
|
|
return 2 |
|
|
elif x == 4 and y == 7: |
|
|
return 3 |
|
|
elif x == 4 and y == 6: |
|
|
return lamda |
|
|
elif x == 6 and y == 5: |
|
|
return 2 |
|
|
elif x == 7 and y == 4: |
|
|
return 3 |
|
|
elif x == 6 and y == 4: |
|
|
return lamda |
|
|
else: |
|
|
return 0 |
|
|
|
|
|
base_range_lst = [1] |
|
|
lamda_lst = [0.8] |
|
|
|
|
|
def creatmat(data, base_range=30, lamda=0.8): |
|
|
paird_map = np.array([[paired(i,j,lamda) for i in range(30)] for j in range(30)]) |
|
|
data_index = np.arange(0,len(data)) |
|
|
|
|
|
coefficient = np.zeros([len(data),len(data)]) |
|
|
|
|
|
score_mask = np.full((len(data),len(data)),True) |
|
|
for add in range(base_range): |
|
|
data_index_x = data_index - add |
|
|
data_index_y = data_index + add |
|
|
score_mask = ((data_index_x >= 0)[:,None] & (data_index_y < len(data))[None,:]) & score_mask |
|
|
data_index_x,data_index_y = np.meshgrid(data_index_x.clip(0,len(data) - 1),data_index_y.clip(0,len(data) - 1),indexing='ij') |
|
|
score = paird_map[data[data_index_x],data[data_index_y]] |
|
|
score_mask = score_mask & (score != 0) |
|
|
|
|
|
coefficient = coefficient + score * score_mask * gaussian(add) |
|
|
if ~(score_mask.any()) : |
|
|
break |
|
|
score_mask = coefficient > 0 |
|
|
for add in range(1,base_range): |
|
|
data_index_x = data_index + add |
|
|
data_index_y = data_index - add |
|
|
score_mask = ((data_index_x < len(data))[:,None] & (data_index_y >= 0)[None,:]) & score_mask |
|
|
data_index_x,data_index_y = np.meshgrid(data_index_x.clip(0,len(data) - 1),data_index_y.clip(0,len(data) - 1),indexing='ij') |
|
|
score = paird_map[data[data_index_x],data[data_index_y]] |
|
|
score_mask = score_mask & (score != 0) |
|
|
coefficient = coefficient + score * score_mask * gaussian(add) |
|
|
if ~(score_mask.any()) : |
|
|
break |
|
|
return coefficient |
|
|
|
|
|
def weights_init_kaiming(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find('Linear') != -1: |
|
|
nn.init.normal_(m.weight, std=0.001) |
|
|
if isinstance(m.bias, nn.Parameter): |
|
|
nn.init.constant_(m.bias, 0.0) |
|
|
elif classname.find('BasicConv') != -1: |
|
|
pass |
|
|
elif classname.find('Conv') != -1: |
|
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0.0) |
|
|
elif classname.find('BatchNorm') != -1: |
|
|
if m.affine: |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
nn.init.constant_(m.bias, 0.0) |
|
|
|
|
|
def weights_init_classifier(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find('Linear') != -1: |
|
|
nn.init.normal_(m.weight, std=0.001) |
|
|
if isinstance(m.bias, nn.Parameter): |
|
|
nn.init.constant_(m.bias, 0.0) |
|
|
|
|
|
class ChooseModel(nn.Module): |
|
|
def __init__(self, sentence_encoder): |
|
|
super().__init__() |
|
|
self.sentence_encoder = sentence_encoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 8, 7, 1, 3) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.dropout = nn.Dropout(p=0.3) |
|
|
self.conv2 = nn.Conv2d(8, 63, 7, 1, 3) |
|
|
self.depth = 8 |
|
|
res_layers = [] |
|
|
for i in range(self.depth): |
|
|
dilation = pow(2, (i % 3)) |
|
|
res_layers.append(MyBasicResBlock(inplanes=64, planes=64, dilation=dilation)) |
|
|
res_layers = nn.Sequential(*res_layers) |
|
|
final_layer = nn.Conv2d(64, 1, kernel_size=3, padding=1) |
|
|
layers = OrderedDict() |
|
|
layers["resnet"] = res_layers |
|
|
layers["final"] = final_layer |
|
|
self.proj = nn.Sequential(layers) |
|
|
self.proj.apply(weights_init_kaiming) |
|
|
self.proj.apply(weights_init_classifier) |
|
|
|
|
|
def forward(self,x, twod_input): |
|
|
|
|
|
input = x[:,1:-1] |
|
|
_,attn_map,out_dict = self.sentence_encoder(x,twod_tokens=twod_input,is_twod=True,extra_only=True, masked_only=False) |
|
|
final_attn = attn_map[-1][:,5:6,1:-1,1:-1] |
|
|
|
|
|
|
|
|
out = self.conv1(final_attn) |
|
|
out = self.dropout(out) |
|
|
out = self.relu(out) |
|
|
out = self.conv2(out) |
|
|
out = torch.cat((out, final_attn), dim=1) |
|
|
|
|
|
out = self.proj(out) |
|
|
|
|
|
output = (out + out.permute(0,1,3,2)) |
|
|
return output |
|
|
|
|
|
class RNASSResnet32(nn.Module): |
|
|
def __init__(self, depth=32): |
|
|
super().__init__() |
|
|
self.depth = depth |
|
|
|
|
|
self.fc1 = nn.Linear(in_features=768, out_features=128) |
|
|
self.fc1.apply(weights_init_kaiming) |
|
|
self.fc1.apply(weights_init_classifier) |
|
|
|
|
|
|
|
|
self.Conv2d_1 = nn.Conv2d(in_channels = 256, out_channels = 64, kernel_size = 1) |
|
|
self.Conv2d_1.apply(weights_init_kaiming) |
|
|
|
|
|
|
|
|
res_layers = [] |
|
|
for i in range(self.depth): |
|
|
dilation = pow(2, (i % 3)) |
|
|
res_layers.append(MyBasicResBlock(inplanes=64, planes=64, dilation=dilation)) |
|
|
res_layers = nn.Sequential(*res_layers) |
|
|
|
|
|
|
|
|
final_layer = nn.Conv2d(64, 1, kernel_size=3, padding=1) |
|
|
layers = OrderedDict() |
|
|
layers["resnet"] = res_layers |
|
|
layers["final"] = final_layer |
|
|
self.proj = nn.Sequential(layers) |
|
|
self.proj.apply(weights_init_kaiming) |
|
|
self.proj.apply(weights_init_classifier) |
|
|
|
|
|
|
|
|
def forward(self,x): |
|
|
x = self.fc1(x) |
|
|
|
|
|
batch_size, seqlen, hiddendim = x.size() |
|
|
x = x.unsqueeze(2).expand(batch_size, seqlen, seqlen, hiddendim) |
|
|
x_T = x.permute(0,2,1,3) |
|
|
x_concat = torch.cat([x,x_T],dim=3) |
|
|
x = x_concat.permute(0,3,1,2) |
|
|
x = self.Conv2d_1(x) |
|
|
|
|
|
x = self.proj(x) |
|
|
upper_triangular_x = torch.triu(x) |
|
|
lower_triangular_x = torch.triu(x,diagonal=1).permute(0,1,3,2) |
|
|
output = upper_triangular_x + lower_triangular_x |
|
|
|
|
|
return output |
|
|
|
|
|
class MyBasicResBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
inplanes: int, |
|
|
planes: int, |
|
|
stride: int = 1, |
|
|
groups: int = 1, |
|
|
base_width: int = 64, |
|
|
dilation: int = 1, |
|
|
) -> None: |
|
|
super(MyBasicResBlock, self).__init__() |
|
|
if groups != 1 or base_width != 64: |
|
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(inplanes) |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
self.conv1 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, bias = False) |
|
|
self.dropout = nn.Dropout(p=0.3) |
|
|
self.relu2 = nn.ReLU(inplace=True) |
|
|
|
|
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) |
|
|
|
|
|
self.stride = stride |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
|
|
|
out = self.bn1(x) |
|
|
out = self.relu1(out) |
|
|
out = self.conv1(out) |
|
|
out = self.dropout(out) |
|
|
out = self.relu2(out) |
|
|
out = self.conv2(out) |
|
|
|
|
|
out += identity |
|
|
return out |
|
|
|
|
|
class ErnieRNAOnestage(nn.Module): |
|
|
''' |
|
|
input_x: shape like: [B,L+2], one for "cls" token, another for "eos" token |
|
|
input_twod_input: shape like: [1,L+2,L+2,1], calculated from input_x |
|
|
outpus: shape like: [B,L+2,768], 768 is the dim of embedding extracted by pre-train model |
|
|
''' |
|
|
def __init__(self, sentence_encoder): |
|
|
super().__init__() |
|
|
self.sentence_encoder = sentence_encoder |
|
|
def forward(self,x,twod_input,return_attn_map = False,i=12,j=5,layer_idx=12): |
|
|
_,attn_map_lst,out_dict = self.sentence_encoder(x,twod_tokens=twod_input,is_twod=True,extra_only=True, masked_only=False) |
|
|
x = torch.stack(out_dict['inner_states'][1:]).transpose(1,2) |
|
|
if layer_idx != 12: |
|
|
x = x[layer_idx,:,:,:].unsqueeze(0) |
|
|
|
|
|
if return_attn_map: |
|
|
L = attn_map_lst[0].shape[2] |
|
|
attnmap = torch.stack(attn_map_lst).transpose(0,1) |
|
|
|
|
|
if i == 13 and j == 12: |
|
|
attnmap = attnmap.view(156,L,L) |
|
|
return attnmap |
|
|
elif i == 13: |
|
|
return attnmap[0,:,j,:,:] |
|
|
elif j == 12: |
|
|
return attnmap[0,i,:,:,:] |
|
|
else: |
|
|
return attnmap[0,i,j,:,:] |
|
|
return x |
|
|
|