| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from attention_module import MultiheadAttention, init_bert_params |
| | from utils import ( |
| | Fp32GroupNorm, |
| | Fp32LayerNorm, |
| | LayerNorm, |
| | SamePad, |
| | TransposeLast, |
| | get_activation_fn, |
| | index_put, |
| | pad_to_multiple, |
| | ) |
| |
|
| |
|
| | class ConvFeatureExtractionModel(nn.Module): |
| | def __init__( |
| | self, |
| | conv_layers: List[Tuple[int, int, int]], |
| | dropout: float = 0.0, |
| | mode: str = "default", |
| | conv_bias: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | assert mode in {"default", "layer_norm"} |
| |
|
| | def block( |
| | n_in, |
| | n_out, |
| | k, |
| | stride, |
| | is_layer_norm=False, |
| | is_group_norm=False, |
| | conv_bias=False, |
| | ): |
| | def make_conv(): |
| | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| | nn.init.kaiming_normal_(conv.weight) |
| | return conv |
| |
|
| | assert ( |
| | is_layer_norm and is_group_norm |
| | ) == False, "layer norm and group norm are exclusive" |
| |
|
| | if is_layer_norm: |
| | return nn.Sequential( |
| | make_conv(), |
| | nn.Dropout(p=dropout), |
| | nn.Sequential( |
| | TransposeLast(), |
| | Fp32LayerNorm(dim, elementwise_affine=True), |
| | TransposeLast(), |
| | ), |
| | nn.GELU(), |
| | ) |
| | elif is_group_norm: |
| | return nn.Sequential( |
| | make_conv(), |
| | nn.Dropout(p=dropout), |
| | Fp32GroupNorm(dim, dim, affine=True), |
| | nn.GELU(), |
| | ) |
| | else: |
| | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
| |
|
| | in_d = 1 |
| | self.conv_layers = nn.ModuleList() |
| | for i, cl in enumerate(conv_layers): |
| | assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| | (dim, k, stride) = cl |
| |
|
| | self.conv_layers.append( |
| | block( |
| | in_d, |
| | dim, |
| | k, |
| | stride, |
| | is_layer_norm=mode == "layer_norm", |
| | is_group_norm=mode == "default" and i == 0, |
| | conv_bias=conv_bias, |
| | ) |
| | ) |
| | in_d = dim |
| |
|
| | def forward(self, x): |
| | |
| | x = x.unsqueeze(1) |
| |
|
| | for conv in self.conv_layers: |
| | x = conv(x) |
| |
|
| | return x |
| |
|
| |
|
| | def make_conv_pos(e, k, g, is_batch_norm=False): |
| | pos_conv = nn.Conv1d( |
| | e, |
| | e, |
| | kernel_size=k, |
| | padding=k // 2, |
| | groups=g, |
| | ) |
| | dropout = 0 |
| | std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) |
| | nn.init.normal_(pos_conv.weight, mean=0, std=std) |
| | nn.init.constant_(pos_conv.bias, 0) |
| |
|
| | if not is_batch_norm: |
| | pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2) |
| | pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) |
| | else: |
| | batch_norm = nn.BatchNorm1d(e) |
| | pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) |
| |
|
| | return pos_conv |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | def build_encoder_layer(self, args, **kwargs): |
| | if args.layer_type == "transformer": |
| | layer = TransformerSentenceEncoderLayer( |
| | embedding_dim=self.embedding_dim, |
| | ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| | num_attention_heads=args.encoder_attention_heads, |
| | dropout=self.dropout, |
| | attention_dropout=args.attention_dropout, |
| | activation_dropout=args.activation_dropout, |
| | activation_fn=args.activation_fn, |
| | layer_norm_first=args.layer_norm_first, |
| | ) |
| | elif args.layer_type == "trf_adp": |
| | use_adp = False |
| | if args.adp_trf_idx == "all": |
| | use_adp = True |
| | else: |
| | adp_trf_idx = list( |
| | range(*[int(g) for g in args.adp_trf_idx.split(":")]) |
| | ) |
| | if kwargs.get("layer_idx", None) in adp_trf_idx: |
| | use_adp = True |
| | if use_adp: |
| | layer = TransformerSentenceEncoderWithAdapterLayer( |
| | embedding_dim=self.embedding_dim, |
| | ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| | num_attention_heads=args.encoder_attention_heads, |
| | dropout=self.dropout, |
| | attention_dropout=args.attention_dropout, |
| | activation_dropout=args.activation_dropout, |
| | activation_fn=args.activation_fn, |
| | layer_norm_first=args.layer_norm_first, |
| | adapter_num=args.adp_num, |
| | adapter_dim=args.adp_dim, |
| | adapter_act_fn=args.adp_act_fn, |
| | ) |
| | else: |
| | layer = TransformerSentenceEncoderLayer( |
| | embedding_dim=self.embedding_dim, |
| | ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| | num_attention_heads=args.encoder_attention_heads, |
| | dropout=self.dropout, |
| | attention_dropout=args.attention_dropout, |
| | activation_dropout=args.activation_dropout, |
| | activation_fn=args.activation_fn, |
| | layer_norm_first=args.layer_norm_first, |
| | ) |
| |
|
| | |
| | |
| | |
| | return layer |
| |
|
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.dropout = args.dropout |
| | self.embedding_dim = args.encoder_embed_dim |
| | self.required_seq_len_multiple = args.required_seq_len_multiple |
| |
|
| | pos_conv_depth = getattr(args, "pos_conv_depth", 1) |
| | if pos_conv_depth > 1: |
| | num_layers = args.pos_conv_depth |
| | k = max(3, args.conv_pos // num_layers) |
| |
|
| | def make_conv_block(e, k, g, l): |
| | return nn.Sequential( |
| | *[ |
| | nn.Sequential( |
| | nn.Conv1d( |
| | e, |
| | e, |
| | kernel_size=k, |
| | padding=k // 2, |
| | groups=g, |
| | ), |
| | SamePad(k), |
| | TransposeLast(), |
| | LayerNorm(e, elementwise_affine=False), |
| | TransposeLast(), |
| | nn.GELU(), |
| | ) |
| | for _ in range(l) |
| | ] |
| | ) |
| |
|
| | self.pos_conv = make_conv_block( |
| | self.embedding_dim, k, args.conv_pos_groups, num_layers |
| | ) |
| |
|
| | else: |
| | self.pos_conv = make_conv_pos( |
| | self.embedding_dim, |
| | args.conv_pos, |
| | args.conv_pos_groups, |
| | is_batch_norm=args.conv_pos_batch_norm |
| | if hasattr(args, "conv_pos_batch_norm") |
| | else False, |
| | ) |
| |
|
| | self.layers = nn.ModuleList( |
| | [ |
| | self.build_encoder_layer(args, layer_idx=ii) |
| | for ii in range(args.encoder_layers) |
| | ] |
| | ) |
| | self.layer_norm_first = args.layer_norm_first |
| | self.layer_norm = LayerNorm(self.embedding_dim) |
| | self.layerdrop = args.encoder_layerdrop |
| |
|
| | self.apply(init_bert_params) |
| |
|
| | def forward(self, x, padding_mask=None, layer=None, corpus_key=None): |
| | x, layer_results = self.extract_features( |
| | x, padding_mask, layer, corpus_key=corpus_key |
| | ) |
| |
|
| | if self.layer_norm_first and layer is None: |
| | x = self.layer_norm(x) |
| |
|
| | return x, layer_results |
| |
|
| | def extract_features( |
| | self, |
| | x, |
| | padding_mask=None, |
| | tgt_layer=None, |
| | min_layer=0, |
| | corpus_key=None, |
| | ): |
| | if padding_mask is not None: |
| | x = index_put(x, padding_mask, 0) |
| |
|
| | x_conv = self.pos_conv(x.transpose(1, 2)) |
| | x_conv = x_conv.transpose(1, 2) |
| | x = x + x_conv |
| |
|
| | if not self.layer_norm_first: |
| | x = self.layer_norm(x) |
| |
|
| | |
| | x, pad_length = pad_to_multiple( |
| | x, self.required_seq_len_multiple, dim=-2, value=0 |
| | ) |
| | if pad_length > 0 and padding_mask is None: |
| | padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) |
| | padding_mask[:, -pad_length:] = True |
| | else: |
| | padding_mask, _ = pad_to_multiple( |
| | padding_mask, self.required_seq_len_multiple, dim=-1, value=True |
| | ) |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| |
|
| | |
| | x = x.transpose(0, 1) |
| |
|
| | layer_results = [] |
| | r = None |
| |
|
| | for i, layer in enumerate(self.layers): |
| | dropout_probability = np.random.random() if self.layerdrop > 0 else 1 |
| | if not self.training or (dropout_probability > self.layerdrop): |
| | layer_check = layer |
| | |
| | |
| | if (corpus_key is None) or ( |
| | not isinstance( |
| | layer_check, |
| | (TransformerSentenceEncoderWithAdapterLayer,), |
| | ) |
| | ): |
| | x, (z, lr) = layer( |
| | x, |
| | self_attn_padding_mask=padding_mask, |
| | need_weights=False, |
| | ) |
| | else: |
| | x, (z, lr) = layer( |
| | x, |
| | self_attn_padding_mask=padding_mask, |
| | need_weights=False, |
| | corpus_key=corpus_key, |
| | ) |
| | if i >= min_layer: |
| | layer_results.append((x, z, lr)) |
| | if i == tgt_layer: |
| | r = x |
| | break |
| |
|
| | if r is not None: |
| | x = r |
| |
|
| | |
| | x = x.transpose(0, 1) |
| |
|
| | |
| | if pad_length > 0: |
| | x = x[:, :-pad_length] |
| |
|
| | def undo_pad(a, b, c): |
| | return ( |
| | a[:-pad_length], |
| | b[:-pad_length] if b is not None else b, |
| | c[:-pad_length], |
| | ) |
| |
|
| | layer_results = [undo_pad(*u) for u in layer_results] |
| |
|
| | return x, layer_results |
| |
|
| | def max_positions(self): |
| | """Maximum output length supported by the encoder.""" |
| | return self.args.max_positions |
| |
|
| | def upgrade_state_dict_named(self, state_dict, name): |
| | """Upgrade a (possibly old) state dict for new versions of fairseq.""" |
| | return state_dict |
| |
|
| |
|
| | class TransformerSentenceEncoderLayer(nn.Module): |
| | """ |
| | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained |
| | models. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embedding_dim: float = 768, |
| | ffn_embedding_dim: float = 3072, |
| | num_attention_heads: int = 8, |
| | dropout: float = 0.1, |
| | attention_dropout: float = 0.1, |
| | activation_dropout: float = 0.1, |
| | activation_fn: str = "relu", |
| | layer_norm_first: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | |
| | self.embedding_dim = embedding_dim |
| | self.dropout = dropout |
| | self.activation_dropout = activation_dropout |
| |
|
| | |
| | self.activation_fn = get_activation_fn(activation_fn) |
| | self.self_attn = MultiheadAttention( |
| | self.embedding_dim, |
| | num_attention_heads, |
| | dropout=attention_dropout, |
| | self_attention=True, |
| | ) |
| |
|
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(self.activation_dropout) |
| | self.dropout3 = nn.Dropout(dropout) |
| |
|
| | self.layer_norm_first = layer_norm_first |
| |
|
| | |
| | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) |
| | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
| | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
| |
|
| | |
| | self.final_layer_norm = LayerNorm(self.embedding_dim) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | self_attn_mask: torch.Tensor = None, |
| | self_attn_padding_mask: torch.Tensor = None, |
| | need_weights: bool = False, |
| | att_args=None, |
| | ): |
| | """ |
| | LayerNorm is applied either before or after the self-attention/ffn |
| | modules similar to the original Transformer imlementation. |
| | """ |
| | residual = x |
| |
|
| | if self.layer_norm_first: |
| | x = self.self_attn_layer_norm(x) |
| | x, attn = self.self_attn( |
| | query=x, |
| | key=x, |
| | value=x, |
| | key_padding_mask=self_attn_padding_mask, |
| | attn_mask=self_attn_mask, |
| | need_weights=False, |
| | ) |
| | x = self.dropout1(x) |
| | x = residual + x |
| |
|
| | residual = x |
| | x = self.final_layer_norm(x) |
| | x = self.activation_fn(self.fc1(x)) |
| | x = self.dropout2(x) |
| | x = self.fc2(x) |
| |
|
| | layer_result = x |
| |
|
| | x = self.dropout3(x) |
| | x = residual + x |
| | else: |
| | x, attn = self.self_attn( |
| | query=x, |
| | key=x, |
| | value=x, |
| | key_padding_mask=self_attn_padding_mask, |
| | need_weights=False, |
| | ) |
| |
|
| | x = self.dropout1(x) |
| | x = residual + x |
| |
|
| | x = self.self_attn_layer_norm(x) |
| |
|
| | residual = x |
| | x = self.activation_fn(self.fc1(x)) |
| | x = self.dropout2(x) |
| | x = self.fc2(x) |
| |
|
| | layer_result = x |
| |
|
| | x = self.dropout3(x) |
| | x = residual + x |
| | x = self.final_layer_norm(x) |
| |
|
| | return x, (attn, layer_result) |
| |
|
| |
|
| | class AdapterFast(nn.Module): |
| | def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): |
| | """ |
| | Implements adapter modules directly with 3D tensor weight as parameters |
| | and without using ModuleList orto speed up training throughput. |
| | """ |
| | super().__init__() |
| |
|
| | self.adapter_num = adapter_num |
| | self.input_dim = input_dim |
| | self.hidden_dim = hidden_dim |
| | self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) |
| | self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) |
| | self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) |
| | self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) |
| |
|
| | self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) |
| | self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) |
| | self.act_fn = nn.Identity() |
| | if act_fn == "relu": |
| | self.act_fn = nn.ReLU() |
| | elif act_fn == "gelu": |
| | self.act_fn = nn.GELU() |
| | elif act_fn == "selu": |
| | self.act_fn = nn.SELU() |
| | else: |
| | raise ValueError(f"unsupported {act_fn}") |
| |
|
| | self.input_dim = input_dim |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | for ii in range(self.adapter_num): |
| | nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) |
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) |
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| | nn.init.uniform_(self.b_a[ii], -bound, bound) |
| | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) |
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| | nn.init.uniform_(self.b_b[ii], -bound, bound) |
| |
|
| | nn.init.ones_(self.ln_W) |
| | nn.init.zeros_(self.ln_b) |
| |
|
| | def forward(self, x, adapter_id): |
| | ii = adapter_id |
| | h = x |
| | h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]) |
| | h = F.linear(h, self.W_a[ii], self.b_a[ii]) |
| | h = self.act_fn(h) |
| | h = F.linear(h, self.W_b[ii], self.b_b[ii]) |
| | outputs = h |
| | return outputs |
| |
|
| | def extra_repr(self): |
| | return "adapter={}, input_dim={}, hidden_dim={}".format( |
| | self.adapter_num, self.input_dim, self.hidden_dim |
| | ) |
| |
|
| |
|
| | class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): |
| | """ |
| | Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained |
| | models. An adapter module is added along with vanilla Transformer module. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embedding_dim: float = 768, |
| | ffn_embedding_dim: float = 3072, |
| | num_attention_heads: int = 8, |
| | dropout: float = 0.1, |
| | attention_dropout: float = 0.1, |
| | activation_dropout: float = 0.1, |
| | activation_fn: str = "relu", |
| | layer_norm_first: bool = False, |
| | adapter_num=201, |
| | adapter_dim=64, |
| | adapter_act_fn="relu", |
| | ) -> None: |
| | super().__init__( |
| | embedding_dim=embedding_dim, |
| | ffn_embedding_dim=ffn_embedding_dim, |
| | num_attention_heads=num_attention_heads, |
| | dropout=dropout, |
| | attention_dropout=attention_dropout, |
| | activation_dropout=activation_dropout, |
| | activation_fn=activation_fn, |
| | layer_norm_first=layer_norm_first, |
| | ) |
| |
|
| | self.adapter_num = adapter_num |
| | self.adapter_dim = adapter_dim |
| | self.adapter_layer = AdapterFast( |
| | adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | self_attn_mask: torch.Tensor = None, |
| | self_attn_padding_mask: torch.Tensor = None, |
| | need_weights: bool = False, |
| | att_args=None, |
| | corpus_key=None, |
| | ): |
| | x, (attn, layer_result) = super().forward( |
| | x=x, |
| | self_attn_mask=self_attn_mask, |
| | self_attn_padding_mask=self_attn_padding_mask, |
| | need_weights=need_weights, |
| | att_args=att_args, |
| | ) |
| | assert corpus_key is not None |
| | assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}" |
| | y = self.adapter_layer(x, corpus_key[0]) |
| | x = x + y |
| | return x, (attn, layer_result) |
| |
|