Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Union | |
| import torch | |
| import torch.nn as nn | |
| class LayerScale(nn.Module): | |
| """LayerScale layer. | |
| Args: | |
| dim (int): Dimension of input features. | |
| layer_scale_init_value (float or torch.Tensor): Init value of layer | |
| scale. Defaults to 1e-5. | |
| inplace (bool): inplace: can optionally do the | |
| operation in-place. Defaults to False. | |
| data_format (str): The input data format, could be 'channels_last' | |
| or 'channels_first', representing (B, C, H, W) and | |
| (B, N, C) format data respectively. Defaults to 'channels_last'. | |
| """ | |
| def __init__(self, | |
| dim: int, | |
| layer_scale_init_value: Union[float, torch.Tensor] = 1e-5, | |
| inplace: bool = False, | |
| data_format: str = 'channels_last'): | |
| super().__init__() | |
| assert data_format in ('channels_last', 'channels_first'), \ | |
| "'data_format' could only be channels_last or channels_first." | |
| self.inplace = inplace | |
| self.data_format = data_format | |
| self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value) | |
| def forward(self, x): | |
| if self.data_format == 'channels_first': | |
| if self.inplace: | |
| return x.mul_(self.weight.view(-1, 1, 1)) | |
| else: | |
| return x * self.weight.view(-1, 1, 1) | |
| return x.mul_(self.weight) if self.inplace else x * self.weight | |