rd_hubert / distiller_modules.py
Hguimaraes's picture
Upload model
8ff90c3 verified
"""
Distiller Modules
Author: Heng-Jui Chang (https://github.com/vectominist)
"""
import math
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from .distiller_w2v2_modules import (
MultiheadAttention,
SamePad,
get_activation_fn,
)
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# FIX: Check if the tensor is on the meta device
if data.is_meta:
return # Skip initialization; real weights will be loaded later
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
class SplitLinear(nn.Module):
"""Split Linear Layer"""
def __init__(self, in_dim, in_split, out_dim):
super().__init__()
self.in_dim = in_dim # Din
self.in_split = in_split # N
self.out_dim = out_dim # Dout
if in_split > 1:
# weight = torch.zeros((1, 1, self.in_split, self.in_dim, self.out_dim))
weight = torch.zeros((self.in_split, self.in_dim, self.out_dim))
self.weight = nn.Parameter(weight, requires_grad=True)
nn.init.uniform_(self.weight, -(self.in_dim**-0.5), self.in_dim**-0.5)
bias = torch.zeros((1, 1, self.in_split, self.out_dim))
self.bias = nn.Parameter(bias, requires_grad=True)
nn.init.uniform_(self.bias, -(self.in_dim**-0.5), self.in_dim**-0.5)
else:
self.layer = nn.Linear(self.in_dim, self.out_dim)
def forward(self, x: torch.Tensor):
# x: shape = B x T x NDin
if self.in_split == 1:
return self.layer(x)
else:
x = x.reshape(x.shape[0], x.shape[1], self.in_split, 1, self.in_dim)
# x: B x T x N x 1 x Din
out = torch.einsum("...klm,kmn->...kln", x, self.weight).squeeze(3)
# out: B x T x N x Dout
out = out + self.bias
return out.reshape(x.shape[0], x.shape[1], -1)
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
attention_type: str = "original",
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = get_activation_fn(activation_fn)
self.attention_type = attention_type
if attention_type == "original":
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
else:
raise NotImplementedError(f"Unknown attention type {attention_type}")
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
def forward_self_attn(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
):
if self.attention_type in ["original", "sparse"]:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
)
elif self.attention_type == "dynamic":
x = self.self_attn(x)
attn = None
return x, attn
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn = self.forward_self_attn(
x,
self_attn_mask=self_attn_mask,
need_weights=False,
self_attn_padding_mask=self_attn_padding_mask,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.forward_self_attn(
x,
self_attn_mask=self_attn_mask,
need_weights=need_weights,
self_attn_padding_mask=self_attn_padding_mask,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
print(f"[TransformerEncoder] - Attention type = {args.attention_type}")
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
attention_type=args.attention_type,
)
for _ in range(args.encoder_layers)
]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = nn.LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, attn_mask=None, get_hidden=False):
x, layer_results = self.extract_features(
x, padding_mask, attn_mask, get_hidden=get_hidden
)
if self.layer_norm_first:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, attn_mask=None, get_hidden=False):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
self_attn_mask=attn_mask,
)
if get_hidden:
layer_results.append(x.transpose(0, 1))
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results