Spaces:
Sleeping
Sleeping
| import math | |
| import copy | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MolTrans(nn.Module): | |
| """ | |
| Interaction Network with 2D interaction map | |
| """ | |
| def __init__( | |
| self, | |
| input_dim_drug: 23532, | |
| input_dim_target: 16693, | |
| max_drug_seq, | |
| max_protein_seq, | |
| emb_size: 384, | |
| dropout_rate: 0.1, | |
| # DenseNet | |
| scale_down_ratio: 0.25, | |
| growth_rate: 20, | |
| transition_rate: 0.5, | |
| num_dense_blocks: 4, | |
| kernal_dense_size: 3, | |
| # Encoder | |
| intermediate_size: 1536, | |
| num_attention_heads: 12, | |
| attention_probs_dropout_prob: 0.1, | |
| hidden_dropout_prob: 0.1, | |
| # flatten_dim: 78192, | |
| # batch_size | |
| ): | |
| super().__init__() | |
| self.max_d = max_drug_seq | |
| self.max_p = max_protein_seq | |
| self.emb_size = emb_size | |
| self.dropout_rate = dropout_rate | |
| # densenet | |
| self.scale_down_ratio = scale_down_ratio | |
| self.growth_rate = growth_rate | |
| self.transition_rate = transition_rate | |
| self.num_dense_blocks = num_dense_blocks | |
| self.kernal_dense_size = kernal_dense_size | |
| # self.batch_size = batch_size | |
| self.input_dim_drug = input_dim_drug | |
| self.input_dim_target = input_dim_target | |
| self.n_layer = 2 | |
| # encoder | |
| self.hidden_size = emb_size | |
| self.intermediate_size = intermediate_size | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
| self.hidden_dropout_prob = hidden_dropout_prob | |
| # self.flatten_dim = flatten_dim | |
| # specialized embedding with positional one | |
| self.demb = Embeddings(self.input_dim_drug, self.emb_size, self.max_d, self.dropout_rate) | |
| self.pemb = Embeddings(self.input_dim_target, self.emb_size, self.max_p, self.dropout_rate) | |
| self.d_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, | |
| self.num_attention_heads, self.attention_probs_dropout_prob, | |
| self.hidden_dropout_prob) | |
| self.p_encoder = EncoderMultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size, | |
| self.num_attention_heads, self.attention_probs_dropout_prob, | |
| self.hidden_dropout_prob) | |
| self.icnn = nn.Conv2d(1, 3, 3, padding=0) | |
| self.decoder = nn.Sequential( | |
| # nn.Linear(self.flatten_dim, 512), | |
| nn.LazyLinear(512), | |
| nn.ReLU(True), | |
| nn.BatchNorm1d(512), | |
| nn.Linear(512, 64), | |
| nn.ReLU(True), | |
| nn.BatchNorm1d(64), | |
| nn.Linear(64, 32), | |
| nn.ReLU(True), | |
| # # output layer | |
| # nn.Linear(32, 1) | |
| ) | |
| def forward(self, v_d, v_p): | |
| d, d_mask = v_d | |
| p, p_mask = v_p | |
| ex_d_mask = d_mask.unsqueeze(1).unsqueeze(2) | |
| ex_p_mask = p_mask.unsqueeze(1).unsqueeze(2) | |
| ex_d_mask = (1.0 - ex_d_mask) * -10000.0 | |
| ex_p_mask = (1.0 - ex_p_mask) * -10000.0 | |
| d_emb = self.demb(d) # batch_size x seq_length x embed_size | |
| p_emb = self.pemb(p) | |
| batch_size = d_emb.size(0) | |
| # set output_all_encoded_layers be false, to obtain the last layer hidden states only. | |
| d_encoded_layers = self.d_encoder(d_emb.float(), ex_d_mask.float()) | |
| # print(d_encoded_layers.shape) | |
| p_encoded_layers = self.p_encoder(p_emb.float(), ex_p_mask.float()) | |
| # print(p_encoded_layers.shape) | |
| # repeat to have the same tensor size for aggregation | |
| d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, self.max_p, 1) # repeat along protein size | |
| p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, self.max_d, 1, 1) # repeat along drug size | |
| i = d_aug * p_aug # interaction | |
| i_v = i.view(batch_size, -1, self.max_d, self.max_p) | |
| # batch_size x embed size x max_drug_seq_len x max_protein_seq_len | |
| i_v = torch.sum(i_v, dim=1) | |
| i_v = torch.unsqueeze(i_v, 1) | |
| i_v = F.dropout(i_v, p=self.dropout_rate) | |
| i = self.icnn(i_v).view(batch_size, -1) | |
| score = self.decoder(i) | |
| return score | |
| class LayerNorm(nn.Module): | |
| def __init__(self, hidden_size, variance_epsilon=1e-12): | |
| super(LayerNorm, self).__init__() | |
| self.gamma = nn.Parameter(torch.ones(hidden_size)) | |
| self.beta = nn.Parameter(torch.zeros(hidden_size)) | |
| self.variance_epsilon = variance_epsilon | |
| def forward(self, x): | |
| u = x.mean(-1, keepdim=True) | |
| s = (x - u).pow(2).mean(-1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
| return self.gamma * x + self.beta | |
| class Embeddings(nn.Module): | |
| """Construct the embeddings from protein/target, position embeddings. | |
| """ | |
| def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): | |
| super(Embeddings, self).__init__() | |
| self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |
| self.position_embeddings = nn.Embedding(max_position_size, hidden_size) | |
| self.LayerNorm = LayerNorm(hidden_size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| def forward(self, input_ids): | |
| seq_length = input_ids.size(1) | |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
| words_embeddings = self.word_embeddings(input_ids) | |
| position_embeddings = self.position_embeddings(position_ids) | |
| embeddings = words_embeddings + position_embeddings | |
| embeddings = self.LayerNorm(embeddings) | |
| embeddings = self.dropout(embeddings) | |
| return embeddings | |
| class SelfAttention(nn.Module): | |
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |
| super(SelfAttention, self).__init__() | |
| if hidden_size % num_attention_heads != 0: | |
| raise ValueError( | |
| "The hidden size (%d) is not a multiple of the number of attention " | |
| "heads (%d)" % (hidden_size, num_attention_heads)) | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_head_size = int(hidden_size / num_attention_heads) | |
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |
| self.query = nn.Linear(hidden_size, self.all_head_size) | |
| self.key = nn.Linear(hidden_size, self.all_head_size) | |
| self.value = nn.Linear(hidden_size, self.all_head_size) | |
| self.dropout = nn.Dropout(attention_probs_dropout_prob) | |
| def transpose_for_scores(self, x): | |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |
| x = x.view(*new_x_shape) | |
| return x.permute(0, 2, 1, 3) | |
| def forward(self, hidden_states, attention_mask): | |
| mixed_query_layer = self.query(hidden_states) | |
| mixed_key_layer = self.key(hidden_states) | |
| mixed_value_layer = self.value(hidden_states) | |
| query_layer = self.transpose_for_scores(mixed_query_layer) | |
| key_layer = self.transpose_for_scores(mixed_key_layer) | |
| value_layer = self.transpose_for_scores(mixed_value_layer) | |
| # Take the dot product between "query" and "key" to get the raw attention scores. | |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
| attention_scores = attention_scores + attention_mask | |
| # Normalize the attention scores to probabilities. | |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) | |
| # This is actually dropping out entire tokens to attend to, which might | |
| # seem a bit unusual, but is taken from the original Transformer paper. | |
| attention_probs = self.dropout(attention_probs) | |
| context_layer = torch.matmul(attention_probs, value_layer) | |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |
| context_layer = context_layer.view(*new_context_layer_shape) | |
| return context_layer | |
| class SelfOutput(nn.Module): | |
| def __init__(self, hidden_size, hidden_dropout_prob): | |
| super(SelfOutput, self).__init__() | |
| self.dense = nn.Linear(hidden_size, hidden_size) | |
| self.LayerNorm = LayerNorm(hidden_size) | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | |
| def forward(self, hidden_states, input_tensor): | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class Attention(nn.Module): | |
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |
| super(Attention, self).__init__() | |
| self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |
| self.output = SelfOutput(hidden_size, hidden_dropout_prob) | |
| def forward(self, input_tensor, attention_mask): | |
| self_output = self.self(input_tensor, attention_mask) | |
| attention_output = self.output(self_output, input_tensor) | |
| return attention_output | |
| class Intermediate(nn.Module): | |
| def __init__(self, hidden_size, intermediate_size): | |
| super(Intermediate, self).__init__() | |
| self.dense = nn.Linear(hidden_size, intermediate_size) | |
| def forward(self, hidden_states): | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = F.relu(hidden_states) | |
| return hidden_states | |
| class Output(nn.Module): | |
| def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): | |
| super(Output, self).__init__() | |
| self.dense = nn.Linear(intermediate_size, hidden_size) | |
| self.LayerNorm = LayerNorm(hidden_size) | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | |
| def forward(self, hidden_states, input_tensor): | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class Encoder(nn.Module): | |
| def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
| hidden_dropout_prob): | |
| super(Encoder, self).__init__() | |
| self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) | |
| self.intermediate = Intermediate(hidden_size, intermediate_size) | |
| self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) | |
| def forward(self, hidden_states, attention_mask): | |
| attention_output = self.attention(hidden_states, attention_mask) | |
| intermediate_output = self.intermediate(attention_output) | |
| layer_output = self.output(intermediate_output, attention_output) | |
| return layer_output | |
| class EncoderMultipleLayers(nn.Module): | |
| def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
| hidden_dropout_prob): | |
| super().__init__() | |
| layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, | |
| hidden_dropout_prob) | |
| self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) | |
| def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |
| all_encoder_layers = [] | |
| for layer_module in self.layer: | |
| hidden_states = layer_module(hidden_states, attention_mask) | |
| # if output_all_encoded_layers: | |
| # all_encoder_layers.append(hidden_states) | |
| # if not output_all_encoded_layers: | |
| # all_encoder_layers.append(hidden_states) | |
| return hidden_states | |