| import torch |
| import torch.nn as nn |
| from encoder.mha import MultiHeadAttention |
| from encoder.attentive_pooling import SelfAttentionPooling |
|
|
| class FlippedReLU(nn.Module): |
| def __init__(self): |
| super(FlippedReLU, self).__init__() |
|
|
| def forward(self, x): |
| return torch.where(x < 0, x, torch.zeros_like(x)) |
|
|
| class PositionWiseFeedForward(nn.Module): |
| def __init__(self, d_model, d_ff): |
| super(PositionWiseFeedForward, self).__init__() |
| self.fc1 = nn.Linear(d_model, d_ff) |
| self.fc2 = nn.Linear(d_ff, d_model) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x): |
| return self.fc2(self.relu(self.fc1(x))) |
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, d_model, num_heads, d_ff, dropout): |
| super(EncoderLayer, self).__init__() |
| self.self_attn = MultiHeadAttention(d_model, num_heads) |
| self.feed_forward = PositionWiseFeedForward(d_model, d_ff) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x, prob_phn=None, mask=None, lambda_val=None): |
| attn_output, attn_mask = self.self_attn(x, x, x, prob_phn=prob_phn, mask=mask, lambda_val=lambda_val) |
| x = self.norm1(x + self.dropout(attn_output)) |
| ff_output = self.feed_forward(x) |
| x = self.norm2(x + self.dropout(ff_output)) |
| return x, attn_mask |
|
|
|
|
| class TransformerSelfAttention(nn.Module): |
| def __init__(self, input_dim, num_heads, dim_feedforward, number_Of_spks, dropout=0.0): |
| """EncoderBlock. |
| |
| Args: |
| input_dim: Dimensionality of the input |
| num_heads: Number of heads to use in the attention block |
| dim_feedforward: Dimensionality of the hidden layer in the MLP |
| dropout: Dropout probability to use in the dropout layers |
| """ |
| super().__init__() |
| |
| self.self_mha_attn = EncoderLayer(input_dim, num_heads, dim_feedforward*8,dropout) |
| self.attn_pooling = SelfAttentionPooling(input_dim) |
| self.emb1 = nn.Linear(input_dim*2, dim_feedforward*8) |
| self.emb2 = nn.Linear(input_dim*2, dim_feedforward*8) |
| self.emb2.weight.data = self.emb1.weight.data.clone() |
| self.emb2.bias.data = self.emb1.bias.data.clone() |
| self.bn = nn.BatchNorm1d(dim_feedforward*8) |
| self.act = nn.ReLU(inplace=True) |
| self.dropout = nn.Dropout(dropout) |
| self.classifier = nn.Linear(dim_feedforward*8, number_Of_spks) |
| self.flipped_relu = FlippedReLU() |
|
|
|
|
| def forward(self, x, prob_phn=None, mask=None, lambda_val=None): |
| |
| attn_out, attn_mask = self.self_mha_attn(x,prob_phn=prob_phn, mask=mask, lambda_val=lambda_val) |
| attn_mask= attn_mask.squeeze(1) |
| attn_out_mean,attn_out_std = self.attn_pooling(attn_out,attn_mask) |
| attn_concat = torch.cat((attn_out_mean, attn_out_std),dim=1).to(dtype=torch.float32) |
| |
| emb1 = self.emb1(attn_concat).to(dtype=torch.float32) |
| emb1 = self.act(emb1) |
|
|
| emb2 = self.emb2(attn_concat).to(dtype=torch.float32) |
| emb2 = self.flipped_relu(emb2) |
|
|
| emb = emb1 + emb2 |
| emb = self.bn(emb) |
| x = self.classifier(emb) |
| return x,emb |
|
|