LexaLCM_Pre0 / lcm /nn /projection.py
Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from dataclasses import dataclass
from typing import Optional
import torch
from fairseq2.nn.projection import Linear
from fairseq2.typing import DataType, Device
from torch import Tensor
from torch.nn import Module
from lcm.nn.initialization import (
SUPPORTED_INIT_TYPES,
get_init_fn,
parse_activation_fn,
)
from lcm.nn.normalization import SUPPORTED_LN_TYPES
@dataclass
class ProjectionConfig:
dropout_p: float = 0.0
""" The dropout probability applied to the module' output"""
linear_bias: bool = True
""" Whether or not the pre-linear layer has a bias term"""
linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform"
weight_normalization: bool = False
layer_normalization_style: SUPPORTED_LN_TYPES = "standard"
activation_name: Optional[str] = None
"""the activation function to apply after fi any"""
class Projection(Module):
"""
An output projecton module.
"""
def __init__(
self,
output_dim: int,
input_dim: int,
config: ProjectionConfig,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
super().__init__()
self.dtype = dtype
init_fn = get_init_fn(config.linear_init_fn)
lin = Linear(
input_dim,
output_dim,
bias=config.linear_bias,
device=device,
dtype=dtype,
init_fn=init_fn,
)
if config.weight_normalization:
self.fc = torch.nn.utils.parametrizations.weight_norm(lin)
else:
self.fc = lin
self.activation_fn = parse_activation_fn(config.activation_name)
if self.activation_fn is not None:
# some activation functions (e.g., PReLU) have parameters
# and so we need to move them to the right device
self.activation_fn.to(device)
def forward(self, seqs: Tensor):
seqs = self.fc(seqs)
if self.activation_fn is not None:
seqs = self.activation_fn(seqs)
return seqs