File size: 2,485 Bytes
3d79eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

from fairseq2.config_registry import ConfigRegistry
from fairseq2.logging import get_log_writer
from fairseq2.typing import DataType, Device
from torch.nn import Module

from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model

logger = get_log_writer(__name__)


"""
An abstract LCM model class for the bare minimum
"""

ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm"


@dataclass
class AbstractLCModelConfig:
    model_type: str = ABSTRACT_LCM_MODEL_TYPE

    sonar_embed_dim: int = 1024

    sonar_normalizer_name: Optional[str] = None


lcm_archs = ConfigRegistry[AbstractLCModelConfig]()
lcm_arch = lcm_archs.decorator


class AbstractLCModel(Module):
    """Asbtract Class for LCM models"""

    def __init__(
        self,
        config: AbstractLCModelConfig,
    ) -> None:
        """
        Asbtract LCM model
        """
        super().__init__()

        self.config = config

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device


class AbstractLCModelBuilder:
    """Builds modules of an LCM"""

    config: AbstractLCModelConfig
    device: Optional[Device]
    dtype: Optional[DataType]

    def __init__(
        self,
        config: AbstractLCModelConfig,
        *,
        device: Optional[Device] = None,
        dtype: Optional[DataType] = None,
    ) -> None:
        """
        :param config:
            The configuration.
        :param device:
            The device on which to initialize modules.
        :param dtype:
            The data type of module parameters and buffers.
        """
        self.config = config

        self.device, self.dtype = device, dtype

    def build_sonar_normalizer(
        self,
    ) -> Optional[SonarNormalizer]:
        if self.config.sonar_normalizer_name is not None:
            logger.info(
                f"Building sonar_normalizer = {self.config.sonar_normalizer_name}"
            )
            return load_sonar_normalizer_model(
                self.config.sonar_normalizer_name,
                device=self.device,
                dtype=self.dtype,
            )
        return None

    @abstractmethod
    def build_model(self) -> AbstractLCModel:
        """Build a model."""
        ...