toto10's picture
b04c629d473b73ca15d008cade0dbdf01deeb1a8f48216e43b46a735e2975a9a
91281ee
raw
history blame
577 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class Scale(nn.Module):
"""A learnable scale parameter.
This layer scales the input by a learnable factor. It multiplies a
learnable scale parameter of shape (1,) with input of any shape.
Args:
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale