| import torch.nn as nn | |
| __all__ = ['LearnableParameter'] | |
| class LearnableParameter(nn.Parameter): | |
| """A simple class to be used for learnable parameters (e.g. learnable | |
| position encodings, queries, keys, ...). Using this is useful to use | |
| custom weight initialization. | |
| """ | |