File size: 4,236 Bytes
3d79eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import math
from functools import partial
from typing import Literal, Optional
import torch
from fairseq2.nn.projection import Linear
from fairseq2.nn.transformer import TransformerNormOrder
from torch.nn import Module
SUPPORTED_INIT_TYPES = Literal[
"xavier",
"sonar",
"zero",
"trunc_normal",
"kaiming_uniform",
"none",
]
SONAR_STD = 0.006
# Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006
# Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended,
# to minimize their disruption of the model training
def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD):
if style == "xavier":
return init_linear_xavier
if style == "kaiming_uniform":
return init_linear_kaiming_uniform
if style == "sonar":
return partial(init_linear_to_sonar, sonar_std=sonar_std)
if style == "zero":
return init_linear_zero
if style == "trunc_normal":
return init_linear_trunc_normal
if style == "none":
return None
else:
raise ValueError(f"Could not recognize initialization function {style}")
def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None:
"""
Initialize the post-lcm in such a way, that if it is fed layer-normed
lcm outputs (with zero mean and unit variance), its outputs have zero
mean and the variance of SONAR embeddings.
"""
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
std = sonar_std * (3 / layer.input_dim) ** 0.5
torch.nn.init.uniform_(layer.weight, a=-std, b=std)
def init_linear_xavier(layer: Linear) -> None:
torch.nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_zero(layer: Linear) -> None:
torch.nn.init.zeros_(layer.weight)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_trunc_normal(layer: Linear) -> None:
torch.nn.init.trunc_normal_(layer.weight, std=1e-3)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def init_linear_kaiming_uniform(layer: Linear) -> None:
torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in = layer.weight.size(1)
m = 1
if layer.weight.ndim > 2:
for s in layer.weight.shape[2:]:
m *= s
fan_in *= m
# We do not calculate the true standard deviation of the uniform
# distribution (i.e. multiply with sqrt(3)). See
# https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575.
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(layer.bias, -bound, bound)
def parse_norm_order(var: str) -> TransformerNormOrder:
norm_order: TransformerNormOrder
if var == "pre":
norm_order = TransformerNormOrder.PRE
elif var == "post":
norm_order = TransformerNormOrder.POST
elif var == "normformer":
norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER
else:
raise ValueError(f"Unknown normalization order {var}")
return norm_order
def parse_activation_fn(var: str = None) -> Optional[Module]:
if var is None:
return None
activ_fn: Module
if var == "relu":
activ_fn = torch.nn.ReLU()
elif var == "tanh":
activ_fn = torch.nn.Tanh()
elif var == "elu":
activ_fn = torch.nn.ELU()
elif var == "leaky_relu":
activ_fn = torch.nn.LeakyReLU()
elif var == "prelu":
activ_fn = torch.nn.PReLU()
elif var == "selu":
activ_fn = torch.nn.SELU()
elif var == "gelu":
activ_fn = torch.nn.GELU()
elif var == "silu":
activ_fn = torch.nn.SiLU()
elif var == "softsign":
activ_fn = torch.nn.Softsign()
elif var == "sigmoid":
activ_fn = torch.nn.Sigmoid()
elif var == "hardsigmoid":
activ_fn = torch.nn.Hardsigmoid()
else:
raise ValueError(f"Unknown activation function {var}")
return activ_fn
|