| |
| |
| |
| |
|
|
| |
| |
| import math |
| import apex |
| import einops |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from torch.nn.init import trunc_normal_ |
| from megatron import get_args, print_rank_0 |
| from megatron.model.utils import get_linear_layer |
| from megatron.model.vision.vit_backbone import VitBackbone |
| from megatron.model.module import MegatronModule |
| from megatron.model.vision.mit_backbone import mit_b5_avg |
| from megatron.model.vision.esvit_swin_backbone import get_swin |
|
|
|
|
| class DINOLoss(torch.nn.Module): |
| def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, |
| warmup_teacher_temp_epochs, nepochs, student_temp=0.1, |
| center_momentum=0.9): |
| super().__init__() |
| self.student_temp = student_temp |
| self.center_momentum = center_momentum |
| self.ncrops = ncrops |
| self.register_buffer("center", torch.zeros(1, out_dim)) |
| |
| |
| self.teacher_temp_schedule = np.concatenate(( |
| np.linspace(warmup_teacher_temp, |
| teacher_temp, warmup_teacher_temp_epochs), |
| np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp |
| )) |
| self.teacher_temp = teacher_temp |
|
|
| def forward(self, student_output, teacher_output, iteration): |
| """ |
| Cross-entropy between softmax outputs of the teacher |
| and student network. |
| """ |
| args = get_args() |
| student_out = student_output / self.student_temp |
| student_out = student_out.chunk(self.ncrops) |
|
|
| epoch = iteration // args.iter_per_epoch |
|
|
| |
| temp = self.teacher_temp_schedule[epoch] |
| teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) |
|
|
| teacher_out = teacher_out.detach().chunk(2) |
|
|
| total_loss = 0 |
| n_loss_terms = 0 |
| for iq, q in enumerate(teacher_out): |
| for v in range(len(student_out)): |
| if v == iq: |
| |
| continue |
| loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) |
| total_loss += loss.mean() |
| n_loss_terms += 1 |
| total_loss /= n_loss_terms |
| self.update_center(teacher_output) |
| return total_loss |
|
|
| @torch.no_grad() |
| def update_center(self, teacher_output): |
| """ |
| Update center used for teacher output. |
| """ |
| batch_center = torch.sum(teacher_output, dim=0, keepdim=True) |
| torch.distributed.all_reduce(batch_center) |
| batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size()) |
| self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) |
|
|
| class DINOHead(torch.nn.Module): |
| def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3): |
| super().__init__() |
| args = get_args() |
| hidden_dim = args.dino_head_hidden_size |
| bottleneck_dim = args.dino_bottleneck_size |
| nlayers = max(nlayers, 1) |
| if nlayers == 1: |
| self.mlp = torch.nn.Linear(in_dim, bottleneck_dim) |
| else: |
| layers = [torch.nn.Linear(in_dim, hidden_dim)] |
| layers.append(torch.nn.GELU()) |
| for _ in range(nlayers - 2): |
| layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) |
| layers.append(torch.nn.GELU()) |
| layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim)) |
| self.mlp = torch.nn.Sequential(*layers) |
| self.apply(self._init_weights) |
| self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False)) |
| self.last_layer.weight_g.data.fill_(1) |
| if norm_last_layer: |
| self.last_layer.weight_g.requires_grad = False |
|
|
| def _init_weights(self, m): |
| if isinstance(m, torch.nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, torch.nn.Linear) and m.bias is not None: |
| torch.nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x): |
| x = self.mlp(x) |
| x = torch.nn.functional.normalize(x, dim=-1, p=2) |
| x = self.last_layer(x) |
| return x |
|
|
|
|
| class MultiCropWrapper(MegatronModule): |
|
|
| """ |
| Perform forward pass separately on each resolution input. |
| The inputs corresponding to a single resolution are clubbed and single |
| forward is run on the same resolution inputs. Hence we do several |
| forward passes = number of different resolutions used. We then |
| concatenate all the output features and run the head forward on these |
| concatenated features. |
| """ |
| def __init__(self, backbone, head): |
| super(MultiCropWrapper, self).__init__() |
| |
| |
| self.backbone = backbone |
| self.head = head |
|
|
| def forward(self, x): |
| |
| if not isinstance(x, list): |
| x = [x] |
| idx_crops = torch.cumsum(torch.unique_consecutive( |
| torch.tensor([inp.shape[-1] for inp in x]), |
| return_counts=True, |
| )[1], 0) |
|
|
| start_idx = 0 |
| for end_idx in idx_crops: |
| _out = self.backbone(torch.cat(x[start_idx: end_idx])) |
| if start_idx == 0: |
| output = _out |
| else: |
| output = torch.cat((output, _out)) |
| start_idx = end_idx |
| |
| if self.training: |
| return self.head(output) |
| else: |
| return output |
|
|
|
|
| def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, |
| warmup_epochs=0, start_warmup_value=0): |
| warmup_schedule = np.array([]) |
| warmup_iters = warmup_epochs * niter_per_ep |
| if warmup_epochs > 0: |
| warmup_schedule = \ |
| np.linspace(start_warmup_value, base_value, warmup_iters) |
|
|
| iters = np.arange(epochs * niter_per_ep - warmup_iters) |
| schedule = final_value + 0.5 * (base_value - final_value) \ |
| * (1 + np.cos(np.pi * iters / len(iters))) |
|
|
| schedule = np.concatenate((warmup_schedule, schedule)) |
| assert len(schedule) == epochs * niter_per_ep |
| return schedule |
|
|
|
|
| def get_student_backbone_and_num_features(pre_process=True, post_process=True): |
| args = get_args() |
|
|
| if args.vision_backbone_type == 'vit': |
| student = VitBackbone(pre_process=pre_process, |
| post_process=post_process, |
| drop_path_rate=0.1, |
| single_token_output=True) |
| num_features = args.hidden_size |
| elif args.vision_backbone_type == 'mit': |
| student = mit_b5_avg(drop_path_rate=0.1) |
| num_features = 512 |
| elif args.vision_backbone_type == 'swin': |
| student = get_swin() |
| num_features = student.num_features |
| else: |
| raise Exception('{} vision backbone is not supported.'.format( |
| args.vision_backbone_type)) |
| |
| return student, num_features |
|
|
| def get_teacher_backbone_and_num_features(pre_process=True, post_process=True): |
| args = get_args() |
|
|
| if args.vision_backbone_type == 'vit': |
| teacher = VitBackbone(pre_process=pre_process, |
| post_process=post_process, |
| single_token_output=True) |
| num_features = args.hidden_size |
| elif args.vision_backbone_type == 'mit': |
| teacher = mit_b5_avg(drop_path_rate=0.0) |
| num_features = 512 |
| elif args.vision_backbone_type == 'swin': |
| teacher = get_swin(is_teacher=True) |
| num_features = teacher.num_features |
| else: |
| raise Exception('{} vision backbone is not supported.'.format( |
| args.vision_backbone_type)) |
| return teacher, num_features |
|
|
|
|
| class DINOPretrainModel(MegatronModule): |
| def __init__(self, pre_process=True, post_process=True): |
| super(DINOPretrainModel, self).__init__() |
| args = get_args() |
| self.out_dim = 65536 |
|
|
| self.dino_loss = DINOLoss( |
| self.out_dim, |
| args.dino_local_crops_number + 2, |
| args.dino_warmup_teacher_temp, |
| args.dino_teacher_temp, |
| args.dino_warmup_teacher_temp_epochs, |
| 300, |
| ) |
|
|
| self.pre_process = pre_process |
| self.post_process = post_process |
| self.momentum_teacher = 0.996 |
|
|
| student_backbone, num_features = \ |
| get_student_backbone_and_num_features(pre_process, post_process) |
|
|
| self.student = MultiCropWrapper( |
| student_backbone, |
| DINOHead(num_features, self.out_dim, |
| norm_last_layer=args.dino_norm_last_layer) |
| ) |
|
|
| self.momentum_schedule = cosine_scheduler( |
| self.momentum_teacher, 1, |
| args.train_iters // args.iter_per_epoch, |
| args.iter_per_epoch |
| ) |
|
|
| teacher_backbone, num_features = \ |
| get_teacher_backbone_and_num_features(pre_process, post_process) |
| self.teacher = MultiCropWrapper( |
| teacher_backbone, |
| DINOHead(num_features, self.out_dim) |
| ) |
| self.teacher.load_state_dict(self.student.state_dict()) |
|
|
| for p in self.teacher.parameters(): |
| if hasattr(p, "requires_grad") and p.requires_grad is not None: |
| p.requires_grad = False |
|
|
| def set_input_tensor(self, tensor): |
| pass |
|
|
| def forward(self, input): |
| student_output = None |
| if self.training: |
| student_output = self.student(input) |
| teacher_output = self.teacher(input[:2]) |
| else: |
| teacher_output = self.teacher(input) |
| return student_output, teacher_output |
|
|
| def cancel_gradients_last_layer(self, iteration): |
| args = get_args() |
| epoch = iteration // args.iter_per_epoch |
| if epoch < args.dino_freeze_last_layer: |
| for n, p in self.student.named_parameters(): |
| if "last_layer" in n: |
| p.grad = None |
|
|
| def update_momentum(self, iteration): |
| with torch.no_grad(): |
| m = self.momentum_schedule[iteration] |
| for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): |
| param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) |
|
|
|
|