|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from functools import partial |
|
|
from typing import Literal, Optional |
|
|
|
|
|
import torch |
|
|
from fairseq2.nn.projection import Linear |
|
|
from fairseq2.nn.transformer import TransformerNormOrder |
|
|
from torch.nn import Module |
|
|
|
|
|
SUPPORTED_INIT_TYPES = Literal[ |
|
|
"xavier", |
|
|
"sonar", |
|
|
"zero", |
|
|
"trunc_normal", |
|
|
"kaiming_uniform", |
|
|
"none", |
|
|
] |
|
|
|
|
|
|
|
|
SONAR_STD = 0.006 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD): |
|
|
if style == "xavier": |
|
|
return init_linear_xavier |
|
|
|
|
|
if style == "kaiming_uniform": |
|
|
return init_linear_kaiming_uniform |
|
|
|
|
|
if style == "sonar": |
|
|
return partial(init_linear_to_sonar, sonar_std=sonar_std) |
|
|
|
|
|
if style == "zero": |
|
|
return init_linear_zero |
|
|
|
|
|
if style == "trunc_normal": |
|
|
return init_linear_trunc_normal |
|
|
|
|
|
if style == "none": |
|
|
return None |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Could not recognize initialization function {style}") |
|
|
|
|
|
|
|
|
def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None: |
|
|
""" |
|
|
Initialize the post-lcm in such a way, that if it is fed layer-normed |
|
|
lcm outputs (with zero mean and unit variance), its outputs have zero |
|
|
mean and the variance of SONAR embeddings. |
|
|
""" |
|
|
if layer.bias is not None: |
|
|
torch.nn.init.zeros_(layer.bias) |
|
|
|
|
|
std = sonar_std * (3 / layer.input_dim) ** 0.5 |
|
|
|
|
|
torch.nn.init.uniform_(layer.weight, a=-std, b=std) |
|
|
|
|
|
|
|
|
def init_linear_xavier(layer: Linear) -> None: |
|
|
torch.nn.init.xavier_uniform_(layer.weight) |
|
|
if layer.bias is not None: |
|
|
torch.nn.init.zeros_(layer.bias) |
|
|
|
|
|
|
|
|
def init_linear_zero(layer: Linear) -> None: |
|
|
torch.nn.init.zeros_(layer.weight) |
|
|
if layer.bias is not None: |
|
|
torch.nn.init.zeros_(layer.bias) |
|
|
|
|
|
|
|
|
def init_linear_trunc_normal(layer: Linear) -> None: |
|
|
torch.nn.init.trunc_normal_(layer.weight, std=1e-3) |
|
|
if layer.bias is not None: |
|
|
torch.nn.init.zeros_(layer.bias) |
|
|
|
|
|
|
|
|
def init_linear_kaiming_uniform(layer: Linear) -> None: |
|
|
torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5)) |
|
|
|
|
|
if layer.bias is not None: |
|
|
fan_in = layer.weight.size(1) |
|
|
|
|
|
m = 1 |
|
|
if layer.weight.ndim > 2: |
|
|
for s in layer.weight.shape[2:]: |
|
|
m *= s |
|
|
|
|
|
fan_in *= m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
|
|
|
|
|
torch.nn.init.uniform_(layer.bias, -bound, bound) |
|
|
|
|
|
|
|
|
def parse_norm_order(var: str) -> TransformerNormOrder: |
|
|
norm_order: TransformerNormOrder |
|
|
if var == "pre": |
|
|
norm_order = TransformerNormOrder.PRE |
|
|
elif var == "post": |
|
|
norm_order = TransformerNormOrder.POST |
|
|
elif var == "normformer": |
|
|
norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER |
|
|
else: |
|
|
raise ValueError(f"Unknown normalization order {var}") |
|
|
|
|
|
return norm_order |
|
|
|
|
|
|
|
|
def parse_activation_fn(var: str = None) -> Optional[Module]: |
|
|
if var is None: |
|
|
return None |
|
|
|
|
|
activ_fn: Module |
|
|
|
|
|
if var == "relu": |
|
|
activ_fn = torch.nn.ReLU() |
|
|
elif var == "tanh": |
|
|
activ_fn = torch.nn.Tanh() |
|
|
elif var == "elu": |
|
|
activ_fn = torch.nn.ELU() |
|
|
elif var == "leaky_relu": |
|
|
activ_fn = torch.nn.LeakyReLU() |
|
|
elif var == "prelu": |
|
|
activ_fn = torch.nn.PReLU() |
|
|
elif var == "selu": |
|
|
activ_fn = torch.nn.SELU() |
|
|
elif var == "gelu": |
|
|
activ_fn = torch.nn.GELU() |
|
|
elif var == "silu": |
|
|
activ_fn = torch.nn.SiLU() |
|
|
elif var == "softsign": |
|
|
activ_fn = torch.nn.Softsign() |
|
|
elif var == "sigmoid": |
|
|
activ_fn = torch.nn.Sigmoid() |
|
|
elif var == "hardsigmoid": |
|
|
activ_fn = torch.nn.Hardsigmoid() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation function {var}") |
|
|
|
|
|
return activ_fn |
|
|
|