LexaLCM_Pre0 / lcm /nn /initialization.py
Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import math
from functools import partial
from typing import Literal, Optional
import torch
from fairseq2.nn.projection import Linear
from fairseq2.nn.transformer import TransformerNormOrder
from torch.nn import Module
SUPPORTED_INIT_TYPES = Literal[
"xavier",
"sonar",
"zero",
"trunc_normal",
"kaiming_uniform",
"none",
]
SONAR_STD = 0.006
# Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006
# Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended,
# to minimize their disruption of the model training
def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD):
if style == "xavier":
return init_linear_xavier
if style == "kaiming_uniform":
return init_linear_kaiming_uniform
if style == "sonar":
return partial(init_linear_to_sonar, sonar_std=sonar_std)
if style == "zero":
return init_linear_zero
if style == "trunc_normal":
return init_linear_trunc_normal
if style == "none":
return None
else:
raise ValueError(f"Could not recognize initialization function {style}")
def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None:
"""
Initialize the post-lcm in such a way, that if it is fed layer-normed
lcm outputs (with zero mean and unit variance), its outputs have zero
mean and the variance of SONAR embeddings.
"""
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
std = sonar_std * (3 / layer.input_dim) ** 0.5
torch.nn.init.uniform_(layer.weight, a=-std, b=std)
def init_linear_xavier(layer: Linear) -> None:
torch.nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_zero(layer: Linear) -> None:
torch.nn.init.zeros_(layer.weight)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_trunc_normal(layer: Linear) -> None:
torch.nn.init.trunc_normal_(layer.weight, std=1e-3)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_kaiming_uniform(layer: Linear) -> None:
torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in = layer.weight.size(1)
m = 1
if layer.weight.ndim > 2:
for s in layer.weight.shape[2:]:
m *= s
fan_in *= m
# We do not calculate the true standard deviation of the uniform
# distribution (i.e. multiply with sqrt(3)). See
# https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575.
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(layer.bias, -bound, bound)
def parse_norm_order(var: str) -> TransformerNormOrder:
norm_order: TransformerNormOrder
if var == "pre":
norm_order = TransformerNormOrder.PRE
elif var == "post":
norm_order = TransformerNormOrder.POST
elif var == "normformer":
norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER
else:
raise ValueError(f"Unknown normalization order {var}")
return norm_order
def parse_activation_fn(var: str = None) -> Optional[Module]:
if var is None:
return None
activ_fn: Module
if var == "relu":
activ_fn = torch.nn.ReLU()
elif var == "tanh":
activ_fn = torch.nn.Tanh()
elif var == "elu":
activ_fn = torch.nn.ELU()
elif var == "leaky_relu":
activ_fn = torch.nn.LeakyReLU()
elif var == "prelu":
activ_fn = torch.nn.PReLU()
elif var == "selu":
activ_fn = torch.nn.SELU()
elif var == "gelu":
activ_fn = torch.nn.GELU()
elif var == "silu":
activ_fn = torch.nn.SiLU()
elif var == "softsign":
activ_fn = torch.nn.Softsign()
elif var == "sigmoid":
activ_fn = torch.nn.Sigmoid()
elif var == "hardsigmoid":
activ_fn = torch.nn.Hardsigmoid()
else:
raise ValueError(f"Unknown activation function {var}")
return activ_fn