| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | from fairseq2.data import VocabularyInfo |
| | from fairseq2.models.transformer import ( |
| | TransformerEmbeddingFrontend, |
| | TransformerFrontend, |
| | ) |
| | from fairseq2.models.utils.arch_registry import ArchitectureRegistry |
| | from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding |
| | from fairseq2.nn.position_encoder import SinusoidalPositionEncoder |
| | from fairseq2.nn.projection import TiedProjection |
| | from fairseq2.nn.transformer import ( |
| | FeedForwardNetwork, |
| | MultiheadAttention, |
| | StandardFeedForwardNetwork, |
| | StandardMultiheadAttention, |
| | TransformerNormOrder, |
| | create_default_sdpa, |
| | ) |
| | from fairseq2.typing import DataType, Device |
| |
|
| | from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel |
| | from seamless_communication.models.monotonic_decoder.monotonic_decoder import ( |
| | MonotonicTransformerDecoder, |
| | ) |
| | from seamless_communication.models.monotonic_decoder.monotonic_decoder_layer import ( |
| | MonotonicTransformerDecoderLayer, |
| | ) |
| | from seamless_communication.models.monotonic_decoder.p_choose import PChooseLayer |
| |
|
| |
|
| | @dataclass |
| | class MonotonicDecoderConfig: |
| | """Holds the configuration of an Monotonic Decoder model.""" |
| |
|
| | model_dim: int |
| | """The dimensionality of the model.""" |
| |
|
| | max_seq_len: int |
| | """The expected maximum sequence length.""" |
| |
|
| | vocab_info: VocabularyInfo |
| | """The vocabulary information.""" |
| |
|
| | num_decoder_layers: int |
| | """The number of Transformer decoder layers.""" |
| |
|
| | num_decoder_attn_heads: int |
| | """The number of attention heads in Transformer decoder layers.""" |
| |
|
| | ffn_inner_dim: int |
| | """The inner dimensionality of Transformer feed-forward networks.""" |
| |
|
| | dropout_p: float |
| | """The dropout probability in Transformer layers.""" |
| |
|
| | energy_bias_value: float |
| | """The value of the energy bias parameter to be added to the |
| | monotonic energy in the PChooseLayer.""" |
| |
|
| | monotonic_temperature: float |
| | """The parameter with which to divide the monotonic energy |
| | to compute p_choose.""" |
| |
|
| | num_monotonic_energy_layers: int |
| | """The number of layers in the EnergyProjection module.""" |
| |
|
| | pre_decision_ratio: int |
| | """The kernel size and stride of the average pooling |
| | in the PChooseLayer.""" |
| |
|
| |
|
| | monotonic_decoder_archs = ArchitectureRegistry[MonotonicDecoderConfig]( |
| | "monotonic_decoder" |
| | ) |
| |
|
| | monotonic_decoder_arch = monotonic_decoder_archs.decorator |
| |
|
| |
|
| | @monotonic_decoder_arch("dense_1b") |
| | def _dense_1b() -> MonotonicDecoderConfig: |
| | return MonotonicDecoderConfig( |
| | model_dim=1024, |
| | max_seq_len=4096, |
| | vocab_info=VocabularyInfo( |
| | size=256102, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 |
| | ), |
| | num_decoder_layers=24, |
| | num_decoder_attn_heads=16, |
| | ffn_inner_dim=1024 * 8, |
| | dropout_p=0.1, |
| | energy_bias_value=-0.5, |
| | monotonic_temperature=0.2, |
| | num_monotonic_energy_layers=4, |
| | pre_decision_ratio=2, |
| | ) |
| |
|
| |
|
| | class MonotonicDecoderBuilder: |
| | """Builds modules of a Monotonic Decoder. |
| | |
| | To tweak the architecture, you can derive from this class and override the |
| | corresponding methods. |
| | """ |
| |
|
| | config: MonotonicDecoderConfig |
| | device: Optional[Device] |
| | dtype: Optional[DataType] |
| |
|
| | def __init__( |
| | self, |
| | config: MonotonicDecoderConfig, |
| | *, |
| | device: Optional[Device] = None, |
| | dtype: Optional[DataType] = None, |
| | ) -> None: |
| | """ |
| | :param config: |
| | The configuration to use. |
| | :param device: |
| | The device on which to initialize modules. |
| | :param dtype: |
| | The data type of module parameters and buffers. |
| | """ |
| | self.config = config |
| |
|
| | self.device, self.dtype = device, dtype |
| |
|
| | def build_model(self) -> MonotonicDecoderModel: |
| | text_embed = self.build_embedding() |
| |
|
| | text_decoder_frontend = self.build_frontend(text_embed) |
| |
|
| | text_decoder = self.build_decoder() |
| |
|
| | final_proj = TiedProjection(text_embed.weight, bias=None) |
| |
|
| | return MonotonicDecoderModel( |
| | text_decoder_frontend, |
| | text_decoder, |
| | final_proj, |
| | ) |
| |
|
| | def build_embedding(self) -> StandardEmbedding: |
| | """Build an embedding table.""" |
| | return StandardEmbedding( |
| | num_embeddings=self.config.vocab_info.size, |
| | embedding_dim=self.config.model_dim, |
| | pad_idx=self.config.vocab_info.pad_idx, |
| | init_fn=init_scaled_embedding, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_frontend(self, embed: Embedding) -> TransformerFrontend: |
| | """Build a Transformer decoder front-end.""" |
| | pos_encoder = SinusoidalPositionEncoder( |
| | self.config.model_dim, |
| | self.config.max_seq_len, |
| | _legacy_pad_idx=1, |
| | device=self.device, |
| | ) |
| |
|
| | return TransformerEmbeddingFrontend( |
| | embed, |
| | pos_encoder, |
| | dropout_p=self.config.dropout_p, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_decoder(self) -> MonotonicTransformerDecoder: |
| | """Build a Transformer decoder.""" |
| | num_layers = self.config.num_decoder_layers |
| |
|
| | layers = [self.build_decoder_layer() for _ in range(num_layers)] |
| |
|
| | return MonotonicTransformerDecoder( |
| | layers, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_decoder_layer(self) -> MonotonicTransformerDecoderLayer: |
| | """Build a Transformer decoder layer.""" |
| | self_attn = self.build_attention(self.config.num_decoder_attn_heads) |
| |
|
| | encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads) |
| |
|
| | p_choose_layer = self.build_p_choose_layer(self.config.num_decoder_attn_heads) |
| |
|
| | ffn = self.build_ffn() |
| |
|
| | return MonotonicTransformerDecoderLayer( |
| | self_attn, |
| | encoder_decoder_attn, |
| | p_choose_layer, |
| | ffn, |
| | dropout_p=self.config.dropout_p, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_attention(self, num_heads: int) -> MultiheadAttention: |
| | """Build a Transformer multi-head attention layer.""" |
| | sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p) |
| |
|
| | return StandardMultiheadAttention( |
| | self.config.model_dim, |
| | num_heads, |
| | sdpa=sdpa, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_p_choose_layer(self, num_heads: int) -> PChooseLayer: |
| | """Build a PChoose layer.""" |
| | return PChooseLayer( |
| | self.config.model_dim, |
| | num_heads, |
| | self.config.energy_bias_value, |
| | self.config.monotonic_temperature, |
| | self.config.num_monotonic_energy_layers, |
| | self.config.pre_decision_ratio, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_ffn(self) -> FeedForwardNetwork: |
| | """Build a Transformer feed-forward network.""" |
| | return StandardFeedForwardNetwork( |
| | self.config.model_dim, |
| | self.config.ffn_inner_dim, |
| | bias=True, |
| | norm_order=TransformerNormOrder.PRE, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| |
|
| | def create_monotonic_decoder_model( |
| | config: MonotonicDecoderConfig, |
| | *, |
| | device: Optional[Device] = None, |
| | dtype: Optional[DataType] = None, |
| | ) -> MonotonicDecoderModel: |
| | """Create an Monotonic Decoder model. |
| | |
| | :param config: |
| | The configuration to use. |
| | :param device: |
| | The device on which to initialize modules. |
| | :param dtype: |
| | The data type of module parameters and buffers. |
| | """ |
| | return MonotonicDecoderBuilder(config, device=device, dtype=dtype).build_model() |
| |
|