BiliSakura commited on
Commit
f6a2144
·
verified ·
1 Parent(s): 774dec2

Update all files for EO-VAE

Browse files
Files changed (1) hide show
  1. _eo_vae/dynamic_conv.py +156 -0
_eo_vae/dynamic_conv.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apache-2.0 - Based on EO-VAE dynamic convolution
2
+ # DynamicConv, DynamicConv_decoder - wavelength-conditioned convolutions
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.nn.init as init
8
+ from torch import Tensor
9
+
10
+
11
+ def get_1d_sincos_pos_embed(embed_dim: int, pos: Tensor) -> Tensor:
12
+ assert embed_dim % 2 == 0
13
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
14
+ omega /= embed_dim / 2.0
15
+ omega = 1.0 / (10000**omega)
16
+ pos = pos.reshape(-1)
17
+ out = torch.einsum("m,d->md", pos, omega)
18
+ return torch.cat([torch.sin(out), torch.cos(out)], dim=1)
19
+
20
+
21
+ class FCResLayer(nn.Module):
22
+ def __init__(self, linear_size: int = 128):
23
+ super().__init__()
24
+ self.w1 = nn.Linear(linear_size, linear_size)
25
+ self.w2 = nn.Linear(linear_size, linear_size)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x + nn.functional.relu(self.w2(nn.functional.relu(self.w1(x))))
29
+
30
+
31
+ class TransformerWeightGenerator(nn.Module):
32
+ def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1):
33
+ super().__init__()
34
+ enc = nn.TransformerEncoderLayer(
35
+ d_model=input_dim, nhead=num_heads, activation="gelu",
36
+ norm_first=False, batch_first=False, dropout=0.0,
37
+ )
38
+ self.transformer_encoder = nn.TransformerEncoder(
39
+ enc, num_layers=num_layers, enable_nested_tensor=False
40
+ )
41
+ self.fc_weight = nn.Linear(input_dim, output_dim)
42
+ self.fc_bias = nn.Linear(input_dim, embed_dim)
43
+ self.wt_num = 128
44
+ self.weight_tokens = nn.Parameter(torch.empty(self.wt_num, input_dim))
45
+ self.bias_token = nn.Parameter(torch.empty(1, input_dim))
46
+ nn.init.normal_(self.weight_tokens, std=0.02)
47
+ nn.init.normal_(self.bias_token, std=0.02)
48
+
49
+ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
50
+ x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0)
51
+ out = self.transformer_encoder(x)
52
+ weights = self.fc_weight(out[self.wt_num:-1] + x[self.wt_num:-1])
53
+ bias = self.fc_bias(out[-1])
54
+ return weights, bias
55
+
56
+
57
+ class TransformerWeightGeneratorDecoder(TransformerWeightGenerator):
58
+ def __init__(self, input_dim: int, output_dim: int, embed_dim: int, num_heads: int = 4, num_layers: int = 1):
59
+ super().__init__(input_dim, output_dim, embed_dim, num_heads, num_layers)
60
+ self.fc_bias = nn.Linear(input_dim, 1)
61
+
62
+ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
63
+ x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0)
64
+ out = self.transformer_encoder(x)
65
+ pos = x[self.wt_num:-1]
66
+ weights = self.fc_weight(out[self.wt_num:-1] + pos)
67
+ bias = self.fc_bias(out[self.wt_num:-1] + self.bias_token.expand(pos.shape[0], -1))
68
+ return weights, bias
69
+
70
+
71
+ class DynamicConv(nn.Module):
72
+ def __init__(
73
+ self,
74
+ wv_planes: int,
75
+ inter_dim: int = 128,
76
+ kernel_size: int = 3,
77
+ stride: int = 1,
78
+ padding: int = 1,
79
+ embed_dim: int = 128,
80
+ num_layers: int = 1,
81
+ num_heads: int = 4,
82
+ ):
83
+ super().__init__()
84
+ self.kernel_size = kernel_size
85
+ self.wv_planes = wv_planes
86
+ self.embed_dim = embed_dim
87
+ self._num_kernel = kernel_size * kernel_size * embed_dim
88
+ self.stride = stride
89
+ self.padding = padding
90
+ self.scaler = 0.1
91
+
92
+ self.weight_generator = TransformerWeightGenerator(
93
+ wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers
94
+ )
95
+ self.fclayer = FCResLayer(wv_planes)
96
+ for m in [self.weight_generator, self.fclayer]:
97
+ for mod in m.modules():
98
+ if isinstance(mod, nn.Linear):
99
+ init.xavier_uniform_(mod.weight)
100
+ if mod.bias is not None:
101
+ mod.bias.data.fill_(0.01)
102
+
103
+ def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor:
104
+ waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000)
105
+ waves = self.fclayer(waves)
106
+ weight, bias = self.weight_generator(waves)
107
+ inplanes = wvs.size(0)
108
+ dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
109
+ dynamic_weight = dynamic_weight.permute(3, 0, 1, 2)
110
+ if bias is not None:
111
+ bias = bias.view(self.embed_dim) * self.scaler
112
+ return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding)
113
+
114
+
115
+ class DynamicConvDecoder(nn.Module):
116
+ def __init__(
117
+ self,
118
+ wv_planes: int,
119
+ inter_dim: int = 128,
120
+ kernel_size: int = 3,
121
+ stride: int = 1,
122
+ padding: int = 1,
123
+ embed_dim: int = 128,
124
+ num_layers: int = 2,
125
+ num_heads: int = 4,
126
+ ):
127
+ super().__init__()
128
+ self.kernel_size = kernel_size
129
+ self.wv_planes = wv_planes
130
+ self.embed_dim = embed_dim
131
+ self._num_kernel = kernel_size * kernel_size * embed_dim
132
+ self.stride = stride
133
+ self.padding = padding
134
+ self.scaler = 0.1
135
+
136
+ self.weight_generator = TransformerWeightGeneratorDecoder(
137
+ wv_planes, self._num_kernel, embed_dim, num_heads=num_heads, num_layers=num_layers
138
+ )
139
+ self.fclayer = FCResLayer(wv_planes)
140
+ for m in [self.weight_generator, self.fclayer]:
141
+ for mod in m.modules():
142
+ if isinstance(mod, nn.Linear):
143
+ init.xavier_uniform_(mod.weight)
144
+ if mod.bias is not None:
145
+ mod.bias.data.fill_(0.01)
146
+
147
+ def forward(self, img_feat: Tensor, wvs: Tensor) -> Tensor:
148
+ waves = get_1d_sincos_pos_embed(self.wv_planes, wvs * 1000)
149
+ waves = self.fclayer(waves)
150
+ weight, bias = self.weight_generator(waves)
151
+ inplanes = wvs.size(0)
152
+ dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
153
+ dynamic_weight = dynamic_weight.permute(0, 3, 1, 2)
154
+ if bias is not None:
155
+ bias = bias.squeeze() * self.scaler
156
+ return F.conv2d(img_feat, dynamic_weight * self.scaler, bias, (self.stride, self.stride), self.padding)