PL-Stitch / models /head.py
chengan98's picture
Upload folder using huggingface_hub
5e1f805 verified
import torch
import torch.nn as nn
import utils
from utils import trunc_normal_
class CSyncBatchNorm(nn.SyncBatchNorm):
def __init__(self,
*args,
with_var=False,
**kwargs):
super(CSyncBatchNorm, self).__init__(*args, **kwargs)
self.with_var = with_var
def forward(self, x):
# center norm
self.training = False
if not self.with_var:
self.running_var = torch.ones_like(self.running_var)
normed_x = super(CSyncBatchNorm, self).forward(x)
# udpate center
self.training = True
_ = super(CSyncBatchNorm, self).forward(x)
return normed_x
class PSyncBatchNorm(nn.SyncBatchNorm):
def __init__(self,
*args,
bunch_size,
**kwargs):
procs_per_bunch = min(bunch_size, utils.get_world_size())
assert utils.get_world_size() % procs_per_bunch == 0
n_bunch = utils.get_world_size() // procs_per_bunch
#
ranks = list(range(utils.get_world_size()))
print('---ALL RANKS----\n{}'.format(ranks))
rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)]
print('---RANK GROUPS----\n{}'.format(rank_groups))
process_groups = [torch.distributed.new_group(pids) for pids in rank_groups]
bunch_id = utils.get_rank() // procs_per_bunch
process_group = process_groups[bunch_id]
print('---CURRENT GROUP----\n{}'.format(process_group))
super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs)
class CustomSequential(nn.Sequential):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
def forward(self, input):
for module in self:
dim = len(input.shape)
if isinstance(module, self.bn_types) and dim > 2:
perm = list(range(dim - 1)); perm.insert(1, dim - 1)
inv_perm = list(range(dim)) + [1]; inv_perm.pop(1)
input = module(input.permute(*perm)).permute(*inv_perm)
else:
input = module(input)
return input
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None,
nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs):
super().__init__()
norm = self._build_norm(norm, hidden_dim)
last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs)
act = self._build_act(act)
nlayers = max(nlayers, 1)
if nlayers == 1:
if bottleneck_dim > 0:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
self.mlp = nn.Linear(in_dim, out_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if norm is not None:
layers.append(norm)
layers.append(act)
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if norm is not None:
layers.append(norm)
layers.append(act)
if bottleneck_dim > 0:
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
else:
layers.append(nn.Linear(hidden_dim, out_dim))
self.mlp = CustomSequential(*layers)
self.apply(self._init_weights)
if bottleneck_dim > 0:
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
else:
self.last_layer = None
self.last_norm = last_norm
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
if self.last_layer is not None:
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
if self.last_norm is not None:
x = self.last_norm(x)
return x
def _build_norm(self, norm, hidden_dim, **kwargs):
if norm == 'bn':
norm = nn.BatchNorm1d(hidden_dim, **kwargs)
elif norm == 'syncbn':
norm = nn.SyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'csyncbn':
norm = CSyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'psyncbn':
norm = PSyncBatchNorm(hidden_dim, **kwargs)
elif norm == 'ln':
norm = nn.LayerNorm(hidden_dim, **kwargs)
else:
assert norm is None, "unknown norm type {}".format(norm)
return norm
def _build_act(self, act):
if act == 'relu':
act = nn.ReLU()
elif act == 'gelu':
act = nn.GELU()
else:
assert False, "unknown act type {}".format(act)
return act
class iBOTHead(DINOHead):
def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None,
nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True,
shared_head=False, **kwargs):
super(iBOTHead, self).__init__(*args,
norm=norm,
act=act,
last_norm=last_norm,
nlayers=nlayers,
hidden_dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
norm_last_layer=norm_last_layer,
**kwargs)
if not shared_head:
if bottleneck_dim > 0:
self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False))
self.last_layer2.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer2.weight_g.requires_grad = False
else:
self.mlp2 = nn.Linear(hidden_dim, patch_out_dim)
self.last_layer2 = None
self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs)
else:
if bottleneck_dim > 0:
self.last_layer2 = self.last_layer
else:
self.mlp2 = self.mlp[-1]
self.last_layer2 = None
self.last_norm2 = self.last_norm
def forward(self, x):
if len(x.shape) == 2:
return super(iBOTHead, self).forward(x)
if self.last_layer is not None:
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x1 = self.last_layer(x[:, 0])
x2 = self.last_layer2(x[:, 1:])
else:
x = self.mlp[:-1](x)
x1 = self.mlp[-1](x[:, 0])
x2 = self.mlp2(x[:, 1:])
if self.last_norm is not None:
x1 = self.last_norm(x1)
x2 = self.last_norm2(x2)
return x1, x2
class TemporalSideContext(nn.Module):
def __init__(self, D, max_len=64, n_layers=6, n_head=8, dropout=0.1):
super().__init__()
#self.pos_t = nn.Embedding(max_len, D) # learnable embedding for positions
layer = nn.TransformerEncoderLayer(D, n_head, 4*D,
dropout=dropout, batch_first=True)
self.enc = nn.TransformerEncoder(layer, n_layers)
def forward(self, x): # x [B,T,D]
B,T,D = x.shape
device = x.device
# Generate relative frame positions [0, 1, ..., T-1]
#pos_ids = torch.arange(T, device=device).unsqueeze(0) # [1, T]
#pos_embed = self.pos_t(pos_ids) # [1, T, D]
#x = x + pos_embed
return self.enc(x) # [B,T,D]
class TemporalHead(nn.Module):
"""
Converts backbone features [B,T,D] → logits [B,T,1] for Plackett–Luce.
"""
def __init__(self, backbone_dim: int, hidden_mul: float = 0.5, max_len: int = 64):
super().__init__()
hidden_dim = int(backbone_dim * hidden_mul)
self.reduce = nn.Sequential(
nn.Linear(backbone_dim, hidden_dim),
nn.GELU()
)
self.temporal = TemporalSideContext(hidden_dim, max_len=max_len)
self.scorer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x: torch.Tensor): # x : [B,T,D]
x = self.reduce(x) # [B,T,hidden]
x = self.temporal(x) # [B,T,hidden]
return self.scorer(x) # [B,T,1]