Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| from fairseq.modules import TransformerSentenceEncoder | |
| from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( | |
| SparseTransformerSentenceEncoderLayer, | |
| ) | |
| class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): | |
| """ | |
| Sparse implementation of the TransformerSentenceEncoder | |
| - see SparseMultiheadAttention | |
| """ | |
| def __init__( | |
| self, | |
| padding_idx: int, | |
| vocab_size: int, | |
| num_encoder_layers: int = 6, | |
| embedding_dim: int = 768, | |
| ffn_embedding_dim: int = 3072, | |
| num_attention_heads: int = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.1, | |
| max_seq_len: int = 256, | |
| num_segments: int = 2, | |
| use_position_embeddings: bool = True, | |
| offset_positions_by_padding: bool = True, | |
| encoder_normalize_before: bool = False, | |
| apply_bert_init: bool = False, | |
| activation_fn: str = "relu", | |
| learned_pos_embedding: bool = True, | |
| embed_scale: float = None, | |
| freeze_embeddings: bool = False, | |
| n_trans_layers_to_freeze: int = 0, | |
| export: bool = False, | |
| is_bidirectional: bool = True, | |
| stride: int = 32, | |
| expressivity: int = 8, | |
| ) -> None: | |
| super().__init__( | |
| padding_idx, | |
| vocab_size, | |
| num_encoder_layers, | |
| embedding_dim, | |
| ffn_embedding_dim, | |
| num_attention_heads, | |
| dropout, | |
| attention_dropout, | |
| activation_dropout, | |
| max_seq_len, | |
| num_segments, | |
| use_position_embeddings, | |
| offset_positions_by_padding, | |
| encoder_normalize_before, | |
| apply_bert_init, | |
| activation_fn, | |
| learned_pos_embedding, | |
| embed_scale, | |
| freeze_embeddings, | |
| n_trans_layers_to_freeze, | |
| export, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| SparseTransformerSentenceEncoderLayer( | |
| embedding_dim=self.embedding_dim, | |
| ffn_embedding_dim=ffn_embedding_dim, | |
| num_attention_heads=num_attention_heads, | |
| dropout=dropout, | |
| attention_dropout=attention_dropout, | |
| activation_dropout=activation_dropout, | |
| activation_fn=activation_fn, | |
| export=export, | |
| is_bidirectional=is_bidirectional, | |
| stride=stride, | |
| expressivity=expressivity, | |
| ) | |
| for _ in range(num_encoder_layers) | |
| ] | |
| ) | |
| def freeze_module_params(m): | |
| if m is not None: | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| for layer in range(n_trans_layers_to_freeze): | |
| freeze_module_params(self.layers[layer]) | |