|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
from fairseq2.config_registry import ConfigRegistry |
|
|
from fairseq2.logging import get_log_writer |
|
|
from fairseq2.typing import DataType, Device |
|
|
from torch.nn import Module |
|
|
|
|
|
from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model |
|
|
|
|
|
logger = get_log_writer(__name__) |
|
|
|
|
|
|
|
|
""" |
|
|
An abstract LCM model class for the bare minimum |
|
|
""" |
|
|
|
|
|
ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AbstractLCModelConfig: |
|
|
model_type: str = ABSTRACT_LCM_MODEL_TYPE |
|
|
|
|
|
sonar_embed_dim: int = 1024 |
|
|
|
|
|
sonar_normalizer_name: Optional[str] = None |
|
|
|
|
|
|
|
|
lcm_archs = ConfigRegistry[AbstractLCModelConfig]() |
|
|
lcm_arch = lcm_archs.decorator |
|
|
|
|
|
|
|
|
class AbstractLCModel(Module): |
|
|
"""Asbtract Class for LCM models""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: AbstractLCModelConfig, |
|
|
) -> None: |
|
|
""" |
|
|
Asbtract LCM model |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return next(self.parameters()).dtype |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.parameters()).device |
|
|
|
|
|
|
|
|
class AbstractLCModelBuilder: |
|
|
"""Builds modules of an LCM""" |
|
|
|
|
|
config: AbstractLCModelConfig |
|
|
device: Optional[Device] |
|
|
dtype: Optional[DataType] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: AbstractLCModelConfig, |
|
|
*, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> None: |
|
|
""" |
|
|
:param config: |
|
|
The configuration. |
|
|
:param device: |
|
|
The device on which to initialize modules. |
|
|
:param dtype: |
|
|
The data type of module parameters and buffers. |
|
|
""" |
|
|
self.config = config |
|
|
|
|
|
self.device, self.dtype = device, dtype |
|
|
|
|
|
def build_sonar_normalizer( |
|
|
self, |
|
|
) -> Optional[SonarNormalizer]: |
|
|
if self.config.sonar_normalizer_name is not None: |
|
|
logger.info( |
|
|
f"Building sonar_normalizer = {self.config.sonar_normalizer_name}" |
|
|
) |
|
|
return load_sonar_normalizer_model( |
|
|
self.config.sonar_normalizer_name, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
return None |
|
|
|
|
|
@abstractmethod |
|
|
def build_model(self) -> AbstractLCModel: |
|
|
"""Build a model.""" |
|
|
... |
|
|
|