import torch import torch.nn as nn import utils from utils import trunc_normal_ class CSyncBatchNorm(nn.SyncBatchNorm): def __init__(self, *args, with_var=False, **kwargs): super(CSyncBatchNorm, self).__init__(*args, **kwargs) self.with_var = with_var def forward(self, x): # center norm self.training = False if not self.with_var: self.running_var = torch.ones_like(self.running_var) normed_x = super(CSyncBatchNorm, self).forward(x) # udpate center self.training = True _ = super(CSyncBatchNorm, self).forward(x) return normed_x class PSyncBatchNorm(nn.SyncBatchNorm): def __init__(self, *args, bunch_size, **kwargs): procs_per_bunch = min(bunch_size, utils.get_world_size()) assert utils.get_world_size() % procs_per_bunch == 0 n_bunch = utils.get_world_size() // procs_per_bunch # ranks = list(range(utils.get_world_size())) print('---ALL RANKS----\n{}'.format(ranks)) rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)] print('---RANK GROUPS----\n{}'.format(rank_groups)) process_groups = [torch.distributed.new_group(pids) for pids in rank_groups] bunch_id = utils.get_rank() // procs_per_bunch process_group = process_groups[bunch_id] print('---CURRENT GROUP----\n{}'.format(process_group)) super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs) class CustomSequential(nn.Sequential): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) def forward(self, input): for module in self: dim = len(input.shape) if isinstance(module, self.bn_types) and dim > 2: perm = list(range(dim - 1)); perm.insert(1, dim - 1) inv_perm = list(range(dim)) + [1]; inv_perm.pop(1) input = module(input.permute(*perm)).permute(*inv_perm) else: input = module(input) return input class DINOHead(nn.Module): def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs): super().__init__() norm = self._build_norm(norm, hidden_dim) last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs) act = self._build_act(act) nlayers = max(nlayers, 1) if nlayers == 1: if bottleneck_dim > 0: self.mlp = nn.Linear(in_dim, bottleneck_dim) else: self.mlp = nn.Linear(in_dim, out_dim) else: layers = [nn.Linear(in_dim, hidden_dim)] if norm is not None: layers.append(norm) layers.append(act) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) if norm is not None: layers.append(norm) layers.append(act) if bottleneck_dim > 0: layers.append(nn.Linear(hidden_dim, bottleneck_dim)) else: layers.append(nn.Linear(hidden_dim, out_dim)) self.mlp = CustomSequential(*layers) self.apply(self._init_weights) if bottleneck_dim > 0: self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) if norm_last_layer: self.last_layer.weight_g.requires_grad = False else: self.last_layer = None self.last_norm = last_norm def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) if self.last_layer is not None: x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) if self.last_norm is not None: x = self.last_norm(x) return x def _build_norm(self, norm, hidden_dim, **kwargs): if norm == 'bn': norm = nn.BatchNorm1d(hidden_dim, **kwargs) elif norm == 'syncbn': norm = nn.SyncBatchNorm(hidden_dim, **kwargs) elif norm == 'csyncbn': norm = CSyncBatchNorm(hidden_dim, **kwargs) elif norm == 'psyncbn': norm = PSyncBatchNorm(hidden_dim, **kwargs) elif norm == 'ln': norm = nn.LayerNorm(hidden_dim, **kwargs) else: assert norm is None, "unknown norm type {}".format(norm) return norm def _build_act(self, act): if act == 'relu': act = nn.ReLU() elif act == 'gelu': act = nn.GELU() else: assert False, "unknown act type {}".format(act) return act class iBOTHead(DINOHead): def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, shared_head=False, **kwargs): super(iBOTHead, self).__init__(*args, norm=norm, act=act, last_norm=last_norm, nlayers=nlayers, hidden_dim=hidden_dim, bottleneck_dim=bottleneck_dim, norm_last_layer=norm_last_layer, **kwargs) if not shared_head: if bottleneck_dim > 0: self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False)) self.last_layer2.weight_g.data.fill_(1) if norm_last_layer: self.last_layer2.weight_g.requires_grad = False else: self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) self.last_layer2 = None self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs) else: if bottleneck_dim > 0: self.last_layer2 = self.last_layer else: self.mlp2 = self.mlp[-1] self.last_layer2 = None self.last_norm2 = self.last_norm def forward(self, x): if len(x.shape) == 2: return super(iBOTHead, self).forward(x) if self.last_layer is not None: x = self.mlp(x) x = nn.functional.normalize(x, dim=-1, p=2) x1 = self.last_layer(x[:, 0]) x2 = self.last_layer2(x[:, 1:]) else: x = self.mlp[:-1](x) x1 = self.mlp[-1](x[:, 0]) x2 = self.mlp2(x[:, 1:]) if self.last_norm is not None: x1 = self.last_norm(x1) x2 = self.last_norm2(x2) return x1, x2 class TemporalSideContext(nn.Module): def __init__(self, D, max_len=64, n_layers=6, n_head=8, dropout=0.1): super().__init__() #self.pos_t = nn.Embedding(max_len, D) # learnable embedding for positions layer = nn.TransformerEncoderLayer(D, n_head, 4*D, dropout=dropout, batch_first=True) self.enc = nn.TransformerEncoder(layer, n_layers) def forward(self, x): # x [B,T,D] B,T,D = x.shape device = x.device # Generate relative frame positions [0, 1, ..., T-1] #pos_ids = torch.arange(T, device=device).unsqueeze(0) # [1, T] #pos_embed = self.pos_t(pos_ids) # [1, T, D] #x = x + pos_embed return self.enc(x) # [B,T,D] class TemporalHead(nn.Module): """ Converts backbone features [B,T,D] → logits [B,T,1] for Plackett–Luce. """ def __init__(self, backbone_dim: int, hidden_mul: float = 0.5, max_len: int = 64): super().__init__() hidden_dim = int(backbone_dim * hidden_mul) self.reduce = nn.Sequential( nn.Linear(backbone_dim, hidden_dim), nn.GELU() ) self.temporal = TemporalSideContext(hidden_dim, max_len=max_len) self.scorer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x: torch.Tensor): # x : [B,T,D] x = self.reduce(x) # [B,T,hidden] x = self.temporal(x) # [B,T,hidden] return self.scorer(x) # [B,T,1]