English
File size: 3,719 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn
from src.utils.parameter import LearnableParameter


__all__ = ['init_weights']


def init_weights(m, linear=None, rpe=None, activation='leaky_relu'):
    """Manual weight initialization. Allows setting specific init modes
    for certain modules. In particular, the linear and RPE layers are
    initialized with Xavier uniform initialization by default:
    https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
    Supported initializations are:
      - 'xavier_uniform'
      - 'xavier_normal'
      - 'kaiming_uniform'
      - 'kaiming_normal'
      - 'trunc_normal'
    """
    from src.nn import SelfAttentionBlock

    linear = 'xavier_uniform' if linear is None else linear
    rpe = linear if rpe is None else rpe

    if isinstance(m, LearnableParameter):
        nn.init.trunc_normal_(m, std=0.02)
        return

    if isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)
        return

    if isinstance(m, nn.Linear):
        _linear_init(m, method=linear, activation=activation)
        return

    if isinstance(m, SelfAttentionBlock):
        if m.k_rpe is not None:
            _linear_init(m.k_rpe, method=rpe, activation=activation)
        if m.q_rpe is not None:
            _linear_init(m.q_rpe, method=rpe, activation=activation)
        return


def _linear_init(m, method='xavier_uniform', activation='leaky_relu'):
    gain = torch.nn.init.calculate_gain(activation)

    if m.bias is not None:
        nn.init.constant_(m.bias, 0)

    if method == 'xavier_uniform':
        nn.init.xavier_uniform_(m.weight, gain=gain)
    elif method == 'xavier_normal':
        nn.init.xavier_normal_(m.weight, gain=gain)
    elif method == 'kaiming_uniform':
        nn.init.kaiming_uniform_(m.weight, nonlinearity=activation)
    elif method == 'kaiming_normal':
        nn.init.kaiming_normal_(m.weight, nonlinearity=activation)
    elif method == 'trunc_normal':
        nn.init.trunc_normal_(m.weight, std=0.02)
    else:
        raise NotImplementedError(f"Unknown initialization method: {method}")


def build_qk_scale_func(dim, num_heads, qk_scale):
    """Builds the QK-scale function that will be used to produce
    the qk-scale. This function follows the template:
        f(s), where `s` is the `edge_index[0]`
    even if it does not use it.
    """
    # If qk_scale is not provided, the default behavior will be
    # 1/(sqrt(dim)*sqrt(num))
    if qk_scale is None:
        def f(s):
            D = (dim // num_heads) ** -0.5
            G = (s.bincount() ** -0.5)[s].view(-1, 1, 1)
            return D * G
        return f

    # If qk_scale is provided as a scalar, it will be used as is
    if not isinstance(qk_scale, str):
        def f(s):
            return qk_scale
        return f

    # Convert input str to lowercase and remove spaces before
    # parsing
    qk_scale = qk_scale.lower().replace(' ', '')

    if qk_scale in ['d+g', 'g+d']:
        def f(s):
            D = (dim // num_heads) ** -0.5
            G = (s.bincount() ** -0.5)[s].view(-1, 1, 1)
            return D + G
        return f

    if qk_scale in ['dg', 'gd', 'd*g', 'g*d', 'd.g', 'g.d']:
        def f(s):
            D = (dim // num_heads) ** -0.5
            G = (s.bincount() ** -0.5)[s].view(-1, 1, 1)
            return D * G
        return f

    if qk_scale == 'd':
        def f(s):
            D = (dim // num_heads) ** -0.5
            return D
        return f

    if qk_scale == 'g':
        def f(s):
            G = (s.bincount() ** -0.5)[s].view(-1, 1, 1)
            return G
        return f

    raise ValueError(
        f"Unable to build QK scaling scheme for qk_scale='{qk_scale}'")