maotao / src /mRNA2vec /mrna2vec_model.py
julse's picture
upload AA2CDS
4707555 verified
from transformers import AutoModel, AutoConfig, AutoTokenizer
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from EMA import EMA
import torch.optim as optim
import transformers
from typing import Callable, Optional, Union, Tuple, List
from transformers import T5EncoderModel, T5Config
class mRNA2vec(nn.Module):
def __init__(self, encoder, **kwargs):
super(mRNA2vec, self).__init__()
self.embed_dim = 256
self.encoder = encoder
self.__dict__.update(kwargs)
self.ema = EMA(self.encoder) # EMA acts as the teacher
self.regression_head = self._build_regression_head()
#self.clf = nn.Linear(2048*2, 1)
self.clf = nn.Sequential(nn.Linear(256, 200),
nn.GELU(),
nn.Linear(200, 1),)
self.clf_ss = self._build_ss_head()
#self.ema_decay = self.cfg.model.ema_decay
#self.ema_end_decay = self.cfg.model.ema_end_decay
#self.ema_anneal_end_step = self.cfg.model.ema_anneal_end_step
def _build_ss_head(self):
return nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 2),
nn.GELU(),
nn.Linear(self.embed_dim * 2, 28))
def _build_regression_head(self):
return nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 2),
nn.GELU(),
nn.Linear(self.embed_dim * 2, self.embed_dim))
def _clf(self,x):
x = self.clf(x)
return x
def ema_step(self):
self.ema.step(self.encoder)
def forward(self, src, trg=None, mask=None,label_mask = None, **kwargs):
# model forward in online mode (student)
outputs = self.encoder(src, mask) # fetch the last layer outputs
x = outputs['encoder_out']
mask_ = src[:,1:-1].clone().gt(1)
x_ss_clf = x[:,1:-1,:].clone()
x_ss_clf = x_ss_clf[mask_]
x_ss_clf = self.clf_ss(x_ss_clf)
if trg is None:
return x
with torch.no_grad():
self.ema.model.eval()
y = self.ema.model(trg, mask)['encoder_states']
y = y[-3:] # take the last k transformer layers
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
y = sum(y) / len(y)
y = F.layer_norm(y.float(), y.shape[-1:])
e = self._clf(x.mean(dim =1))
masked_indices = src.eq(4)
x = x[masked_indices]
y = y[masked_indices]
x = self.regression_head(x)
return x, y, e, x_ss_clf
class T5_encoder(nn.Module):
def __init__(self,
vocab_size=69,
hidden_size=256,
num_hidden_layers=4,
num_attention_heads=8,
pad_token_id=1
):
super().__init__()
self.embed_dim = hidden_size
model_cofig = T5Config()
model_cofig.d_model = hidden_size
model_cofig.num_attention_heads = num_attention_heads
model_cofig.d_kv = hidden_size//num_attention_heads
model_cofig.pad_token_id = pad_token_id
model_cofig.num_layers = num_hidden_layers
model_cofig.d_ff = hidden_size * 4
model_cofig.vocab_size = vocab_size
self.encoder = T5EncoderModel(config=model_cofig)
def forward(self, x, attention_mask):
outputs = self.encoder(x,
attention_mask=attention_mask,
return_dict=True,
output_attentions = True,
output_hidden_states = True,
)
encoder_states =outputs.hidden_states[:]
encoder_out = outputs.hidden_states[-1]
attentions = outputs.attentions
return {
'encoder_states': encoder_states,
'encoder_out': encoder_out,
'attentions': attentions
}
class ConvMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
bias=True,
drop=0.2,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv1d(in_features, hidden_features, kernel_size=1, )
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc2 = nn.Conv1d(hidden_features, out_features, kernel_size=1,)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x
class LayerScale1d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class ConvNeXtBlock1D(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
conv_mlp: bool = True,
drop_path: float = 0.
):
super().__init__()
out_chs = out_chs or in_chs
act_layer = nn.GELU
norm_layer = nn.LayerNorm
mlp_layer = ConvMlp(in_features=out_chs, hidden_features=2 * out_chs, act_layer=act_layer) # You will need to modify ConvMlp to be 1D as well
self.use_conv_mlp = conv_mlp
if stride == 2:
self.shortcut = nn.AvgPool1d(3, stride=2, padding=1)
elif in_chs != out_chs:
self.shortcut = nn.Conv1d(in_chs, out_chs, kernel_size=1)
else:
self.shortcut = nn.Identity()
if stride == 2:
self.down = nn.AvgPool1d(3, stride=2, padding=1)
else:
self.down = nn.Identity()
self.conv_dw = nn.Conv1d(
in_chs, out_chs, kernel_size=kernel_size, stride=1, padding=kernel_size//2, groups=1)
self.norm = norm_layer(normalized_shape=out_chs)
self.mlp = mlp_layer # Modify ConvMlp accordingly
self.ls = LayerScale1d(out_chs, 1e-6)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x
x = self.down(x)
x = self.conv_dw(x)
x = self.norm(x.transpose(1, 2)).transpose(1, 2) # LayerNorm expects the channel as the last dimension
x = self.mlp(x)
x = self.ls(x)
x = self.drop_path(x) + shortcut
return x
class Regression_Model(nn.Module):
def __init__(self,
vocab_size=69,
hidden_size=512,
num_hidden_layers=12,
num_attention_heads=8,
pad_token_id=1
):
super().__init__()
self.T5_encoder = T5_encoder(hidden_size=hidden_size,
num_attention_heads = num_attention_heads,
num_hidden_layers= num_hidden_layers,)
self.proj1 = nn.Linear(hidden_size, hidden_size // 32)
self.proj2 = nn.Linear(32 * hidden_size // 32, 128)
self.conv_blocks = nn.Sequential(
ConvNeXtBlock1D(hidden_size, hidden_size // 2, stride=1),
ConvNeXtBlock1D(hidden_size // 2, hidden_size // 2, stride=2),
ConvNeXtBlock1D(hidden_size // 2, hidden_size // 4, stride=1),
ConvNeXtBlock1D(hidden_size // 4, hidden_size // 4, stride=1),
ConvNeXtBlock1D(hidden_size // 4, hidden_size // 4, stride=1),
ConvNeXtBlock1D(hidden_size // 4, hidden_size // 4, stride=2)
)
self.cls = nn.Sequential(nn.Linear(hidden_size, 1),
#nn.GELU(),
#nn.Linear(200, 1),
)
self.loss_fn = nn.MSELoss()
#self.loss_fn = nn.CrossEntropyLoss()
def forward_logit_linear(self, x, mask):
#with torch.no_grad():
x = self.T5_encoder.encoder(x,attention_mask=mask,output_hidden_states = True,return_dict=True,)
x = x.hidden_states[-2].mean(dim =1) # take the last k transformer layers
x = x.reshape(x.size(0), -1)
x = self.cls(x)
return x
def forward_logit_cov(self, x, mask):
x = self.T5_encoder.encoder(x,attention_mask=mask,output_hidden_states = True,return_dict=True,)
x = x.hidden_states[-2]#[:,:12,:]
x = x.permute(0,2,1)
x = self.conv_blocks(x)
x = x.reshape(x.size(0), -1)
x = self.cls(x)
return x
def forward(self, x,mask,label):
x = self.forward_logit_linear(x,
mask,
)
loss = self.loss_fn(x.reshape(-1), label)
return loss