| | """ |
| | Distiller Modules |
| | Author: Heng-Jui Chang (https://github.com/vectominist) |
| | """ |
| |
|
| | import math |
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | import torch.nn.functional as F |
| |
|
| | from .distiller_w2v2_modules import ( |
| | MultiheadAttention, |
| | SamePad, |
| | get_activation_fn, |
| | ) |
| |
|
| |
|
| | def init_bert_params(module): |
| | """ |
| | Initialize the weights specific to the BERT Model. |
| | This overrides the default initializations depending on the specified arguments. |
| | 1. If normal_init_linear_weights is set then weights of linear |
| | layer will be initialized using the normal distribution and |
| | bais will be set to the specified value. |
| | 2. If normal_init_embed_weights is set then weights of embedding |
| | layer will be initialized using the normal distribution. |
| | 3. If normal_init_proj_weights is set then weights of |
| | in_project_weight for MultiHeadAttention initialized using |
| | the normal distribution (to be validated). |
| | """ |
| |
|
| | def normal_(data): |
| | |
| | if data.is_meta: |
| | return |
| |
|
| | |
| | |
| | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) |
| |
|
| | if isinstance(module, nn.Linear): |
| | normal_(module.weight.data) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | if isinstance(module, nn.Embedding): |
| | normal_(module.weight.data) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | if isinstance(module, MultiheadAttention): |
| | normal_(module.q_proj.weight.data) |
| | normal_(module.k_proj.weight.data) |
| | normal_(module.v_proj.weight.data) |
| |
|
| |
|
| | class SplitLinear(nn.Module): |
| | """Split Linear Layer""" |
| |
|
| | def __init__(self, in_dim, in_split, out_dim): |
| | super().__init__() |
| |
|
| | self.in_dim = in_dim |
| | self.in_split = in_split |
| | self.out_dim = out_dim |
| |
|
| | if in_split > 1: |
| | |
| | weight = torch.zeros((self.in_split, self.in_dim, self.out_dim)) |
| | self.weight = nn.Parameter(weight, requires_grad=True) |
| | nn.init.uniform_(self.weight, -(self.in_dim**-0.5), self.in_dim**-0.5) |
| |
|
| | bias = torch.zeros((1, 1, self.in_split, self.out_dim)) |
| | self.bias = nn.Parameter(bias, requires_grad=True) |
| | nn.init.uniform_(self.bias, -(self.in_dim**-0.5), self.in_dim**-0.5) |
| | else: |
| | self.layer = nn.Linear(self.in_dim, self.out_dim) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| |
|
| | if self.in_split == 1: |
| | return self.layer(x) |
| | else: |
| | x = x.reshape(x.shape[0], x.shape[1], self.in_split, 1, self.in_dim) |
| | |
| |
|
| | out = torch.einsum("...klm,kmn->...kln", x, self.weight).squeeze(3) |
| | |
| | out = out + self.bias |
| |
|
| | return out.reshape(x.shape[0], x.shape[1], -1) |
| |
|
| |
|
| | class TransformerSentenceEncoderLayer(nn.Module): |
| | """ |
| | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained |
| | models. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embedding_dim: float = 768, |
| | ffn_embedding_dim: float = 3072, |
| | num_attention_heads: float = 8, |
| | dropout: float = 0.1, |
| | attention_dropout: float = 0.1, |
| | activation_dropout: float = 0.1, |
| | activation_fn: str = "relu", |
| | layer_norm_first: bool = False, |
| | attention_type: str = "original", |
| | ) -> None: |
| | super().__init__() |
| | |
| | self.embedding_dim = embedding_dim |
| | self.dropout = dropout |
| | self.activation_dropout = activation_dropout |
| |
|
| | |
| | self.activation_fn = get_activation_fn(activation_fn) |
| | self.attention_type = attention_type |
| | if attention_type == "original": |
| | self.self_attn = MultiheadAttention( |
| | self.embedding_dim, |
| | num_attention_heads, |
| | dropout=attention_dropout, |
| | self_attention=True, |
| | ) |
| | else: |
| | raise NotImplementedError(f"Unknown attention type {attention_type}") |
| |
|
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(self.activation_dropout) |
| | self.dropout3 = nn.Dropout(dropout) |
| |
|
| | self.layer_norm_first = layer_norm_first |
| |
|
| | |
| | self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) |
| | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
| | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
| |
|
| | |
| | self.final_layer_norm = nn.LayerNorm(self.embedding_dim) |
| |
|
| | def forward_self_attn( |
| | self, |
| | x: torch.Tensor, |
| | self_attn_mask: torch.Tensor = None, |
| | self_attn_padding_mask: torch.Tensor = None, |
| | need_weights: bool = False, |
| | ): |
| | if self.attention_type in ["original", "sparse"]: |
| | x, attn = self.self_attn( |
| | query=x, |
| | key=x, |
| | value=x, |
| | key_padding_mask=self_attn_padding_mask, |
| | need_weights=need_weights, |
| | attn_mask=self_attn_mask, |
| | ) |
| | elif self.attention_type == "dynamic": |
| | x = self.self_attn(x) |
| | attn = None |
| |
|
| | return x, attn |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | self_attn_mask: torch.Tensor = None, |
| | self_attn_padding_mask: torch.Tensor = None, |
| | need_weights: bool = False, |
| | att_args=None, |
| | ): |
| | """ |
| | LayerNorm is applied either before or after the self-attention/ffn |
| | modules similar to the original Transformer imlementation. |
| | """ |
| | residual = x |
| |
|
| | if self.layer_norm_first: |
| | x = self.self_attn_layer_norm(x) |
| | x, attn = self.forward_self_attn( |
| | x, |
| | self_attn_mask=self_attn_mask, |
| | need_weights=False, |
| | self_attn_padding_mask=self_attn_padding_mask, |
| | ) |
| | x = self.dropout1(x) |
| | x = residual + x |
| |
|
| | residual = x |
| | x = self.final_layer_norm(x) |
| | x = self.activation_fn(self.fc1(x)) |
| | x = self.dropout2(x) |
| | x = self.fc2(x) |
| | x = self.dropout3(x) |
| | x = residual + x |
| | else: |
| | x, attn = self.forward_self_attn( |
| | x, |
| | self_attn_mask=self_attn_mask, |
| | need_weights=need_weights, |
| | self_attn_padding_mask=self_attn_padding_mask, |
| | ) |
| |
|
| | x = self.dropout1(x) |
| | x = residual + x |
| | x = self.self_attn_layer_norm(x) |
| |
|
| | residual = x |
| | x = self.activation_fn(self.fc1(x)) |
| | x = self.dropout2(x) |
| | x = self.fc2(x) |
| | x = self.dropout3(x) |
| | x = residual + x |
| | x = self.final_layer_norm(x) |
| |
|
| | return x, attn |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.dropout = args.dropout |
| | self.embedding_dim = args.encoder_embed_dim |
| |
|
| | self.pos_conv = nn.Conv1d( |
| | self.embedding_dim, |
| | self.embedding_dim, |
| | kernel_size=args.conv_pos, |
| | padding=args.conv_pos // 2, |
| | groups=args.conv_pos_groups, |
| | ) |
| | dropout = 0 |
| | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) |
| | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) |
| | nn.init.constant_(self.pos_conv.bias, 0) |
| |
|
| | self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2) |
| | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) |
| |
|
| | print(f"[TransformerEncoder] - Attention type = {args.attention_type}") |
| | self.layers = nn.ModuleList( |
| | [ |
| | TransformerSentenceEncoderLayer( |
| | embedding_dim=self.embedding_dim, |
| | ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| | num_attention_heads=args.encoder_attention_heads, |
| | dropout=self.dropout, |
| | attention_dropout=args.attention_dropout, |
| | activation_dropout=args.activation_dropout, |
| | activation_fn=args.activation_fn, |
| | layer_norm_first=args.layer_norm_first, |
| | attention_type=args.attention_type, |
| | ) |
| | for _ in range(args.encoder_layers) |
| | ] |
| | ) |
| |
|
| | self.layer_norm_first = args.layer_norm_first |
| | self.layer_norm = nn.LayerNorm(self.embedding_dim) |
| | self.layerdrop = args.encoder_layerdrop |
| |
|
| | self.apply(init_bert_params) |
| |
|
| | def forward(self, x, padding_mask=None, attn_mask=None, get_hidden=False): |
| | x, layer_results = self.extract_features( |
| | x, padding_mask, attn_mask, get_hidden=get_hidden |
| | ) |
| |
|
| | if self.layer_norm_first: |
| | x = self.layer_norm(x) |
| |
|
| | return x, layer_results |
| |
|
| | def extract_features(self, x, padding_mask=None, attn_mask=None, get_hidden=False): |
| | if padding_mask is not None: |
| | x[padding_mask] = 0 |
| |
|
| | x_conv = self.pos_conv(x.transpose(1, 2)) |
| | x_conv = x_conv.transpose(1, 2) |
| | x = x + x_conv |
| |
|
| | if not self.layer_norm_first: |
| | x = self.layer_norm(x) |
| |
|
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| |
|
| | |
| | x = x.transpose(0, 1) |
| |
|
| | layer_results = [] |
| | for i, layer in enumerate(self.layers): |
| | dropout_probability = np.random.random() |
| | if not self.training or (dropout_probability > self.layerdrop): |
| | x, z = layer( |
| | x, |
| | self_attn_padding_mask=padding_mask, |
| | need_weights=False, |
| | self_attn_mask=attn_mask, |
| | ) |
| | if get_hidden: |
| | layer_results.append(x.transpose(0, 1)) |
| |
|
| | |
| | x = x.transpose(0, 1) |
| |
|
| | return x, layer_results |