File size: 8,942 Bytes
3d79eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#

from dataclasses import dataclass, field
from typing import Optional

import torch.nn
from fairseq2.config_registry import ConfigRegistry
from fairseq2.logging import get_log_writer
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.transformer import AttentionMaskFactory, CausalAttentionMaskFactory
from fairseq2.typing import DataType, Device

from lcm.datasets.batch import EmbeddingsBatch
from lcm.models.abstract_lcm import (
    AbstractLCModel,
    AbstractLCModelBuilder,
    AbstractLCModelConfig,
)
from lcm.models.base_lcm.frontend import LCMFrontend, LCMFrontendConfig
from lcm.nn.initialization import parse_norm_order
from lcm.nn.normalization import parse_layer_norm_factory
from lcm.nn.projection import Projection, ProjectionConfig
from lcm.nn.transformer import (
    LCMTransformerDecoder,
    TransformerConfig,
    TransformerFactory,
)

logger = get_log_writer(__name__)

BASE_LCM_MODEL_TYPE = "base_lcm"


@dataclass
class BaseLCModelConfig(AbstractLCModelConfig):
    model_type: str = BASE_LCM_MODEL_TYPE

    max_seq_len: int = 2048

    model_dim: int = 1024

    model_output_dim: Optional[int] = None
    """If ``None`` use SONAR dimension as output_dim."""

    frontend: LCMFrontendConfig = field(default_factory=lambda: LCMFrontendConfig())
    """The fronted config. This module maps from `sonar_embed_dim` to `model_dim`
        and potentially adds positional embeddings"""

    lcm: TransformerConfig = field(default_factory=lambda: TransformerConfig())
    """The core lcm config. This is causal Transformer decoder"""

    postnet: ProjectionConfig = field(default_factory=lambda: ProjectionConfig())
    """The postnet config. A module mapping the output of the core lcm
       back to `sonar_embed_dim`"""


lcm_archs = ConfigRegistry[BaseLCModelConfig]()
lcm_arch = lcm_archs.decorator


class BaseLCModel(AbstractLCModel):
    """Base class for LCM models"""

    config: BaseLCModelConfig

    def __init__(
        self,
        config: BaseLCModelConfig,
        lcm: LCMTransformerDecoder,
        frontend: LCMFrontend,
        postnet: Projection,
    ) -> None:
        """
        Basic LCM model with :
            - fronted
            - lcm
            - postnet
        """
        super().__init__(config)

        self.frontend = frontend

        self.lcm = lcm

        self.postnet = postnet

        self.model_dim = lcm.model_dim

        self.sonar_embed_dim = config.sonar_embed_dim

    def forward(
        self,
        batch: EmbeddingsBatch,
        state_bag: Optional[IncrementalStateBag] = None,
        **kwargs,
    ) -> EmbeddingsBatch:
        """
        Scaling + Positions
        If a normalizer is provided, the features will be normalized in the
        frontend's pre_forward (e.g. MSE LCM) or in the criterion (Diffusion LCM)
        """
        seqs, padding_mask = self.frontend(
            batch.seqs,
            batch.padding_mask,
            diffusion_timesteps=batch.diffusion_timesteps,
            state_bag=state_bag,
            **kwargs,
        )

        # Core LCM
        seqs, padding_mask = self.lcm(
            seqs,
            padding_mask,
            state_bag=state_bag,
            **kwargs,
        )

        # Postnet:
        seqs = self.postnet(seqs)  # type: ignore

        return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask)

    def predict_next_sentence(
        self,
        batch: EmbeddingsBatch,
        sample: bool = False,
        temperature: float = 1.0,
        state_bag: Optional[IncrementalStateBag] = None,
        **kwargs,
    ) -> EmbeddingsBatch:
        """
        The method for predicting the next sentence embeddings.
        In the basic LCM, this is equivalent to just the forward method,
        but the derived architectures may have a different implementation.
        E.g. in VAE LCM, we run the VAE decoder on top of the `forward` results.

        Args:
            batch (EmbeddingsBatch): the sequence of concepts which
                the model should continue.
            sample (bool): whether to predict the single most probable next sentence
                or to sample from the predicted distribution.
            temperature (float): a positive float indicating the degree of diversity
                for the sampling (active only if `sample is True`).
        Returns:
            EmbeddingsBatch: the batch with predicted SONAR sentences.
        """
        # Normalize the input embeddings if we're expected to
        # normalize outside of the model's forward pass
        if self.frontend.sonar_normalizer is not None:
            batch = batch.normalize_seqs(self.frontend.sonar_normalizer)

        # TODO: implement efficient sampling of multiple candidates
        predicted_means = self.forward(batch, state_bag=state_bag, **kwargs)

        if sample and temperature > 0:
            noise = torch.randn_like(predicted_means.seqs) * temperature
            predicted_means.seqs = predicted_means.seqs + noise

        if self.frontend.sonar_normalizer is not None:
            predicted_means = predicted_means.denormalize_seqs(
                self.frontend.sonar_normalizer
            )

        return predicted_means


