English
File size: 289 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn


__all__ = ['LearnableParameter']


class LearnableParameter(nn.Parameter):
    """A simple class to be used for learnable parameters (e.g. learnable
    position encodings, queries, keys, ...). Using this is useful to use
    custom weight initialization.
    """