maotao / fairseq /modules /sparse_transformer_sentence_encoder.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.modules import TransformerSentenceEncoder
from fairseq.modules.sparse_transformer_sentence_encoder_layer import SparseTransformerSentenceEncoderLayer
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
"""
Sparse implementation of the TransformerSentenceEncoder
- see SparseMultiheadAttention
"""
def __init__(
self,
padding_idx: int,
vocab_size: int,
num_encoder_layers: int = 6,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_seq_len: int = 256,
num_segments: int = 2,
use_position_embeddings: bool = True,
offset_positions_by_padding: bool = True,
encoder_normalize_before: bool = False,
apply_bert_init: bool = False,
activation_fn: str = "relu",
learned_pos_embedding: bool = True,
embed_scale: float = None,
freeze_embeddings: bool = False,
n_trans_layers_to_freeze: int = 0,
export: bool = False,
is_bidirectional: bool = True,
stride: int = 32,
expressivity: int = 8,
) -> None:
super().__init__(
padding_idx, vocab_size, num_encoder_layers, embedding_dim,
ffn_embedding_dim, num_attention_heads, dropout, attention_dropout,
activation_dropout, max_seq_len, num_segments, use_position_embeddings,
offset_positions_by_padding, encoder_normalize_before, apply_bert_init,
activation_fn, learned_pos_embedding, embed_scale, freeze_embeddings,
n_trans_layers_to_freeze, export
)
self.layers = nn.ModuleList(
[
SparseTransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
export=export,
is_bidirectional=is_bidirectional,
stride=stride,
expressivity=expressivity,
)
for _ in range(num_encoder_layers)
]
)
def freeze_module_params(m):
if m is not None:
for p in m.parameters():
p.requires_grad = False
for layer in range(n_trans_layers_to_freeze):
freeze_module_params(self.layers[layer])