| import math | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from core.networks.dynamic_conv import DynamicConv | |
| class DynamicLinear(nn.Module): | |
| def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True): | |
| super().__init__() | |
| self.dynamic_conv = DynamicConv( | |
| in_planes, | |
| out_planes, | |
| cond_planes, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=bias, | |
| K=K, | |
| ratio=ratio, | |
| temperature=temperature, | |
| init_weight=init_weight, | |
| ) | |
| def forward(self, x, cond): | |
| """ | |
| Args: | |
| x (_type_): (L, B, C_in) | |
| cond (_type_): (B, C_style) | |
| Returns: | |
| _type_: (L, B, C_out) | |
| """ | |
| x = x.permute(1, 2, 0).unsqueeze(-1) | |
| out = self.dynamic_conv(x, cond) | |
| # (B, C_out, L, 1) | |
| out = out.squeeze().permute(2, 0, 1) | |
| return out | |