|
|
import math |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
import partialsmiles as ps |
|
|
|
|
|
from SCMG.config import varables as VBS |
|
|
from torch.autograd import Variable |
|
|
import partialsmiles as ps |
|
|
from SCMG.utils.utils_rsd import * |
|
|
from rdkit import Chem |
|
|
from rdkit import RDLogger |
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
class PositionalEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(PositionalEncoder, self).__init__() |
|
|
self.Dropout = nn.Dropout(p=config[VBS.RATE_DROPOUT]) |
|
|
max_len = config[VBS.SIZE_BLOCK] |
|
|
pe = torch.zeros(max_len, config[VBS.DIM_EMBEDDING]) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, config[VBS.DIM_EMBEDDING], 2).float() * (-math.log(10000.0) / config[VBS.DIM_EMBEDDING])) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer('pe', pe) |
|
|
def forward(self, T): |
|
|
x = self.Dropout(self.pe[:,:T, :]) |
|
|
return x |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
assert config[VBS.DIM_ATTENTION] % config[VBS.NUM_HEADS] == 0 |
|
|
self.Key = nn.Linear(config[VBS.DIM_EMBEDDING], config[VBS.DIM_ATTENTION]) |
|
|
self.Query = nn.Linear(config[VBS.DIM_EMBEDDING], config[VBS.DIM_ATTENTION]) |
|
|
self.Value = nn.Linear(config[VBS.DIM_EMBEDDING], config[VBS.DIM_ATTENTION]) |
|
|
self.Dropout_Attention = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Dropout_Residue = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Projection = nn.Linear(config[VBS.DIM_ATTENTION], config[VBS.DIM_EMBEDDING]) |
|
|
self.NumberOfHeads = config[VBS.NUM_HEADS] |
|
|
self.DimHead = config[VBS.DIM_ATTENTION] // self.NumberOfHeads |
|
|
self.DimAttention = config[VBS.DIM_ATTENTION] |
|
|
|
|
|
def forward(self, X_1,X_2, mask=None): |
|
|
if X_2 is None: |
|
|
X_2 = X_1 |
|
|
BatchSize, T_Encoder, _ = X_1.size() |
|
|
BatchSize, T_Decoder, _ = X_2.size() |
|
|
K = self.Key( X_1).view(BatchSize, T_Encoder, self.NumberOfHeads,self.DimHead).transpose(1, 2) |
|
|
Q = self.Query(X_2).view(BatchSize, T_Decoder, self.NumberOfHeads,self.DimHead).transpose(1, 2) |
|
|
V = self.Value(X_1).view(BatchSize, T_Encoder, self.NumberOfHeads,self.DimHead).transpose(1, 2) |
|
|
|
|
|
ScoreAttention = (Q @ K.transpose(-2, -1)) / math.sqrt(self.DimHead) |
|
|
ScoreAttention = ScoreAttention.masked_fill(mask==0, -1e9) |
|
|
ScoreAttention = F.softmax(ScoreAttention, dim=-1) |
|
|
ScoreAttention = self.Dropout_Attention(ScoreAttention) |
|
|
|
|
|
|
|
|
Z = ScoreAttention @ V |
|
|
|
|
|
Z = Z.transpose(1, 2).contiguous().view(BatchSize, T_Decoder, self.DimAttention) |
|
|
|
|
|
Z = self.Dropout_Residue(self.Projection(Z)) |
|
|
return Z |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
if config[VBS.DIM_FEEDFORWARD] == 0: |
|
|
Dim_FeedForward = config[VBS.DIM_ATTENTION] *4 |
|
|
else: |
|
|
Dim_FeedForward = config[VBS.DIM_FEEDFORWARD] |
|
|
self.Linear1 = nn.Linear(config[VBS.DIM_EMBEDDING], Dim_FeedForward) |
|
|
self.GELU = nn.GELU() |
|
|
self.Linear2 = nn.Linear(Dim_FeedForward, config[VBS.DIM_EMBEDDING]) |
|
|
self.Dropout = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
|
|
|
def forward(self,x): |
|
|
x = self.Linear1(x) |
|
|
x = self.GELU (x) |
|
|
x = self.Dropout(x) |
|
|
x = self.Linear2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.LayerNorm1 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.LayerNorm2 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.Dropout1 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Dropout2 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Attention = Attention( config) |
|
|
self.FeedForward = FeedForward(config) |
|
|
|
|
|
def forward(self, X_Encoder,Mask_Encoder): |
|
|
X_Encoder = self.LayerNorm1(X_Encoder + self.Attention (self.Dropout1(X_Encoder), None, Mask_Encoder)) |
|
|
X_Encoder = self.LayerNorm2(X_Encoder + self.FeedForward(self.Dropout2(X_Encoder))) |
|
|
return X_Encoder |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.LayerNorm1 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.LayerNorm2 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.LayerNorm3 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.Dropout1 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Dropout2 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Dropout3 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.AttentionMasked = Attention( config) |
|
|
self.AttentionCross = Attention( config) |
|
|
self.FeedForward = FeedForward(config) |
|
|
|
|
|
def forward(self, X_Encoder,X_Decoder,Mask_Cross,Mask_Decoder): |
|
|
X_Decoder = self.LayerNorm1(X_Decoder + self.AttentionMasked(self.Dropout1(X_Decoder), None, Mask_Decoder)) |
|
|
X_Decoder = self.LayerNorm2(X_Decoder + self.AttentionCross ( X_Encoder, self.Dropout2(X_Decoder), Mask_Cross )) |
|
|
X_Decoder = self.LayerNorm3(X_Decoder + self.FeedForward (self.Dropout3(X_Decoder) )) |
|
|
return X_Decoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.Dim_Embedding = config[VBS.DIM_EMBEDDING] |
|
|
self.Token_Padding_Encoder = config["Token_Padding_Encoder"] |
|
|
self.Token_Padding_Decoder = config["Token_Padding_Decoder"] |
|
|
|
|
|
self.Embedding_Encoder = nn.Embedding(len(config["vocab_encoder"]), config[VBS.DIM_EMBEDDING]) |
|
|
self.Embedding_Decoder = nn.Embedding(len(config["vocab_decoder"]), config[VBS.DIM_EMBEDDING]) |
|
|
self.pos_emb = PositionalEncoder(config) |
|
|
|
|
|
self.Dropout1 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.Dropout2 = nn.Dropout(config[VBS.RATE_DROPOUT]) |
|
|
self.LayerNorm1 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
self.LayerNorm2 = nn.LayerNorm(config[VBS.DIM_EMBEDDING]) |
|
|
|
|
|
self.encoder_blocks = nn.ModuleList([EncoderBlock(config) for _ in range(config[VBS.NUM_LAYERS])]) |
|
|
self.decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config[VBS.NUM_LAYERS])]) |
|
|
|
|
|
self.head = nn.Linear(config[VBS.DIM_EMBEDDING], len(config["vocab_decoder"]), bias=False) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
self.optimizer = None |
|
|
self.Alpha_LabelSmoothing = None |
|
|
self.TokenWeight = None |
|
|
|
|
|
|
|
|
def _set_train_params(self,Config): |
|
|
self.Alpha_LabelSmoothing = Config["Alpha_LabelSmoothing"] |
|
|
self.TokenWeight = Config["TokenWeight"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
for p in module.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_optimizers(self,train_config): |
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=train_config[VBS.RATE_LEARNING]) |
|
|
return optimizer |
|
|
def init_scheduler(self,train_config): |
|
|
scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=train_config[VBS.SIZE_STEP], gamma=train_config[VBS.GAMMA]) |
|
|
return scheduler |
|
|
def get_collate_fn(self, vocab_encoder,vocab_decoder): |
|
|
def collate(results): |
|
|
X_Encoder = [a[0] for a in results] |
|
|
X_Decoder = [a[1] for a in results] |
|
|
Auxiliary = [a[2] for a in results] |
|
|
|
|
|
max_len_x = max([len(a) for a in X_Encoder]) |
|
|
max_len_y = max([len(a) for a in X_Decoder]) |
|
|
|
|
|
x = torch.tensor([(a+[vocab_encoder[VBS.TOKEN_PAD] for _ in range(max_len_x-len(a))]) for a in X_Encoder],dtype=torch.long) |
|
|
y = torch.tensor([(a+[vocab_decoder[VBS.TOKEN_PAD] for _ in range(max_len_y-len(a))]) for a in X_Decoder],dtype=torch.long) |
|
|
if isinstance(Auxiliary[0],list): |
|
|
MaxLen_Auxiliary = max([len(TruthTable) for TruthTable in Auxiliary]) |
|
|
Len_Vocab = len(self.List_Vocab_Decoder) |
|
|
Auxiliary = torch.tensor([TruthTable+[[0 for _ in range(Len_Vocab)] for _ in range(MaxLen_Auxiliary-len(TruthTable))] for TruthTable in Auxiliary]) |
|
|
|
|
|
|
|
|
return x,y,Auxiliary |
|
|
return collate |
|
|
def customize_model_fn(self,diex): |
|
|
def fn(diex): |
|
|
bos_token = diex[VBS.COLUMN_TASK_TYPE] |
|
|
|
|
|
x_in = self.tokenizer(diex[VBS.COLUMN_ENCODER]) |
|
|
if len(x_in)>0: |
|
|
x_in = [bos_token] + x_in + [VBS.TOKEN_END] |
|
|
x_in = [self.vocab_encoder[a] if a in self.vocab_encoder.keys() else self.vocab_encoder["<unk>"] for a in x_in ] |
|
|
|
|
|
y_in = self.tokenizer(diex[VBS.COLUMN_DECODER]) |
|
|
y_in = [bos_token] + y_in + [VBS.TOKEN_END] |
|
|
|
|
|
|
|
|
|
|
|
TruthTable = [] |
|
|
for CurrentIndex in range(1,len(y_in)): |
|
|
if (y_in[CurrentIndex] == "|" or "<" in y_in[CurrentIndex]) and y_in[CurrentIndex] != VBS.TOKEN_END: |
|
|
TruthTable.append([0 for _ in range(len(self.List_Vocab_Decoder))]) |
|
|
continue |
|
|
CurrentTruthTable = [] |
|
|
for CurrentToken in self.List_Vocab_Decoder: |
|
|
try: |
|
|
_ = ps.ParseSmiles("".join(y_in[1:CurrentIndex])+CurrentToken, partial=True) |
|
|
IsValid = 1 |
|
|
except: |
|
|
IsValid = 0 |
|
|
if CurrentToken == VBS.TOKEN_END: |
|
|
CurrentSMI = join_scaf_deco(diex[VBS.COLUMN_ENCODER],"".join(y_in[1:CurrentIndex])) |
|
|
if len(CurrentSMI) > 0: |
|
|
IsValid = 1 |
|
|
CurrentTruthTable.append(IsValid) |
|
|
TruthTable.append(CurrentTruthTable) |
|
|
|
|
|
|
|
|
y_in = [self.vocab_decoder[a] if a in self.vocab_decoder.keys() else self.vocab_decoder["<unk>"] for a in y_in ] |
|
|
Auxiliary = TruthTable |
|
|
return x_in,y_in,Auxiliary |
|
|
return fn |
|
|
def generate_masks(self,X_Encoder, X_Decoder): |
|
|
with torch.no_grad(): |
|
|
|
|
|
T = X_Decoder.shape[1] |
|
|
Mask_Encoder = (X_Encoder != self.Token_Padding_Encoder).unsqueeze(-2).unsqueeze(-2) |
|
|
Mask_Decoder = (X_Decoder != self.Token_Padding_Decoder).unsqueeze(-2).unsqueeze(-2).repeat(1,1,T,1) |
|
|
Mask_Cross = (X_Encoder != self.Token_Padding_Encoder).unsqueeze(-2).unsqueeze(-2) |
|
|
mask_tril = torch.tril(torch.ones(T, T)).view(1, 1, T, T).to(Mask_Decoder.device) |
|
|
Mask_Decoder = Mask_Decoder.masked_fill(mask_tril==0,0) |
|
|
return Mask_Encoder,Mask_Decoder,Mask_Cross |
|
|
|
|
|
def forward(self, X_Encoder, X_Decoder, Y_Decoder_Ref=None,Auxiliary=None): |
|
|
Mask_Encoder, Mask_Decoder,Mask_Cross = self.generate_masks(X_Encoder, X_Decoder) |
|
|
|
|
|
X_Encoder = self.Dropout1(self.Embedding_Encoder(X_Encoder) * math.sqrt(self.Dim_Embedding) + self.pos_emb(X_Encoder.size(1))) |
|
|
X_Decoder = self.Dropout2(self.Embedding_Decoder(X_Decoder) * math.sqrt(self.Dim_Embedding) + self.pos_emb(X_Decoder.size(1))) |
|
|
|
|
|
|
|
|
for encoder_block in self.encoder_blocks: |
|
|
X_Encoder = encoder_block(X_Encoder,Mask_Encoder) |
|
|
|
|
|
|
|
|
for decoder_block in self.decoder_blocks: |
|
|
X_Decoder = decoder_block(X_Encoder,X_Decoder,Mask_Cross,Mask_Decoder) |
|
|
|
|
|
Y_Decoder_Logits = self.head(X_Decoder) |
|
|
loss = None |
|
|
if Y_Decoder_Ref is not None: |
|
|
with torch.no_grad(): |
|
|
Y_OneHot = F.one_hot(Y_Decoder_Ref, num_classes=len(self.vocab_decoder)) * (1-self.Alpha_LabelSmoothing) |
|
|
|
|
|
LabelSmooth = torch.ones(len(self.List_Vocab_Decoder),device = Y_Decoder_Ref.device) * self.Alpha_LabelSmoothing / (len(self.List_Vocab_Decoder)-1) |
|
|
Y_OneHot = Y_OneHot + LabelSmooth |
|
|
|
|
|
TruthTables = Auxiliary |
|
|
Y_OneHot = Y_OneHot * TruthTables |
|
|
|
|
|
if self.TokenWeight is not None: |
|
|
Weight = torch.tensor( |
|
|
self.TokenWeight, |
|
|
device = Y_Decoder_Ref.device).unsqueeze(0).unsqueeze(0) |
|
|
Y_OneHot = Y_OneHot * Weight |
|
|
|
|
|
Y_OneHot[Y_Decoder_Ref==self.Token_Padding_Decoder] = 0. |
|
|
Y_Decoder_Logits_LogSoftmax = F.log_softmax(Y_Decoder_Logits,dim=-1) |
|
|
loss = -(Y_OneHot * Y_Decoder_Logits_LogSoftmax).sum(dim=-1) |
|
|
loss = loss.mean() |
|
|
|
|
|
return Y_Decoder_Logits, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|