class BaseLCModelBuilder(AbstractLCModelBuilder):
    """Builds modules of a base LCM model"""

    config: BaseLCModelConfig
    device: Optional[Device]
    dtype: Optional[DataType]

    def __init__(
        self,
        config: BaseLCModelConfig,
        *,
        device: Optional[Device] = None,
        dtype: Optional[DataType] = None,
    ) -> None:
        super().__init__(config=config, device=device, dtype=dtype)
        self.lcm_factory = TransformerFactory(
            model_dim=self.config.model_dim,
            max_seq_len=self.config.max_seq_len,
            config=self.config.lcm,
            device=device,
            dtype=dtype,
        )

        if config.model_output_dim is None:
            self.model_output_dim = self.config.sonar_embed_dim
        else:
            self.model_output_dim = config.model_output_dim

    def build_model(self) -> BaseLCModel:
        """Build a model."""

        frontend = self.build_frontend()

        lcm = self.build_core_lcm()

        postnet = self.build_postnet()

        return BaseLCModel(
            config=self.config,
            frontend=frontend,
            lcm=lcm,
            postnet=postnet,
        )

    def build_frontend(self) -> LCMFrontend:
        """Build the LCM front-end (i.e., prenet)."""

        return LCMFrontend(
            sonar_embed_dim=self.config.sonar_embed_dim,
            model_dim=self.config.model_dim,
            config=self.config.frontend,
            pos_encoder=self.lcm_factory.build_pos_encoder(),
            sonar_normalizer=self.build_sonar_normalizer(),
            device=self.device,
            dtype=self.dtype,
        )

    def build_postnet(self) -> Projection:
        return Projection(
            output_dim=self.model_output_dim,
            input_dim=self.config.model_dim,
            config=self.config.postnet,
            device=self.device,
            dtype=self.dtype,
        )

    def build_attention_mask_factory(self):
        self_attn_mask_factory: AttentionMaskFactory

        self_attn_mask_factory = CausalAttentionMaskFactory()

        return self_attn_mask_factory

    def build_core_lcm(self) -> LCMTransformerDecoder:
        """Build the core LCM module."""

        config = self.config.lcm

        layers = [self.lcm_factory.build_layer() for _ in range(config.num_layers)]

        self_attn_mask_factory = self.build_attention_mask_factory()

        if config.final_norm_order_style is None:
            # The final norm order style will be that of the layer-level norm order
            final_norm_order = parse_norm_order(config.norm_order_style)
        else:
            final_norm_order = parse_norm_order(config.final_norm_order_style)

        layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style)

        return LCMTransformerDecoder(
            layers,  # type: ignore
            self_attn_mask_factory=self_attn_mask_factory,
            norm_order=final_norm_order,
            layer_norm_factory=layer_norm_factory,
            dropout_p=config.final_dropout_p,
            device=self.device,
            dtype=self.dtype,
        )


def create_base_lcm_model(
    config: BaseLCModelConfig,
    *,
    device: Optional[Device] = None,
    dtype: Optional[DataType] = None,
) -> BaseLCModel:
    """Create an LCM model.
    :param config:
        The configuration.
    :param device:
        The device on which to initialize modules.
    :param dtype:
        The data type of module parameters and buffers.
    """
    return BaseLCModelBuilder(config, device=device, dtype=dtype).build_model()