File size: 4,236 Bytes
3d79eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#

import math
from functools import partial
from typing import Literal, Optional

import torch
from fairseq2.nn.projection import Linear
from fairseq2.nn.transformer import TransformerNormOrder
from torch.nn import Module

SUPPORTED_INIT_TYPES = Literal[
    "xavier",
    "sonar",
    "zero",
    "trunc_normal",
    "kaiming_uniform",
    "none",
]


SONAR_STD = 0.006
# Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006
# Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended,
# to minimize their disruption of the model training


def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD):
    if style == "xavier":
        return init_linear_xavier

    if style == "kaiming_uniform":
        return init_linear_kaiming_uniform

    if style == "sonar":
        return partial(init_linear_to_sonar, sonar_std=sonar_std)

    if style == "zero":
        return init_linear_zero

    if style == "trunc_normal":
        return init_linear_trunc_normal

    if style == "none":
        return None

    else:
        raise ValueError(f"Could not recognize initialization function {style}")


def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None:
    """
    Initialize the post-lcm in such a way, that if it is fed layer-normed
    lcm outputs (with zero mean and unit variance), its outputs have zero
    mean and the variance of SONAR embeddings.
    """
    if layer.bias is not None:
        torch.nn.init.zeros_(layer.bias)

    std = sonar_std * (3 / layer.input_dim) ** 0.5

    torch.nn.init.uniform_(layer.weight, a=-std, b=std)


def init_linear_xavier(layer: Linear) -> None:
    torch.nn.init.xavier_uniform_(layer.weight)
    if layer.bias is not None:
        torch.nn.init.zeros_(layer.bias)


def init_linear_zero(layer: Linear) -> None:
    torch.nn.init.zeros_(layer.weight)
    if layer.bias is not None:
        torch.nn.init.zeros_(layer.bias)


def init_linear_trunc_normal(layer: Linear) -> None:
    torch.nn.init.trunc_normal_(layer.weight, std=1e-3)
    if layer.bias is not None:
        torch.nn.init.zeros_(layer.bias)


def init_linear_kaiming_uniform(layer: Linear) -> None:
    torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))

    if layer.bias is not None:
        fan_in = layer.weight.size(1)

        m = 1
        if layer.weight.ndim > 2:
            for s in layer.weight.shape[2:]:
                m *= s

        fan_in *= m

        # We do not calculate the true standard deviation of the uniform
        # distribution (i.e. multiply with sqrt(3)). See
        # https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575.
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0

        torch.nn.init.uniform_(layer.bias, -bound, bound)


def parse_norm_order(var: str) -> TransformerNormOrder:
    norm_order: TransformerNormOrder
    if var == "pre":
        norm_order = TransformerNormOrder.PRE
    elif var == "post":
        norm_order = TransformerNormOrder.POST
    elif var == "normformer":
        norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER
    else:
        raise ValueError(f"Unknown normalization order {var}")

    return norm_order


def parse_activation_fn(var: str = None) -> Optional[Module]:
    if var is None:
        return None

    activ_fn: Module

    if var == "relu":
        activ_fn = torch.nn.ReLU()
    elif var == "tanh":
        activ_fn = torch.nn.Tanh()
    elif var == "elu":
        activ_fn = torch.nn.ELU()
    elif var == "leaky_relu":
        activ_fn = torch.nn.LeakyReLU()
    elif var == "prelu":
        activ_fn = torch.nn.PReLU()
    elif var == "selu":
        activ_fn = torch.nn.SELU()
    elif var == "gelu":
        activ_fn = torch.nn.GELU()
    elif var == "silu":
        activ_fn = torch.nn.SiLU()
    elif var == "softsign":
        activ_fn = torch.nn.Softsign()
    elif var == "sigmoid":
        activ_fn = torch.nn.Sigmoid()
    elif var == "hardsigmoid":
        activ_fn = torch.nn.Hardsigmoid()
    else:
        raise ValueError(f"Unknown activation function {var}")

    return activ_fn