| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.init as init |
| from torch import Tensor |
|
|
|
|
| def get_1d_sincos_pos_embed(embed_dim: int, pos: Tensor) -> Tensor: |
| assert embed_dim % 2 == 0 |
| omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / (10000**omega) |
| pos = pos.reshape(-1) |
| out = torch.einsum("m,d->md", pos, omega) |
| return torch.cat([torch.sin(out), torch.cos(out)], dim=1) |
|
|
|
|
| class FCResLayer(nn.Module): |
| def __init__(self, linear_size: int = 128): |
| super().__init__() |
| self.w1 = nn.Linear(linear_size, linear_size) |
| self.w2 = nn.Linear(linear_size, linear_size) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x + nn.functional.relu(self.w2(nn.functional.relu(self.w1(x)))) |
|
|
|
|
| class TransformerWeightGenerator(nn.Module): |
| def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1): |
| super().__init__() |
| enc = nn.TransformerEncoderLayer( |
| d_model=input_dim, nhead=num_heads, activation="gelu", |
| norm_first=False, batch_first=False, dropout=0.0, |
| ) |
| self.transformer_encoder = nn.TransformerEncoder( |
| enc, num_layers=num_layers, enable_nested_tensor=False |
| ) |
| self.fc_weight = nn.Linear(input_dim, output_dim) |
| self.fc_bias = nn.Linear(input_dim, embed_dim) |
| self.wt_num = 128 |
| self.weight_tokens = nn.Parameter(torch.empty(self.wt_num, input_dim)) |
| self.bias_token = nn.Parameter(torch.empty(1, input_dim)) |
| nn.init.normal_(self.weight_tokens, std=0.02) |
| nn.init.normal_(self.bias_token, std=0.02) |
|
|
| def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: |
| x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) |
| out = self.transformer_encoder(x) |
| weights = self.fc_weight(out[self.wt_num:-1] + x[self.wt_num:-1]) |
| bias = self.fc_bias(out[-1]) |
| return weights, bias |
|
|
|
|
| class TransformerWeightGeneratorDecoder(TransformerWeightGenerator): |
| def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1): |
| super().__init__(input_dim, output_dim, embed_dim, num_heads, num_layers) |
| self.fc_bias = nn.Linear(input_dim, 1) |
|
|
| def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: |
| x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) |
| out = self.transformer_encoder(x) |
| pos = x[self.wt_num:-1] |
| weights = self.fc_weight(out[self.wt_num:-1] + pos) |
| bias = self.fc_bias(out[self.wt_num:-1] + self.bias_token.expand(pos.shape[0], -1)) |
| return weights, bias |
|
|
|
|
| class DynamicConv(nn.Module): |
| def __init__( |
| self, |
| wv_planes: int, |
| inter_dim: int = 128, |
| kernel_size: int = 3, |
| stride: int = 1, |
| padding: int = 1, |
| embed_dim: int = 128, |
| num_layers: int = 1, |
| num_heads: int = 4, |
| ): |
| super().__init__() |
| self.kernel_size = kernel_size |
| self.wv_planes = wv_planes |
| self.embed_dim = embed_dim |
| self._num_kernel = kernel_size * kernel_size * embed_dim |
| self.stride = stride |
| self.padding = padding |
| self.scaler = 0.1 |
|
|
| self.weight_generator = TransformerWeightGenerator( |
| wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers |
| ) |
| self.fclayer = FCResLayer(wv_planes) |
| for m in [self.weight_generator, self.fclayer]: |
| for mod in m.modules(): |
| if isinstance(mod, nn.Linear): |
| init.xavier_uniform_(mod.weight) |
| if mod.bias is not None: |
| mod.bias.data.fill_(0.01) |
|
|
| def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor: |
| waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000) |
| waves = self.fclayer(waves) |
| weight, bias = self.weight_generator(waves) |
| inplanes = wvs.size(0) |
| dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim) |
| dynamic_weight = dynamic_weight.permute(3, 0, 1, 2) |
| if bias is not None: |
| bias = bias.view(self.embed_dim) * self.scaler |
| return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding) |
|
|
|
|
| class DynamicConvDecoder(nn.Module): |
| def __init__( |
| self, |
| wv_planes: int, |
| inter_dim: int = 128, |
| kernel_size: int = 3, |
| stride: int = 1, |
| padding: int = 1, |
| embed_dim: int = 128, |
| num_layers: int = 2, |
| num_heads: int = 4, |
| ): |
| super().__init__() |
| self.kernel_size = kernel_size |
| self.wv_planes = wv_planes |
| self.embed_dim = embed_dim |
| self._num_kernel = kernel_size * kernel_size * embed_dim |
| self.stride = stride |
| self.padding = padding |
| self.scaler = 0.1 |
|
|
| self.weight_generator = TransformerWeightGeneratorDecoder( |
| wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers |
| ) |
| self.fclayer = FCResLayer(wv_planes) |
| for m in [self.weight_generator, self.fclayer]: |
| for mod in m.modules(): |
| if isinstance(mod, nn.Linear): |
| init.xavier_uniform_(mod.weight) |
| if mod.bias is not None: |
| mod.bias.data.fill_(0.01) |
|
|
| def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor: |
| waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000) |
| waves = self.fclayer(waves) |
| weight, bias = self.weight_generator(waves) |
| inplanes = wvs.size(0) |
| dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim) |
| dynamic_weight = dynamic_weight.permute(0, 3, 1, 2) |
| if bias is not None: |
| bias = bias.squeeze() * self.scaler |
| return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding) |
|
|