| | from utils.transformer_modules import * |
| | from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask |
| | from utils.hparams import HParams |
| |
|
| | use_cuda = torch.cuda.is_available() |
| |
|
| | class self_attention_block(nn.Module): |
| | def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, |
| | bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False): |
| | super(self_attention_block, self).__init__() |
| |
|
| | self.attention_map = attention_map |
| | self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map) |
| | self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout) |
| | self.dropout = nn.Dropout(layer_dropout) |
| | self.layer_norm_mha = LayerNorm(hidden_size) |
| | self.layer_norm_ffn = LayerNorm(hidden_size) |
| |
|
| | def forward(self, inputs): |
| | x = inputs |
| |
|
| | |
| | x_norm = self.layer_norm_mha(x) |
| |
|
| | |
| | if self.attention_map is True: |
| | y, weights = self.multi_head_attention(x_norm, x_norm, x_norm) |
| | else: |
| | y = self.multi_head_attention(x_norm, x_norm, x_norm) |
| |
|
| | |
| | x = self.dropout(x + y) |
| |
|
| | |
| | x_norm = self.layer_norm_ffn(x) |
| |
|
| | |
| | y = self.positionwise_convolution(x_norm) |
| |
|
| | |
| | y = self.dropout(x + y) |
| |
|
| | if self.attention_map is True: |
| | return y, weights |
| | return y |
| |
|
| | class bi_directional_self_attention(nn.Module): |
| | def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length, |
| | layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): |
| |
|
| | super(bi_directional_self_attention, self).__init__() |
| |
|
| | self.weights_list = list() |
| |
|
| | params = (hidden_size, |
| | total_key_depth or hidden_size, |
| | total_value_depth or hidden_size, |
| | filter_size, |
| | num_heads, |
| | _gen_bias_mask(max_length), |
| | layer_dropout, |
| | attention_dropout, |
| | relu_dropout, |
| | True) |
| |
|
| | self.attn_block = self_attention_block(*params) |
| |
|
| | params = (hidden_size, |
| | total_key_depth or hidden_size, |
| | total_value_depth or hidden_size, |
| | filter_size, |
| | num_heads, |
| | torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3), |
| | layer_dropout, |
| | attention_dropout, |
| | relu_dropout, |
| | True) |
| |
|
| | self.backward_attn_block = self_attention_block(*params) |
| |
|
| | self.linear = nn.Linear(hidden_size*2, hidden_size) |
| |
|
| | def forward(self, inputs): |
| | x, list = inputs |
| |
|
| | |
| | encoder_outputs, weights = self.attn_block(x) |
| | |
| | reverse_outputs, reverse_weights = self.backward_attn_block(x) |
| | |
| | outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2) |
| | y = self.linear(outputs) |
| |
|
| | |
| | self.weights_list = list |
| | self.weights_list.append(weights) |
| | self.weights_list.append(reverse_weights) |
| | return y, self.weights_list |
| |
|
| | class bi_directional_self_attention_layers(nn.Module): |
| | def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, |
| | filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0, |
| | attention_dropout=0.0, relu_dropout=0.0): |
| | super(bi_directional_self_attention_layers, self).__init__() |
| |
|
| | self.timing_signal = _gen_timing_signal(max_length, hidden_size) |
| | params = (hidden_size, |
| | total_key_depth or hidden_size, |
| | total_value_depth or hidden_size, |
| | filter_size, |
| | num_heads, |
| | max_length, |
| | layer_dropout, |
| | attention_dropout, |
| | relu_dropout) |
| | self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) |
| | self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)]) |
| | self.layer_norm = LayerNorm(hidden_size) |
| | self.input_dropout = nn.Dropout(input_dropout) |
| |
|
| | def forward(self, inputs): |
| | |
| | x = self.input_dropout(inputs) |
| |
|
| | |
| | x = self.embedding_proj(x) |
| |
|
| | |
| | x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) |
| |
|
| | |
| | y, weights_list = self.self_attn_layers((x, [])) |
| |
|
| | |
| | y = self.layer_norm(y) |
| | return y, weights_list |
| |
|
| | class BTC_model(nn.Module): |
| | def __init__(self, config): |
| | super(BTC_model, self).__init__() |
| |
|
| | self.timestep = config['timestep'] |
| | self.probs_out = config['probs_out'] |
| |
|
| | params = (config['feature_size'], |
| | config['hidden_size'], |
| | config['num_layers'], |
| | config['num_heads'], |
| | config['total_key_depth'], |
| | config['total_value_depth'], |
| | config['filter_size'], |
| | config['timestep'], |
| | config['input_dropout'], |
| | config['layer_dropout'], |
| | config['attention_dropout'], |
| | config['relu_dropout']) |
| |
|
| | self.self_attn_layers = bi_directional_self_attention_layers(*params) |
| | self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out']) |
| |
|
| | def forward(self, x, labels): |
| | labels = labels.view(-1, self.timestep) |
| | |
| | self_attn_output, weights_list = self.self_attn_layers(x) |
| |
|
| | |
| | if self.probs_out is True: |
| | logits = self.output_layer(self_attn_output) |
| | return logits |
| |
|
| | |
| | prediction,second = self.output_layer(self_attn_output) |
| | prediction = prediction.view(-1) |
| | second = second.view(-1) |
| |
|
| | |
| | loss = self.output_layer.loss(self_attn_output, labels) |
| | return prediction, loss, weights_list, second |
| |
|
| | if __name__ == "__main__": |
| | config = HParams.load("run_config.yaml") |
| | device = torch.device("cuda" if use_cuda else "cpu") |
| |
|
| | batch_size = 2 |
| | timestep = 108 |
| | feature_size = 144 |
| | num_chords = 25 |
| |
|
| | features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device) |
| | chords = torch.randint(25,(batch_size*timestep,)).to(device) |
| |
|
| | model = BTC_model(config=config.model).to(device) |
| |
|
| | prediction, loss, weights_list, second = model(features, chords) |
| | print(prediction.size()) |
| | print(loss) |
| |
|
| |
|
| |
|