|
|
import torch |
|
|
from torch import nn |
|
|
from .LMConfig import LMConfig |
|
|
from typing import Any, Optional, Tuple, List |
|
|
from .model_ribo import MiniMindLM |
|
|
|
|
|
|
|
|
class ConvNetCodon(nn.Module): |
|
|
def __init__(self, |
|
|
in_dim:int, |
|
|
hid_dim: int, |
|
|
out_dim: int, |
|
|
dropout: float = 0.): |
|
|
super(ConvNetCodon, self).__init__() |
|
|
CovTransformer_layers = 3 |
|
|
self.nodes = hid_dim |
|
|
self.dropout = nn.Dropout(dropout, inplace=True) |
|
|
self.relu = nn.ReLU() |
|
|
self.flatten = nn.Flatten() |
|
|
|
|
|
|
|
|
self.linear = nn.Linear(in_features = in_dim*6, out_features = self.nodes) |
|
|
self.linear_2 = nn.Linear(in_features = self.nodes, out_features = self.nodes * 4) |
|
|
self.linear_3 = nn.Linear(in_features = self.nodes * 4, out_features = self.nodes) |
|
|
self.output = nn.Linear(in_features = self.nodes, out_features = out_dim) |
|
|
def forward(self,x,self_attn_padding_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
frame_1 = x[:, 0::3, :] |
|
|
frame_2 = x[:, 1::3, :] |
|
|
frame_3 = x[:, 2::3, :] |
|
|
|
|
|
frame_1_max = torch.max(frame_1, dim=1)[0] |
|
|
frame_2_max = torch.max(frame_2, dim=1)[0] |
|
|
frame_3_max = torch.max(frame_3, dim=1)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frame_1_avg = torch.mean(frame_1, dim=1) |
|
|
frame_2_avg = torch.mean(frame_2, dim=1) |
|
|
frame_3_avg = torch.mean(frame_3, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pooled_output = torch.cat([frame_1_max, frame_1_avg, frame_2_max, frame_2_avg, frame_3_max, frame_3_avg], dim=1) |
|
|
x_pooled = self.flatten(pooled_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
o_linear = self.linear(x_pooled) |
|
|
o_linear_2 = self.linear_2(o_linear) |
|
|
o_linear_3 = self.linear_3(o_linear_2) |
|
|
|
|
|
o_relu = self.relu(o_linear_3) |
|
|
o_dropout = self.dropout(o_relu) |
|
|
o = self.output(o_dropout) |
|
|
|
|
|
return o |
|
|
class MiniMindLMForRegression(MiniMindLM): |
|
|
def __init__(self, params: LMConfig = None, output_dim=1): |
|
|
super().__init__(params) |
|
|
|
|
|
|
|
|
self.regression_head = ConvNetCodon(256,128,output_dim) |
|
|
|
|
|
|
|
|
def forward(self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
twod_tokens: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
**args): |
|
|
|
|
|
|
|
|
base_output = super().forward(input_ids=input_ids, twod_tokens=twod_tokens, |
|
|
past_key_values=past_key_values, use_cache=use_cache, **args) |
|
|
sentence_representation = base_output.embeddings |
|
|
|
|
|
|
|
|
|
|
|
regression_output = self.regression_head(sentence_representation) |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
'te': regression_output, |
|
|
'aux_loss': base_output.aux_loss, |
|
|
'past_key_values': base_output.past_key_values, |
|
|
'zero_shot':sentence_representation.mean(dim=(1,2)).reshape(-1,1) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|