maotao / src /utils.py
julse's picture
upload AA2CDS
4707555 verified
import os
import math
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .ernie_rna.tasks.ernie_rna import *
from .ernie_rna.models.ernie_rna import *
from .ernie_rna.criterions.ernie_rna import *
def prepare_input_for_ernierna(index, seq_len):
shorten_index = index[:seq_len+2]
one_d = torch.from_numpy(shorten_index).long().reshape(1,-1)
two_d = np.zeros((1,seq_len+2,seq_len+2))
two_d[0,:,:] = creatmat(shorten_index.astype(int),base_range=1,lamda=0.8)
two_d = two_d.transpose(1,2,0)
two_d = torch.from_numpy(two_d).reshape(1,seq_len+2,seq_len+2,1)
return one_d, two_d
def read_text_file(file_path):
'''
input:
file_path: str, txt file path of input seqs
return:
lines: list[str], list of seqs
'''
try:
with open(file_path, 'r') as file:
lines = file.readlines()
lines = [line.strip() for line in lines]
return lines
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return []
def read_fasta_file(file_path):
'''
input:
file_path: str, fasta file path of input seqs
return:
seqs_dict: dict[str], dict of seqs
'''
try:
with open(file_path) as fa:
seqs_dict = {}
for line in fa:
line = line.replace('\n','')
if line.startswith('>'):
seq_name = line[1:]
seqs_dict[seq_name] = ''
else:
seqs_dict[seq_name] += line
return seqs_dict
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return {}
def save_rnass_results(file_path, seq_names_lst, ss_results_lst):
'''
input:
file_path: str, the path of rna ss extracted by ERNIE-RNA
seq_names_lst: list[str], names list of input seqs
ss_results_lst: list[list[str]], rna ss predicted by ernie_rna
return:
'''
os.makedirs(file_path, exist_ok=True)
fine_tune_prediction_result = open(file_path + 'fine_tune_results.txt','w+')
pretrain_prediction_result = open(file_path + 'pretrain_results.txt','w+')
for name, result_lst in zip(seq_names_lst, ss_results_lst):
fine_tune_prediction_result.write(name + '\n')
pretrain_prediction_result.write(name + '\n')
fine_tune_prediction_result.write(result_lst[0] + '\n')
pretrain_prediction_result.write(result_lst[1] + '\n')
fine_tune_prediction_result.close()
pretrain_prediction_result.close()
def save_rnavalue_results(file_path, seq_names_lst, ss_results_lst,label_lst=None):
'''
input:
file_path: str, the path of rna ss extracted by ERNIE-RNA
seq_names_lst: list[str], names list of input seqs
ss_results_lst: list[list[str]], rna ss predicted by ernie_rna
return:
'''
os.makedirs(file_path, exist_ok=True)
df = pd.DataFrame(ss_results_lst)
df.columns = ['sequence','pred']
df['_id'] = seq_names_lst
if label_lst is not None: df['label'] = label_lst
df.to_csv(file_path + 'pretrain_results.csv',index=False)
def load_pretrained_ernierna(mlm_pretrained_model_path,arg_overrides):
rna_models, _, _ = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model_path.split(os.pathsep),arg_overrides=arg_overrides)
model_pretrained = rna_models[0]
return model_pretrained
def gaussian(x):
return math.exp(-0.5*(x*x))
def paired(x,y,lamda=0.8):
if x == 5 and y == 6:
return 2
elif x == 4 and y == 7:
return 3
elif x == 4 and y == 6:
return lamda
elif x == 6 and y == 5:
return 2
elif x == 7 and y == 4:
return 3
elif x == 6 and y == 4:
return lamda
else:
return 0
base_range_lst = [1]
lamda_lst = [0.8]
def creatmat(data, base_range=30, lamda=0.8):
paird_map = np.array([[paired(i,j,lamda) for i in range(30)] for j in range(30)])
data_index = np.arange(0,len(data))
# np.indices((2,2))   
coefficient = np.zeros([len(data),len(data)])
# mat = np.zeros((len(data),len(data)))
score_mask = np.full((len(data),len(data)),True)
for add in range(base_range):
data_index_x = data_index - add
data_index_y = data_index + add
score_mask = ((data_index_x >= 0)[:,None] & (data_index_y < len(data))[None,:]) & score_mask
data_index_x,data_index_y = np.meshgrid(data_index_x.clip(0,len(data) - 1),data_index_y.clip(0,len(data) - 1),indexing='ij')
score = paird_map[data[data_index_x],data[data_index_y]]
score_mask = score_mask & (score != 0)
coefficient = coefficient + score * score_mask * gaussian(add)
if ~(score_mask.any()) :
break
score_mask = coefficient > 0
for add in range(1,base_range):
data_index_x = data_index + add
data_index_y = data_index - add
score_mask = ((data_index_x < len(data))[:,None] & (data_index_y >= 0)[None,:]) & score_mask
data_index_x,data_index_y = np.meshgrid(data_index_x.clip(0,len(data) - 1),data_index_y.clip(0,len(data) - 1),indexing='ij')
score = paird_map[data[data_index_x],data[data_index_y]]
score_mask = score_mask & (score != 0)
coefficient = coefficient + score * score_mask * gaussian(add)
if ~(score_mask.any()) :
break
return coefficient
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if isinstance(m.bias, nn.Parameter):
nn.init.constant_(m.bias, 0.0)
elif classname.find('BasicConv') != -1: # for googlenet
pass
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if isinstance(m.bias, nn.Parameter):
nn.init.constant_(m.bias, 0.0)
class ChooseModel(nn.Module):
def __init__(self, sentence_encoder):
super().__init__()
self.sentence_encoder = sentence_encoder
# self.head = rna_ss_resnet32()
# self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=7, stride=1, padding=3)
# self.relu = nn.ReLU(inplace=True)
# self.conv2 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1, padding=2)
self.conv1 = nn.Conv2d(1, 8, 7, 1, 3)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.3)
self.conv2 = nn.Conv2d(8, 63, 7, 1, 3)
self.depth = 8
res_layers = []
for i in range(self.depth):
dilation = pow(2, (i % 3))
res_layers.append(MyBasicResBlock(inplanes=64, planes=64, dilation=dilation))
res_layers = nn.Sequential(*res_layers)
final_layer = nn.Conv2d(64, 1, kernel_size=3, padding=1)
layers = OrderedDict()
layers["resnet"] = res_layers
layers["final"] = final_layer
self.proj = nn.Sequential(layers)
self.proj.apply(weights_init_kaiming)
self.proj.apply(weights_init_classifier)
def forward(self,x, twod_input):
input = x[:,1:-1]
_,attn_map,out_dict = self.sentence_encoder(x,twod_tokens=twod_input,is_twod=True,extra_only=True, masked_only=False)
final_attn = attn_map[-1][:,5:6,1:-1,1:-1]
out = self.conv1(final_attn)
out = self.dropout(out)
out = self.relu(out)
out = self.conv2(out)
out = torch.cat((out, final_attn), dim=1)
# print(out.shape)
out = self.proj(out)
output = (out + out.permute(0,1,3,2))
return output
class RNASSResnet32(nn.Module):
def __init__(self, depth=32):
super().__init__()
self.depth = depth
# 进入ResNet32之前:1、fc到128
self.fc1 = nn.Linear(in_features=768, out_features=128)
self.fc1.apply(weights_init_kaiming)
self.fc1.apply(weights_init_classifier)
# 进入ResNet32之前:2、conv到64
self.Conv2d_1 = nn.Conv2d(in_channels = 256, out_channels = 64, kernel_size = 1)
self.Conv2d_1.apply(weights_init_kaiming)
# 定义ResNet32参数
res_layers = []
for i in range(self.depth):
dilation = pow(2, (i % 3))
res_layers.append(MyBasicResBlock(inplanes=64, planes=64, dilation=dilation))
res_layers = nn.Sequential(*res_layers)
# final_layer = nn.Conv2d(64, 2, kernel_size=3, padding=1)
final_layer = nn.Conv2d(64, 1, kernel_size=3, padding=1)
layers = OrderedDict()
layers["resnet"] = res_layers
layers["final"] = final_layer
self.proj = nn.Sequential(layers)
self.proj.apply(weights_init_kaiming)
self.proj.apply(weights_init_classifier)
def forward(self,x):
x = self.fc1(x) # -> [B,T,128]
# x = x.squeeze()
batch_size, seqlen, hiddendim = x.size()
x = x.unsqueeze(2).expand(batch_size, seqlen, seqlen, hiddendim)
x_T = x.permute(0,2,1,3)
x_concat = torch.cat([x,x_T],dim=3) # -> [B,T,T,C*2]
x = x_concat.permute(0,3,1,2) # -> [B,C*2,T,T]
x = self.Conv2d_1(x)
# ResNet32+output的conv处理
x = self.proj(x)
upper_triangular_x = torch.triu(x)
lower_triangular_x = torch.triu(x,diagonal=1).permute(0,1,3,2)
output = upper_triangular_x + lower_triangular_x
# return shape like [B,1,L,L]
return output
class MyBasicResBlock(nn.Module):
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
) -> None:
super(MyBasicResBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
# cjy commented
#if dilation > 1:
# raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, bias = False)
self.dropout = nn.Dropout(p=0.3)
self.relu2 = nn.ReLU(inplace=True)
# self.conv2 = conv3x3(planes, planes, dilation=dilation)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
#self.bn2 = norm_layer(planes)
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.relu1(out)
out = self.conv1(out)
out = self.dropout(out)
out = self.relu2(out)
out = self.conv2(out)
out += identity
return out
class ErnieRNAOnestage(nn.Module):
'''
input_x: shape like: [B,L+2], one for "cls" token, another for "eos" token
input_twod_input: shape like: [1,L+2,L+2,1], calculated from input_x
outpus: shape like: [B,L+2,768], 768 is the dim of embedding extracted by pre-train model
'''
def __init__(self, sentence_encoder):
super().__init__()
self.sentence_encoder = sentence_encoder
def forward(self,x,twod_input,return_attn_map = False,i=12,j=5,layer_idx=12):
_,attn_map_lst,out_dict = self.sentence_encoder(x,twod_tokens=twod_input,is_twod=True,extra_only=True, masked_only=False)
x = torch.stack(out_dict['inner_states'][1:]).transpose(1,2) # (12,T,B,C) -> (12,B,T,C)
if layer_idx != 12:
x = x[layer_idx,:,:,:].unsqueeze(0)
if return_attn_map:
L = attn_map_lst[0].shape[2]
attnmap = torch.stack(attn_map_lst).transpose(0,1) # (13,B,12,T,T) -> (B,13,12,T,T)
# atten1 = F.softmax(attn_map_lst[i][0,j], dim=-1)
if i == 13 and j == 12:
attnmap = attnmap.view(156,L,L)
return attnmap
elif i == 13:
return attnmap[0,:,j,:,:]
elif j == 12:
return attnmap[0,i,:,:,:]
else:
return attnmap[0,i,j,:,:]
return x