| |
| |
|
|
| import math |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from fairscale.nn import checkpoint_wrapper, wrap |
|
|
| from vlmo.torchscale.architecture.utils import init_bert_params |
| from vlmo.torchscale.component.droppath import DropPath |
| from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts |
| from vlmo.torchscale.component.multihead_attention import MultiheadAttention |
| from vlmo.torchscale.component.relative_position_bias import RelativePositionBias |
| from vlmo.torchscale.component.xmoe.moe_layer import MOELayer |
| from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate |
|
|
| try: |
| from apex.normalization import FusedLayerNorm as LayerNorm |
| except ModuleNotFoundError: |
| from torch.nn import LayerNorm |
|
|
|
|
| class DecoderLayer(nn.Module): |
| def __init__( |
| self, |
| args, |
| depth, |
| is_moe_layer=False, |
| is_encoder_decoder=False, |
| ): |
| super().__init__() |
| self.args = args |
| self.embed_dim = args.decoder_embed_dim |
| self.dropout_module = torch.nn.Dropout(args.dropout) |
|
|
| if args.drop_path_rate > 0: |
| drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth] |
| self.drop_path = DropPath(drop_path_prob) |
| else: |
| self.drop_path = None |
|
|
| self.self_attn = self.build_self_attention(self.embed_dim, args) |
|
|
| self.normalize_before = args.decoder_normalize_before |
|
|
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) |
|
|
| if not is_encoder_decoder: |
| self.encoder_attn = None |
| self.encoder_attn_layer_norm = None |
| else: |
| self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) |
| self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) |
|
|
| self.is_moe_layer = is_moe_layer |
| self.ffn_dim = args.decoder_ffn_embed_dim |
|
|
| if not self.is_moe_layer: |
| self.ffn = self.build_ffn( |
| self.embed_dim, |
| self.args, |
| ) |
| else: |
| if args.moe_top1_expert: |
| gate = Top1Gate( |
| self.embed_dim, |
| args.moe_expert_count, |
| use_fp32=args.moe_gating_use_fp32, |
| moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, |
| use_xmoe=args.use_xmoe, |
| ) |
| else: |
| gate = Top2Gate( |
| self.embed_dim, |
| args.moe_expert_count, |
| args.moe_gating_use_fp32, |
| args.moe_second_expert_policy, |
| args.moe_normalize_gate_prob_before_dropping, |
| args.moe_eval_capacity_token_fraction, |
| use_xmoe=args.use_xmoe, |
| ) |
| experts = make_experts(args, self.embed_dim, self.ffn_dim) |
| self.moe_layer = MOELayer(gate, experts, args) |
|
|
| self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) |
|
|
| if args.deepnorm: |
| if is_encoder_decoder: |
| self.alpha = math.pow(3.0 * args.decoder_layers, 0.25) |
| else: |
| self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) |
| else: |
| self.alpha = 1.0 |
|
|
| def build_ffn(self, embed_dim, args): |
| return FeedForwardNetwork( |
| embed_dim, |
| self.ffn_dim, |
| args.activation_fn, |
| args.dropout, |
| args.activation_dropout, |
| args.layernorm_eps, |
| args.subln, |
| ) |
|
|
| def build_self_attention(self, embed_dim, args): |
| return MultiheadAttention( |
| args, |
| embed_dim, |
| args.decoder_attention_heads, |
| dropout=args.attention_dropout, |
| self_attention=True, |
| encoder_decoder_attention=False, |
| subln=args.subln, |
| ) |
|
|
| def build_encoder_attention(self, embed_dim, args): |
| return MultiheadAttention( |
| args, |
| embed_dim, |
| args.decoder_attention_heads, |
| dropout=args.attention_dropout, |
| self_attention=False, |
| encoder_decoder_attention=True, |
| subln=args.subln, |
| ) |
|
|
| def residual_connection(self, x, residual): |
| return residual * self.alpha + x |
|
|
| def forward( |
| self, |
| x, |
| encoder_out=None, |
| encoder_padding_mask=None, |
| incremental_state=None, |
| self_attn_mask=None, |
| self_attn_padding_mask=None, |
| self_attn_rel_pos=None, |
| cross_attn_rel_pos=None, |
| ): |
| residual = x |
| if self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
|
|
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=self_attn_padding_mask, |
| incremental_state=incremental_state, |
| attn_mask=self_attn_mask, |
| rel_pos=self_attn_rel_pos, |
| ) |
| x = self.dropout_module(x) |
|
|
| if self.drop_path is not None: |
| x = self.drop_path(x) |
|
|
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
|
|
| if self.encoder_attn is not None and encoder_out is not None: |
| residual = x |
| if self.normalize_before: |
| x = self.encoder_attn_layer_norm(x) |
|
|
| x, attn = self.encoder_attn( |
| query=x, |
| key=encoder_out, |
| value=encoder_out, |
| key_padding_mask=encoder_padding_mask, |
| incremental_state=None, |
| rel_pos=cross_attn_rel_pos, |
| ) |
| x = self.dropout_module(x) |
|
|
| if self.drop_path is not None: |
| x = self.drop_path(x) |
|
|
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.encoder_attn_layer_norm(x) |
|
|
| residual = x |
| if self.normalize_before: |
| x = self.final_layer_norm(x) |
| if not self.is_moe_layer: |
| x = self.ffn(x) |
| l_aux = None |
| else: |
| x, l_aux = self.moe_layer(x) |
|
|
| if self.drop_path is not None: |
| x = self.drop_path(x) |
|
|
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.final_layer_norm(x) |
|
|
| return x, attn, None, l_aux |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.args = args |
|
|
| self.dropout_module = torch.nn.Dropout(args.dropout) |
|
|
| embed_dim = args.decoder_embed_dim |
| self.embed_dim = embed_dim |
| self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) |
|
|
| self.embed_tokens = embed_tokens |
| self.embed_positions = embed_positions |
|
|
| if output_projection is None and not args.no_output_layer and args.vocab_size > 0: |
| self.output_projection = self.build_output_projection(args) |
| else: |
| self.output_projection = output_projection |
|
|
| if args.layernorm_embedding: |
| self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps) |
| else: |
| self.layernorm_embedding = None |
|
|
| self.layers = nn.ModuleList([]) |
|
|
| moe_freq = args.moe_freq |
| for i in range(args.decoder_layers): |
| is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 |
| self.layers.append( |
| self.build_decoder_layer( |
| args, |
| depth=i, |
| is_moe_layer=is_moe_layer, |
| is_encoder_decoder=is_encoder_decoder, |
| ) |
| ) |
|
|
| self.num_layers = len(self.layers) |
|
|
| if args.decoder_normalize_before: |
| self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps) |
| else: |
| self.layer_norm = None |
|
|
| self.self_attn_relative_position = None |
| self.cross_attn_relative_position = None |
|
|
| if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: |
| self.self_attn_relative_position = RelativePositionBias( |
| num_buckets=args.rel_pos_buckets, |
| max_distance=args.max_rel_pos, |
| n_heads=args.decoder_attention_heads, |
| ) |
| if is_encoder_decoder: |
| self.cross_attn_relative_position = RelativePositionBias( |
| num_buckets=args.rel_pos_buckets, |
| max_distance=args.max_rel_pos, |
| n_heads=args.decoder_attention_heads, |
| ) |
|
|
| if args.bert_init: |
| self.apply(init_bert_params) |
|
|
| if args.deepnorm: |
| if is_encoder_decoder: |
| init_scale = math.pow(12.0 * args.decoder_layers, 0.25) |
| else: |
| init_scale = math.pow(8.0 * args.decoder_layers, 0.25) |
| for name, p in self.named_parameters(): |
| if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: |
| p.data.div_(init_scale) |
|
|
| if args.subln: |
| if is_encoder_decoder: |
| init_scale = math.sqrt(math.log(args.decoder_layers * 3)) |
| else: |
| init_scale = math.sqrt(math.log(args.decoder_layers * 2)) |
| for name, p in self.named_parameters(): |
| if "encoder_attn" in name: |
| continue |
| if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: |
| p.data.mul_(init_scale) |
|
|
| def build_output_projection( |
| self, |
| args, |
| ): |
| if args.share_decoder_input_output_embed: |
| output_projection = torch.nn.Linear( |
| self.embed_tokens.weight.shape[1], |
| self.embed_tokens.weight.shape[0], |
| bias=False, |
| ) |
| output_projection.weight = self.embed_tokens.weight |
| else: |
| output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False) |
| torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5) |
| return output_projection |
|
|
| def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False): |
| layer = DecoderLayer( |
| args, |
| depth, |
| is_moe_layer=is_moe_layer, |
| is_encoder_decoder=is_encoder_decoder, |
| ) |
| if args.checkpoint_activations: |
| layer = checkpoint_wrapper(layer) |
| if args.fsdp: |
| layer = wrap(layer) |
| return layer |
|
|
| def forward_embedding( |
| self, |
| tokens, |
| token_embedding=None, |
| incremental_state=None, |
| ): |
| positions = None |
| if self.embed_positions is not None: |
| positions = self.embed_positions(tokens, incremental_state=incremental_state) |
|
|
| if incremental_state is not None: |
| tokens = tokens[:, -1:] |
| if positions is not None: |
| positions = positions[:, -1:] |
|
|
| if token_embedding is None: |
| token_embedding = self.embed_tokens(tokens) |
|
|
| x = embed = self.embed_scale * token_embedding |
|
|
| if positions is not None: |
| x += positions |
|
|
| if self.layernorm_embedding is not None: |
| x = self.layernorm_embedding(x) |
|
|
| x = self.dropout_module(x) |
|
|
| return x, embed |
|
|
| def forward( |
| self, |
| prev_output_tokens, |
| self_attn_padding_mask=None, |
| encoder_out=None, |
| incremental_state=None, |
| features_only=False, |
| return_all_hiddens=False, |
| token_embeddings=None, |
| **kwargs |
| ): |
| |
| x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) |
|
|
| |
| self_attn_rel_pos_bias = None |
| slen = prev_output_tokens.size(1) |
| if self.self_attn_relative_position is not None: |
| self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen) |
| if incremental_state is not None: |
| self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :] |
| cross_attn_rel_pos_bias = None |
| if self.cross_attn_relative_position is not None: |
| cross_attn_rel_pos_bias = self.cross_attn_relative_position( |
| batch_size=x.size(0), |
| qlen=slen, |
| klen=encoder_out["encoder_out"].size(1), |
| ) |
| if incremental_state is not None: |
| cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :] |
|
|
| |
| inner_states = [x] |
|
|
| if encoder_out is None: |
| l_aux = [] |
| else: |
| l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] |
|
|
| for idx, layer in enumerate(self.layers): |
| if incremental_state is None: |
| self_attn_mask = torch.triu( |
| torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x), |
| 1, |
| ) |
| else: |
| self_attn_mask = None |
| if idx not in incremental_state: |
| incremental_state[idx] = {} |
|
|
| x, layer_attn, _, l_aux_i = layer( |
| x, |
| encoder_out["encoder_out"] if encoder_out is not None else None, |
| encoder_out["encoder_padding_mask"] if encoder_out is not None else None, |
| incremental_state[idx] if incremental_state is not None else None, |
| self_attn_mask=self_attn_mask, |
| self_attn_padding_mask=self_attn_padding_mask, |
| self_attn_rel_pos=self_attn_rel_pos_bias, |
| cross_attn_rel_pos=cross_attn_rel_pos_bias, |
| ) |
| l_aux.append(l_aux_i) |
| inner_states.append(x) |
|
|
| if self.layer_norm is not None: |
| x = self.layer_norm(x) |
|
|
| if not features_only: |
| x = self.output_layer(x) |
|
|
| return x, { |
| "inner_states": inner_states, |
| "l_aux": l_aux, |
| "attn": None, |
| } |
|
|
| def output_layer(self, features): |
| return self.output_projection(features) |
|
|