File size: 289 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch.nn as nn
__all__ = ['LearnableParameter']
class LearnableParameter(nn.Parameter):
"""A simple class to be used for learnable parameters (e.g. learnable
position encodings, queries, keys, ...). Using this is useful to use
custom weight initialization.
"""
|