earthflow commited on
Commit
d64d4bd
·
verified ·
1 Parent(s): b0e3ffb

Upload dofa_dinov3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dofa_dinov3.py +252 -0
dofa_dinov3.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # --------------------------------------------------------
7
+
8
+ from functools import partial
9
+ import math
10
+ import einops
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # Add this import for F.pad
15
+ from timm.models.vision_transformer import VisionTransformer
16
+ from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid_torch
17
+ import pdb
18
+
19
+ import timm
20
+
21
+ class TransformerWeightGenerator(nn.Module):
22
+ def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1):
23
+ super(TransformerWeightGenerator, self).__init__()
24
+ encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, activation='gelu', norm_first=False, batch_first=False, dropout=False)
25
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers,enable_nested_tensor=False)
26
+
27
+ # Linear layer to map transformer output to desired weight shape
28
+ self.fc_weight = nn.Linear(input_dim, output_dim)
29
+ self.fc_bias = nn.Linear(input_dim, embed_dim)
30
+ self.wt_num = 128
31
+ self.weight_tokens = nn.Parameter(torch.empty([self.wt_num,input_dim]))
32
+ self.bias_token = nn.Parameter(torch.empty([1,input_dim]))
33
+
34
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
35
+ torch.nn.init.normal_(self.weight_tokens, std=.02)
36
+ torch.nn.init.normal_(self.bias_token, std=.02)
37
+
38
+ def forward(self, x):
39
+ # x should have shape [seq_len, batch, input_dim]
40
+ pos_wave = x
41
+ x = torch.cat([self.weight_tokens, pos_wave],dim=0)
42
+ x = torch.cat([x,self.bias_token], dim=0)
43
+ transformer_output = self.transformer_encoder(x)
44
+ weights = self.fc_weight(transformer_output[self.wt_num:-1]+pos_wave)
45
+ bias = self.fc_bias(transformer_output[-1]) # Using the last output to generate bias
46
+ return weights, bias
47
+
48
+ class Basic1d(nn.Module):
49
+ def __init__(self, in_channels, out_channels, bias=True):
50
+ super().__init__()
51
+ conv = nn.Linear(in_channels, out_channels, bias)
52
+ self.conv = nn.Sequential(conv, )
53
+ if not bias:
54
+ self.conv.add_module('ln', nn.LayerNorm(out_channels))
55
+ self.conv.add_module('relu', nn.ReLU(inplace=True))
56
+
57
+ def forward(self, x):
58
+ out = self.conv(x)
59
+ return out
60
+
61
+ class FCResLayer(nn.Module):
62
+ def __init__(self, linear_size=128):
63
+ super(FCResLayer, self).__init__()
64
+ self.l_size = linear_size
65
+ self.nonlin1 = nn.ReLU(inplace=True)
66
+ self.nonlin2 = nn.ReLU(inplace=True)
67
+ #self.dropout1 = nn.Dropout()
68
+ self.w1 = nn.Linear(self.l_size, self.l_size)
69
+ self.w2 = nn.Linear(self.l_size, self.l_size)
70
+
71
+ def forward(self, x):
72
+ y = self.w1(x)
73
+ y = self.nonlin1(y)
74
+ #y = self.dropout1(y)
75
+ y = self.w2(y)
76
+ y = self.nonlin2(y)
77
+ out = x + y
78
+ return out
79
+
80
+
81
+ class Dynamic_MLP_OFA(nn.Module):
82
+ """
83
+ Input: channels of wavelength (normalized): List -> List
84
+ kernel size of the depth-wise convolution: kernel_size, default 3x3
85
+ wv_planes
86
+ inplanes
87
+ """
88
+
89
+ def __init__(self, wv_planes, inter_dim = 128, kernel_size=3, embed_dim=1024):
90
+ super().__init__()
91
+ self.kernel_size = kernel_size
92
+ self.wv_planes = wv_planes
93
+ self.embed_dim = embed_dim
94
+ self.kernel_size = kernel_size
95
+ self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
96
+ self.inter_dim = inter_dim
97
+ self.patch_size = (kernel_size, kernel_size)
98
+
99
+ self.weight_generator = TransformerWeightGenerator(wv_planes, self._num_kernel, embed_dim)
100
+ self.scaler = 0.1
101
+
102
+ self.fclayer = FCResLayer(wv_planes)
103
+
104
+ self._init_weights()
105
+
106
+ def _get_weights(self, waves):
107
+ dweights = []
108
+ dynamic_weights = self.weight_generator(waves)
109
+
110
+ return dynamic_weights
111
+
112
+ def weight_init(self, m):
113
+ if type(m) == nn.Linear:
114
+ torch.nn.init.xavier_uniform_(m.weight)
115
+ m.bias.data.fill_(0.01)
116
+
117
+ def _init_weights(self):
118
+ """
119
+ initialize the base weights and dynamic mlp weights
120
+ """
121
+ self.weight_generator.apply(self.weight_init)
122
+ self.fclayer.apply(self.weight_init)
123
+
124
+
125
+ def forward(self, img_feat, wvs):
126
+ inplanes = wvs.size(0)
127
+ #wv_feats: 9,128 -> 9, 3x3x3
128
+ waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs*1000)
129
+ waves = self.fclayer(waves)
130
+ weight,bias = self._get_weights(waves) #3x3x3
131
+ #bias = None
132
+
133
+ dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
134
+ dynamic_weight = dynamic_weight.permute([3,0,1,2])
135
+
136
+ if bias is not None:
137
+ bias = bias.view([self.embed_dim]) * self.scaler
138
+
139
+ weights = dynamic_weight * self.scaler
140
+ #pdb.set_trace()
141
+
142
+ dynamic_out = F.conv2d(img_feat, weights, bias=bias, stride=self.kernel_size)
143
+
144
+ x = dynamic_out
145
+ #x = x.flatten(2).transpose(1, 2)
146
+
147
+ return x, waves
148
+
149
+ class DOFAViT(nn.Module):
150
+ """Masked Autoencoder with VisionTransformer backbone"""
151
+
152
+ def __init__(
153
+ self,
154
+ img_size=224,
155
+ patch_size=16,
156
+ drop_rate=0.0,
157
+ out_indices=None,
158
+ drop_path_rate=0.0,
159
+ embed_dim=1024,
160
+ depth=24,
161
+ num_heads=16,
162
+ wv_planes=128,
163
+ num_classes=45,
164
+ global_pool=True,
165
+ mlp_ratio=4.0,
166
+ norm_layer=nn.LayerNorm,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.wv_planes = wv_planes
171
+ self.out_indices = out_indices
172
+ self.global_pool = True
173
+ if self.global_pool:
174
+ norm_layer = norm_layer
175
+ embed_dim = embed_dim
176
+ self.fc_norm = norm_layer(embed_dim)
177
+
178
+ # --------------------------------------------------------------------------
179
+ # MAE encoder specifics
180
+ self.img_size = img_size
181
+ if isinstance(img_size, tuple):
182
+ self.img_size = self.img_size[0]
183
+
184
+ self.num_patches = (self.img_size // patch_size) ** 2
185
+ self.patch_embed = Dynamic_MLP_OFA(wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim)
186
+ self.model = timm.create_model('vit_large_patch16_dinov3.lvd1689m', pretrained=False)
187
+
188
+ self.dynamic_img_size = True
189
+ self.waves = None
190
+ self.norm = norm_layer(embed_dim)
191
+
192
+ self.head_drop = nn.Dropout(drop_rate)
193
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
194
+
195
+ def forward_features(self, x, wave_list):
196
+ with torch.autocast("cuda", enabled=False):
197
+ waves = torch.tensor(wave_list, device=x.device).float()
198
+ x, _ = self.patch_embed(x, waves)
199
+ x = einops.rearrange(x, 'b c h w -> b h w c', h=14, w=14)
200
+ x, rot_pos_embed = self.model._pos_embed(x)
201
+
202
+ x = self.model.norm_pre(x)
203
+ for i,blk in enumerate(self.model.blocks[:-1]):
204
+ x = blk(x, rope=rot_pos_embed)
205
+ if i == len(self.model.blocks)-2:
206
+ outx = x
207
+
208
+ if self.global_pool:
209
+ x = self.model.norm(outx)
210
+ x = x[:, self.model.num_prefix_tokens:, :].mean(dim=1) # global pool without cls token
211
+ outcome = self.fc_norm(x)
212
+ else:
213
+ x = self.model.norm(x)
214
+ outcome = x[:, 0]
215
+ return outcome
216
+
217
+ def forward_head(self, x, pre_logits=False):
218
+ x = self.model.head_drop(x)
219
+ return x if pre_logits else self.head(x)
220
+
221
+ def forward(self, x, wave_list):
222
+ x = self.forward_features(x, wave_list)
223
+ x = self.forward_head(x)
224
+ return x
225
+
226
+
227
+ def vit_base_patch16(**kwargs):
228
+ model = DOFAViT(
229
+ out_indices=[4, 6, 10, 11],
230
+ patch_size=16,
231
+ embed_dim=768,
232
+ depth=12,
233
+ num_heads=12,
234
+ mlp_ratio=4,
235
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
236
+ **kwargs,
237
+ )
238
+ return model
239
+
240
+
241
+ def vit_large_patch16(**kwargs):
242
+ model = DOFAViT(
243
+ out_indices=[5, 11, 17, 23],
244
+ patch_size=16,
245
+ embed_dim=1024,
246
+ depth=24,
247
+ num_heads=16,
248
+ mlp_ratio=4,
249
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
250
+ **kwargs,
251
+ )
252
+ return model