|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from fairseq2.nn.projection import Linear |
|
|
from fairseq2.typing import DataType, Device |
|
|
from torch import Tensor |
|
|
from torch.nn import Module |
|
|
|
|
|
from lcm.nn.initialization import ( |
|
|
SUPPORTED_INIT_TYPES, |
|
|
get_init_fn, |
|
|
parse_activation_fn, |
|
|
) |
|
|
from lcm.nn.normalization import SUPPORTED_LN_TYPES |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ProjectionConfig: |
|
|
dropout_p: float = 0.0 |
|
|
""" The dropout probability applied to the module' output""" |
|
|
|
|
|
linear_bias: bool = True |
|
|
""" Whether or not the pre-linear layer has a bias term""" |
|
|
|
|
|
linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" |
|
|
|
|
|
weight_normalization: bool = False |
|
|
|
|
|
layer_normalization_style: SUPPORTED_LN_TYPES = "standard" |
|
|
|
|
|
activation_name: Optional[str] = None |
|
|
"""the activation function to apply after fi any""" |
|
|
|
|
|
|
|
|
class Projection(Module): |
|
|
""" |
|
|
An output projecton module. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
output_dim: int, |
|
|
input_dim: int, |
|
|
config: ProjectionConfig, |
|
|
device: Optional[Device] = None, |
|
|
dtype: Optional[DataType] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.dtype = dtype |
|
|
|
|
|
init_fn = get_init_fn(config.linear_init_fn) |
|
|
|
|
|
lin = Linear( |
|
|
input_dim, |
|
|
output_dim, |
|
|
bias=config.linear_bias, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
init_fn=init_fn, |
|
|
) |
|
|
if config.weight_normalization: |
|
|
self.fc = torch.nn.utils.parametrizations.weight_norm(lin) |
|
|
else: |
|
|
self.fc = lin |
|
|
|
|
|
self.activation_fn = parse_activation_fn(config.activation_name) |
|
|
|
|
|
if self.activation_fn is not None: |
|
|
|
|
|
|
|
|
self.activation_fn.to(device) |
|
|
|
|
|
def forward(self, seqs: Tensor): |
|
|
seqs = self.fc(seqs) |
|
|
|
|
|
if self.activation_fn is not None: |
|
|
seqs = self.activation_fn(seqs) |
|
|
|
|
|
return seqs |
|
|
|