zdai257 commited on
Commit
11f4370
·
verified ·
1 Parent(s): 4e928c5

Initial Commit: network builder and model.pt

Browse files
Files changed (2) hide show
  1. GazeMoE.pt +3 -0
  2. gazemoe_builder.py +408 -0
GazeMoE.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ed8db85e877f4b51ddc0c6afe76e0ea338361bcb64c6a3cc1cee379b390dc65
3
+ size 14647206
gazemoe_builder.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torchvision.transforms as transforms
5
+ from timm.models.vision_transformer import Block
6
+ import math
7
+
8
+
9
+ # Official DINOv2 backbones from torch hub (https://github.com/facebookresearch/dinov2#pretrained-backbones-via-pytorch-hub)
10
+ class DinoV2Backbone(nn.Module):
11
+ def __init__(self, model_name):
12
+ super(DinoV2Backbone, self).__init__()
13
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name)
14
+
15
+ def forward(self, x):
16
+ b, c, h, w = x.shape
17
+ out_h, out_w = self.get_out_size((h, w))
18
+ x = self.model.forward_features(x)['x_norm_patchtokens']
19
+ x = x.view(x.size(0), out_h, out_w, -1).permute(0, 3, 1, 2) # "b (out_h out_w) c -> b c out_h out_w"
20
+ return x
21
+
22
+ def get_dimension(self):
23
+ return self.model.embed_dim
24
+
25
+ def get_out_size(self, in_size):
26
+ h, w = in_size
27
+ return (h // self.model.patch_size, w // self.model.patch_size)
28
+
29
+ def get_transform(self, in_size):
30
+ return transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406],
34
+ std=[0.229, 0.224, 0.225]
35
+ ),
36
+ transforms.Resize(in_size),
37
+ ])
38
+
39
+
40
+ class DinoV2BackboneMultiScale(nn.Module):
41
+ def __init__(self, model_name, num_scales=3):
42
+ super().__init__()
43
+ self.base_backbone = DinoV2Backbone(model_name)
44
+ # Store the desired number of scales
45
+ self.num_scales = num_scales
46
+ if self.num_scales < 1:
47
+ raise ValueError("num_scales must be at least 1")
48
+
49
+ def forward(self, x):
50
+ # Obtain the original feature map [B, C, H, W]
51
+ features = self.base_backbone.forward(x)
52
+ multi_scale_features = []
53
+ current_features = features
54
+ for i in range(self.num_scales):
55
+ if i == 0:
56
+ # First scale is the original feature map
57
+ multi_scale_features.append(current_features)
58
+ else:
59
+ # Subsequent scales are downsampled
60
+ # Using 0.5^i as scale factor relative to the original
61
+ scale_factor = 0.5 ** i
62
+ downsampled_features = nn.functional.interpolate(
63
+ features, scale_factor=scale_factor, mode='bilinear', align_corners=False
64
+ )
65
+ multi_scale_features.append(downsampled_features)
66
+ # Return a list of feature maps
67
+ return multi_scale_features
68
+
69
+ def get_out_size(self, in_size):
70
+ return self.base_backbone.get_out_size(in_size)
71
+
72
+ def get_multi_scale_channels(self):
73
+ C = self.base_backbone.get_dimension()
74
+ # Return a list of C repeated num_scales times
75
+ return [C] * self.num_scales
76
+
77
+ def get_transform(self, size):
78
+ return self.base_backbone.get_transform(size)
79
+
80
+
81
+ def repeat_tensors(tensor, repeat_counts):
82
+ repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)]
83
+ return torch.cat(repeated_tensors, dim=0)
84
+
85
+
86
+ def split_tensors(tensor, split_counts):
87
+ indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0)
88
+ return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))]
89
+
90
+
91
+ class MultiScaleFusionLite(nn.Module):
92
+ def __init__(self, in_channels_list, out_channels, target_size):
93
+ """
94
+ Args:
95
+ in_channels_list: List of channel dimensions for each feature map.
96
+ out_channels: Desired number of channels after fusion.
97
+ target_size: Tuple (height, width) for spatial alignment.
98
+ """
99
+ super().__init__()
100
+ self.target_size = target_size
101
+ self.convs = nn.ModuleList([
102
+ nn.Conv2d(in_ch, out_channels, kernel_size=1)
103
+ for in_ch in in_channels_list
104
+ ])
105
+ # Learnable scalar weights for each scale
106
+ self.scale_weights = nn.Parameter(torch.ones(len(in_channels_list)))
107
+ self.refine_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
108
+
109
+ def forward(self, feature_maps):
110
+ processed_maps = []
111
+ for conv, feat in zip(self.convs, feature_maps):
112
+ feat_proj = conv(feat)
113
+ feat_resized = nn.functional.interpolate(feat_proj, size=self.target_size, mode='bilinear',
114
+ align_corners=False)
115
+ processed_maps.append(feat_resized)
116
+ weights = torch.softmax(self.scale_weights, dim=0)
117
+ fused = sum(w * feat for w, feat in zip(weights, processed_maps))
118
+ fused = self.refine_conv(fused)
119
+ return fused
120
+
121
+
122
+ class MoELayer(nn.Module):
123
+ def __init__(self, in_features, out_features, num_experts=8, num_shared_experts=2, top_k=2, hidden_dim=None):
124
+ super().__init__()
125
+ self.num_experts = num_experts # Routed experts
126
+ self.num_shared_experts = num_shared_experts # Shared experts
127
+ self.top_k = top_k
128
+ self.in_features = in_features
129
+ self.out_features = out_features
130
+ self.hidden_dim = hidden_dim if hidden_dim is not None else in_features * 4
131
+
132
+ # Routed expert networks
133
+ self.routed_experts = nn.ModuleList([
134
+ nn.Sequential(
135
+ nn.Linear(in_features, self.hidden_dim),
136
+ nn.GELU(),
137
+ nn.Linear(self.hidden_dim, out_features)
138
+ ) for _ in range(num_experts)
139
+ ])
140
+
141
+ # Shared expert networks
142
+ self.shared_experts = nn.ModuleList([
143
+ nn.Sequential(
144
+ nn.Linear(in_features, self.hidden_dim),
145
+ nn.GELU(),
146
+ nn.Linear(self.hidden_dim, out_features)
147
+ ) for _ in range(num_shared_experts)
148
+ ])
149
+
150
+ # Gating network for routed experts only
151
+ self.gate = nn.Linear(in_features, num_experts)
152
+
153
+ def forward(self, x):
154
+ # x: [batch_size, seq_len, in_features] or [batch_size, in_features]
155
+ batch_shape = x.shape[:-1]
156
+ x_flat = x.view(-1, self.in_features) # [batch_size * seq_len, in_features]
157
+
158
+ # Initialize output
159
+ output = torch.zeros(x_flat.shape[0], self.out_features, device=x.device)
160
+
161
+ # Shared experts: always applied
162
+ for expert in self.shared_experts:
163
+ output += expert(x_flat) / (self.num_shared_experts + 1e-10) # Average shared contributions
164
+
165
+ # Routed experts: top-k selection
166
+ gate_logits = self.gate(x_flat) # [batch_size * seq_len, num_experts]
167
+ gate_weights = torch.softmax(gate_logits, dim=-1) # [batch_size * seq_len, num_experts]
168
+ top_k_weights, top_k_indices = gate_weights.topk(self.top_k, dim=-1) # [batch_size * seq_len, top_k]
169
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-10) # Normalize
170
+
171
+ # Compute weighted sum of routed expert outputs
172
+ for k in range(self.top_k):
173
+ expert_idx = top_k_indices[:, k] # [batch_size * seq_len]
174
+ weights = top_k_weights[:, k].unsqueeze(-1) # [batch_size * seq_len, 1]
175
+ for i in range(self.num_experts):
176
+ mask = (expert_idx == i).float().unsqueeze(-1) # [batch_size * seq_len, 1]
177
+ expert_output = self.routed_experts[i](x_flat) # [batch_size * seq_len, out_features]
178
+ output += mask * weights * expert_output
179
+
180
+ # Reshape back to original shape
181
+ output = output.view(*batch_shape, self.out_features)
182
+ return output
183
+
184
+
185
+ class MoEBlock(Block):
186
+ def __init__(self, dim, num_heads, mlp_ratio=4., drop_path=0.1, num_experts=8, num_shared_experts=2, top_k=2):
187
+ super().__init__(dim, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path)
188
+ # Replace the FFN (self.mlp) with MoELayer
189
+ hidden_dim = int(dim * mlp_ratio)
190
+ self.mlp = MoELayer(
191
+ in_features=dim,
192
+ out_features=dim,
193
+ num_experts=num_experts,
194
+ num_shared_experts=num_shared_experts,
195
+ top_k=top_k,
196
+ hidden_dim=hidden_dim
197
+ )
198
+
199
+
200
+ # Shared Transformer (for 'vanilla' decoder block)
201
+ class SharedTransformer(nn.Module):
202
+ def __init__(self, transformer_block, num_layers):
203
+ super().__init__()
204
+ self.block = transformer_block # A single transformer block (e.g., vanilla Block)
205
+ self.num_layers = num_layers
206
+
207
+ def forward(self, x):
208
+ for _ in range(self.num_layers):
209
+ x = self.block(x)
210
+ return x
211
+
212
+
213
+ class GazeMoE(nn.Module):
214
+ def __init__(self, backbone, inout=False, dim=256, mlp_ratio=4, num_layers=3, in_size=(448, 448), out_size=(64, 64),
215
+ num_experts=8, num_shared_experts=2, top_k=2, dropout=0.1, moe_type="vanilla", is_msf=False):
216
+ super().__init__()
217
+ self.backbone = backbone
218
+ self.dim = dim
219
+ self.mlp_ratio = mlp_ratio
220
+ self.num_layers = num_layers
221
+ self.featmap_h, self.featmap_w = backbone.get_out_size(in_size)
222
+ self.in_size = in_size
223
+ self.out_size = out_size
224
+ self.inout = inout
225
+ self.num_experts = num_experts
226
+ self.num_shared_experts = num_shared_experts
227
+ self.top_k = top_k
228
+ if not is_msf:
229
+ self.ms_fusion = nn.Conv2d(backbone.get_dimension(), self.dim, 1)
230
+ else:
231
+ # Multi-scale fusion module (lightweight version)
232
+ multi_scale_channels = backbone.get_multi_scale_channels()
233
+ self.ms_fusion = MultiScaleFusionLite(
234
+ in_channels_list=multi_scale_channels,
235
+ out_channels=self.dim,
236
+ target_size=(self.featmap_h, self.featmap_w)
237
+ )
238
+
239
+ self.register_buffer("pos_embed",
240
+ positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(
241
+ dim=0))
242
+
243
+ if moe_type == "vanilla":
244
+ self.transformer = nn.Sequential(*[
245
+ Block(dim=self.dim, num_heads=8, mlp_ratio=self.mlp_ratio, drop_path=dropout)
246
+ for _ in range(num_layers)
247
+ ])
248
+ elif moe_type == "shared":
249
+ # Create one vanilla block and share it across num_layers iterations.
250
+ vanilla_block = Block(dim=self.dim, num_heads=8, mlp_ratio=self.mlp_ratio, drop_path=dropout)
251
+ self.transformer = SharedTransformer(vanilla_block, num_layers)
252
+ else:
253
+ # Create Transformer blocks with MoE
254
+ self.transformer = nn.Sequential(*[
255
+ MoEBlock(
256
+ dim=self.dim,
257
+ num_heads=8,
258
+ mlp_ratio=self.mlp_ratio,
259
+ drop_path=dropout,
260
+ num_experts=self.num_experts,
261
+ num_shared_experts=self.num_shared_experts,
262
+ top_k=self.top_k
263
+ ) for _ in range(num_layers)
264
+ ])
265
+
266
+ self.heatmap_head = nn.Sequential(
267
+ nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
268
+ nn.Conv2d(dim, 1, kernel_size=1, bias=False),
269
+ nn.Sigmoid()
270
+ )
271
+
272
+ self.head_token = nn.Embedding(1, self.dim)
273
+ if self.inout:
274
+ self.inout_head = nn.Sequential(
275
+ nn.Linear(self.dim, 128),
276
+ nn.ReLU(),
277
+ nn.Dropout(0.1),
278
+ nn.Linear(128, 1),
279
+ nn.Sigmoid()
280
+ )
281
+ self.inout_token = nn.Embedding(1, self.dim)
282
+
283
+ def forward(self, input):
284
+ num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]]
285
+ # Multi-scale features and fusion
286
+ feats = self.backbone.forward(input["images"])
287
+ x = self.ms_fusion(feats) # [B, dim, featmap_h, featmap_w]
288
+
289
+ x = x + self.pos_embed
290
+ x = repeat_tensors(x, num_ppl_per_img)
291
+ head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device)
292
+ head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1)
293
+ x = x + head_map_embeddings
294
+ x = x.flatten(start_dim=2).permute(0, 2, 1)
295
+
296
+ if self.inout:
297
+ x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1)
298
+
299
+ x = self.transformer(x)
300
+
301
+ if self.inout:
302
+ inout_tokens = x[:, 0, :]
303
+ inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1)
304
+ inout_preds = split_tensors(inout_preds, num_ppl_per_img)
305
+ x = x[:, 1:, :]
306
+
307
+ x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2)
308
+ x = self.heatmap_head(x).squeeze(dim=1)
309
+ x = torchvision.transforms.functional.resize(x, self.out_size)
310
+ heatmap_preds = split_tensors(x, num_ppl_per_img)
311
+
312
+ return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None}
313
+
314
+ def get_input_head_maps(self, bboxes):
315
+ head_maps = []
316
+ for bbox_list in bboxes:
317
+ img_head_maps = []
318
+ for bbox in bbox_list:
319
+ if bbox is None:
320
+ img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w))
321
+ else:
322
+ xmin, ymin, xmax, ymax = bbox
323
+ width, height = self.featmap_w, self.featmap_h
324
+ xmin = round(xmin * width)
325
+ ymin = round(ymin * height)
326
+ xmax = round(xmax * width)
327
+ ymax = round(ymax * height)
328
+ head_map = torch.zeros((height, width))
329
+ head_map[ymin:ymax, xmin:xmax] = 1
330
+ img_head_maps.append(head_map)
331
+ head_maps.append(torch.stack(img_head_maps))
332
+ return head_maps
333
+
334
+ def get_gazemoe_state_dict(self, include_backbone=False):
335
+ if include_backbone:
336
+ return self.state_dict()
337
+ else:
338
+ return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")}
339
+
340
+ def load_gazemoe_state_dict(self, ckpt_state_dict, include_backbone=False):
341
+ current_state_dict = self.state_dict()
342
+ keys1 = current_state_dict.keys()
343
+ keys2 = ckpt_state_dict.keys()
344
+
345
+ if not include_backbone:
346
+ keys1 = set([k for k in keys1 if not k.startswith("backbone")])
347
+ keys2 = set([k for k in keys2 if not k.startswith("backbone")])
348
+ else:
349
+ keys1 = set(keys1)
350
+ keys2 = set(keys2)
351
+
352
+ if len(keys2 - keys1) > 0:
353
+ print("WARNING unused keys in provided state dict: ", keys2 - keys1)
354
+ if len(keys1 - keys2) > 0:
355
+ print("WARNING provided state dict does not have values for keys: ", keys1 - keys2)
356
+
357
+ for k in list(keys1 & keys2):
358
+ current_state_dict[k] = ckpt_state_dict[k]
359
+
360
+ self.load_state_dict(current_state_dict, strict=False)
361
+
362
+
363
+ def positionalencoding2d(d_model, height, width):
364
+ if d_model % 4 != 0:
365
+ raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(d_model))
366
+ pe = torch.zeros(d_model, height, width)
367
+ d_model_half = d_model // 2
368
+ div_term = torch.exp(torch.arange(0., d_model_half, 2) * -(math.log(10000.0) / d_model_half))
369
+ pos_w = torch.arange(0., width).unsqueeze(1)
370
+ pos_h = torch.arange(0., height).unsqueeze(1)
371
+ pe[0:d_model_half:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
372
+ pe[1:d_model_half:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
373
+ pe[d_model_half::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
374
+ pe[d_model_half + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
375
+ return pe
376
+
377
+
378
+ def gazemoe_dinov2_vitl14_inout(bbtype, d_model, mlp_ratio, num_layers, num_experts, num_shared_experts, top_k, dropout,
379
+ moe_type, is_msf):
380
+ if bbtype == "DINOv2":
381
+ backbone = DinoV2BackboneMultiScale('dinov2_vitl14', num_scales=is_msf)
382
+ else:
383
+ raise TypeError("backbone not supported!")
384
+
385
+ transform = backbone.get_transform((448, 448))
386
+ model = GazeMoE(backbone, inout=True, dim=d_model, mlp_ratio=mlp_ratio, num_layers=num_layers,
387
+ num_experts=num_experts,
388
+ num_shared_experts=num_shared_experts, top_k=top_k, dropout=dropout,
389
+ moe_type=moe_type, is_msf=is_msf)
390
+ return model, transform
391
+
392
+
393
+ def get_gazemoe_model(configuration=None):
394
+ factory = {
395
+ "gazemoe_dinov2_vitl14_inout": gazemoe_dinov2_vitl14_inout,
396
+ }
397
+ return factory["gazemoe_dinov2_vitl14_inout"](
398
+ bbtype='DINOv2',
399
+ d_model=256,
400
+ mlp_ratio=1,
401
+ num_layers=3,
402
+ num_experts=4,
403
+ num_shared_experts=1,
404
+ top_k=2,
405
+ dropout=0.1,
406
+ moe_type='moe',
407
+ is_msf=1,
408
+ )