EO-VAE / _eo_vae /dynamic_conv.py
BiliSakura's picture
Update all files for EO-VAE
f6a2144 verified
# Apache-2.0 - Based on EO-VAE dynamic convolution
# DynamicConv, DynamicConv_decoder - wavelength-conditioned convolutions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
def get_1d_sincos_pos_embed(embed_dim: int, pos: Tensor) -> Tensor:
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega /= embed_dim / 2.0
omega = 1.0 / (10000**omega)
pos = pos.reshape(-1)
out = torch.einsum("m,d->md", pos, omega)
return torch.cat([torch.sin(out), torch.cos(out)], dim=1)
class FCResLayer(nn.Module):
def __init__(self, linear_size: int = 128):
super().__init__()
self.w1 = nn.Linear(linear_size, linear_size)
self.w2 = nn.Linear(linear_size, linear_size)
def forward(self, x: Tensor) -> Tensor:
return x + nn.functional.relu(self.w2(nn.functional.relu(self.w1(x))))
class TransformerWeightGenerator(nn.Module):
def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1):
super().__init__()
enc = nn.TransformerEncoderLayer(
d_model=input_dim, nhead=num_heads, activation="gelu",
norm_first=False, batch_first=False, dropout=0.0,
)
self.transformer_encoder = nn.TransformerEncoder(
enc, num_layers=num_layers, enable_nested_tensor=False
)
self.fc_weight = nn.Linear(input_dim, output_dim)
self.fc_bias = nn.Linear(input_dim, embed_dim)
self.wt_num = 128
self.weight_tokens = nn.Parameter(torch.empty(self.wt_num, input_dim))
self.bias_token = nn.Parameter(torch.empty(1, input_dim))
nn.init.normal_(self.weight_tokens, std=0.02)
nn.init.normal_(self.bias_token, std=0.02)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0)
out = self.transformer_encoder(x)
weights = self.fc_weight(out[self.wt_num:-1] + x[self.wt_num:-1])
bias = self.fc_bias(out[-1])
return weights, bias
class TransformerWeightGeneratorDecoder(TransformerWeightGenerator):
def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1):
super().__init__(input_dim, output_dim, embed_dim, num_heads, num_layers)
self.fc_bias = nn.Linear(input_dim, 1)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0)
out = self.transformer_encoder(x)
pos = x[self.wt_num:-1]
weights = self.fc_weight(out[self.wt_num:-1] + pos)
bias = self.fc_bias(out[self.wt_num:-1] + self.bias_token.expand(pos.shape[0], -1))
return weights, bias
class DynamicConv(nn.Module):
def __init__(
self,
wv_planes: int,
inter_dim: int = 128,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
embed_dim: int = 128,
num_layers: int = 1,
num_heads: int = 4,
):
super().__init__()
self.kernel_size = kernel_size
self.wv_planes = wv_planes
self.embed_dim = embed_dim
self._num_kernel = kernel_size * kernel_size * embed_dim
self.stride = stride
self.padding = padding
self.scaler = 0.1
self.weight_generator = TransformerWeightGenerator(
wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers
)
self.fclayer = FCResLayer(wv_planes)
for m in [self.weight_generator, self.fclayer]:
for mod in m.modules():
if isinstance(mod, nn.Linear):
init.xavier_uniform_(mod.weight)
if mod.bias is not None:
mod.bias.data.fill_(0.01)
def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor:
waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000)
waves = self.fclayer(waves)
weight, bias = self.weight_generator(waves)
inplanes = wvs.size(0)
dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
dynamic_weight = dynamic_weight.permute(3, 0, 1, 2)
if bias is not None:
bias = bias.view(self.embed_dim) * self.scaler
return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding)
class DynamicConvDecoder(nn.Module):
def __init__(
self,
wv_planes: int,
inter_dim: int = 128,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
embed_dim: int = 128,
num_layers: int = 2,
num_heads: int = 4,
):
super().__init__()
self.kernel_size = kernel_size
self.wv_planes = wv_planes
self.embed_dim = embed_dim
self._num_kernel = kernel_size * kernel_size * embed_dim
self.stride = stride
self.padding = padding
self.scaler = 0.1
self.weight_generator = TransformerWeightGeneratorDecoder(
wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers
)
self.fclayer = FCResLayer(wv_planes)
for m in [self.weight_generator, self.fclayer]:
for mod in m.modules():
if isinstance(mod, nn.Linear):
init.xavier_uniform_(mod.weight)
if mod.bias is not None:
mod.bias.data.fill_(0.01)
def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor:
waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000)
waves = self.fclayer(waves)
weight, bias = self.weight_generator(waves)
inplanes = wvs.size(0)
dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
dynamic_weight = dynamic_weight.permute(0, 3, 1, 2)
if bias is not None:
bias = bias.squeeze() * self.scaler
return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding)