Upload DINOSAUR.py
Browse files- slots/DINOSAUR.py +78 -5
slots/DINOSAUR.py
CHANGED
|
@@ -221,7 +221,78 @@ class Decoder(nn.Module):
|
|
| 221 |
|
| 222 |
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
| 223 |
|
| 224 |
-
return slot_maps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
class ISA(nn.Module):
|
|
@@ -517,7 +588,9 @@ class DINOSAURpp(nn.Module):
|
|
| 517 |
else:
|
| 518 |
self.slot_encoder = SA(args, input_dim=1024)
|
| 519 |
|
| 520 |
-
self.slot_decoder = Decoder(args)
|
|
|
|
|
|
|
| 521 |
|
| 522 |
self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
|
| 523 |
init.normal_(self.pos_dec, mean=0., std=.02)
|
|
@@ -572,17 +645,17 @@ class DINOSAURpp(nn.Module):
|
|
| 572 |
rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
|
| 573 |
|
| 574 |
slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
|
| 575 |
-
slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
| 576 |
|
| 577 |
else:
|
| 578 |
slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
|
| 579 |
assert torch.sum(torch.isnan(slots)) == 0
|
| 580 |
|
| 581 |
slot_maps, pos_maps = self.sbd_slots(slots)
|
| 582 |
-
slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
| 583 |
|
| 584 |
reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
|
| 585 |
|
| 586 |
-
return reconstruction, slots, masks
|
| 587 |
|
| 588 |
|
|
|
|
| 221 |
|
| 222 |
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
| 223 |
|
| 224 |
+
return slot_maps, slot_maps
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Decoder_to_DINOV2(nn.Module):
|
| 228 |
+
def __init__(self, args):
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
# === Token calculations ===
|
| 232 |
+
slot_dim = args['slot_dim']
|
| 233 |
+
hidden_dim = 2048
|
| 234 |
+
|
| 235 |
+
# === MLP Based Decoder ===
|
| 236 |
+
self.layer1 = nn.Linear(slot_dim, hidden_dim)
|
| 237 |
+
self.layer2 = nn.Linear(hidden_dim, hidden_dim)
|
| 238 |
+
self.layer3 = nn.Linear(hidden_dim, hidden_dim)
|
| 239 |
+
self.layer4 = nn.Linear(hidden_dim, 1024 + 1)
|
| 240 |
+
|
| 241 |
+
self.layer_to_dinov2 = nn.Linear(hidden_dim, 768)
|
| 242 |
+
self.relu = nn.ReLU(inplace=True)
|
| 243 |
+
|
| 244 |
+
def forward(self, slot_maps):
|
| 245 |
+
# :arg slot_maps: (B * S, token, D_slot)
|
| 246 |
+
slot_maps = self.relu(self.layer1(slot_maps)) # (B * S, token, D_hidden)
|
| 247 |
+
x_dinov2 = self.layer_to_dinov2(slot_maps)
|
| 248 |
+
slot_maps = self.relu(self.layer2(slot_maps)) # (B * S, token, D_hidden)
|
| 249 |
+
slot_maps = self.relu(self.layer3(slot_maps)) # (B * S, token, D_hidden)
|
| 250 |
+
|
| 251 |
+
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
| 252 |
+
|
| 253 |
+
return slot_maps, x_dinov2
|
| 254 |
+
|
| 255 |
+
from torch.nn.init import trunc_normal_
|
| 256 |
+
class DINOHead(nn.Module):
|
| 257 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=768):
|
| 258 |
+
super().__init__()
|
| 259 |
+
nlayers = max(nlayers, 1)
|
| 260 |
+
if nlayers == 1:
|
| 261 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
| 262 |
+
else:
|
| 263 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
| 264 |
+
if use_bn:
|
| 265 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 266 |
+
layers.append(nn.GELU())
|
| 267 |
+
for _ in range(nlayers - 2):
|
| 268 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 269 |
+
if use_bn:
|
| 270 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 271 |
+
layers.append(nn.GELU())
|
| 272 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
| 273 |
+
self.mlp = nn.Sequential(*layers)
|
| 274 |
+
self.apply(self._init_weights)
|
| 275 |
+
self.gelu = nn.GELU()
|
| 276 |
+
self.last_layer1 = nn.Linear(bottleneck_dim, bottleneck_dim)
|
| 277 |
+
self.last_layer2 = nn.Linear(bottleneck_dim, out_dim)
|
| 278 |
+
|
| 279 |
+
# self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 280 |
+
# self.last_layer.weight_g.data.fill_(1)
|
| 281 |
+
# if norm_last_layer:
|
| 282 |
+
# self.last_layer.weight_g.requires_grad = False
|
| 283 |
+
|
| 284 |
+
def _init_weights(self, m):
|
| 285 |
+
if isinstance(m, nn.Linear):
|
| 286 |
+
trunc_normal_(m.weight, std=.02)
|
| 287 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 288 |
+
nn.init.constant_(m.bias, 0)
|
| 289 |
+
|
| 290 |
+
def forward(self, x):
|
| 291 |
+
x_dinov2 = self.mlp(x)
|
| 292 |
+
# x = nn.functional.normalize(x, dim=-1, p=2)
|
| 293 |
+
x = self.gelu(self.last_layer1(x_dinov2))
|
| 294 |
+
x = self.last_layer2(x)
|
| 295 |
+
return x, x_dinov2
|
| 296 |
|
| 297 |
|
| 298 |
class ISA(nn.Module):
|
|
|
|
| 588 |
else:
|
| 589 |
self.slot_encoder = SA(args, input_dim=1024)
|
| 590 |
|
| 591 |
+
self.slot_decoder = Decoder(args) #ori easy mlp
|
| 592 |
+
# self.slot_decoder = DINOHead(in_dim=256, out_dim=1024+1, nlayers=3, bottleneck_dim=768) #ori easy mlp
|
| 593 |
+
# self.slot_decoder = Decoder_to_DINOV2(args) #ori easy mlp
|
| 594 |
|
| 595 |
self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
|
| 596 |
init.normal_(self.pos_dec, mean=0., std=.02)
|
|
|
|
| 645 |
rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
|
| 646 |
|
| 647 |
slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
|
| 648 |
+
slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
| 649 |
|
| 650 |
else:
|
| 651 |
slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
|
| 652 |
assert torch.sum(torch.isnan(slots)) == 0
|
| 653 |
|
| 654 |
slot_maps, pos_maps = self.sbd_slots(slots)
|
| 655 |
+
slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
| 656 |
|
| 657 |
reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
|
| 658 |
|
| 659 |
+
return reconstruction, slots, masks, x_dinov2
|
| 660 |
|
| 661 |
|