| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| functions for building multi-granular losses. |
| """ |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from utils import concat_all_gather |
|
|
|
|
| class InfoNCELoss(nn.Module): |
| """ |
| vanilla infoNCEloss. |
| --ncrops: how many crops are used in student networks |
| --dim: feature dimension in queue determinted by output dimention of student network |
| --queue_size: queue size |
| --temperature: temperature parameter for infoNCEloss |
| """ |
|
|
| def __init__(self, ncrops, dim=256, queue_size=65536, temperature=0.2): |
| super().__init__() |
| self.queue_size = queue_size |
| self.temperature = temperature |
|
|
| self.register_buffer("queue", torch.randn(dim, queue_size)) |
| self.queue = nn.functional.normalize(self.queue, dim=0) |
|
|
| self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
| self.CrossEntropyLoss = nn.CrossEntropyLoss() |
| self.ncrops = ncrops |
|
|
| @torch.no_grad() |
| def _dequeue_and_enqueue(self, keys): |
| """ |
| queue update |
| """ |
| keys = concat_all_gather(keys) |
| batch_size = keys.shape[0] |
| ptr = int(self.queue_ptr) |
| |
| if ptr + batch_size <= self.queue_size: |
| self.queue[:, ptr : ptr + batch_size] = keys.T |
| ptr = (ptr + batch_size) % self.queue_size |
| else: |
| keys_t = keys.T |
| queue_remaining_size = self.queue_size - ptr |
| self.queue[:, ptr:] = keys_t[:, :queue_remaining_size] |
| self.queue[:, : batch_size - queue_remaining_size] = keys_t[ |
| :, queue_remaining_size: |
| ] |
|
|
| ptr = batch_size - queue_remaining_size |
|
|
| self.queue_ptr[0] = ptr |
|
|
| |
| def forward(self, student_output, teacher_output, epoch): |
| """ |
| Cross-entropy between softmax outputs of the teacher and student networks. |
| """ |
| preds = student_output.chunk(self.ncrops) |
| targets = teacher_output.detach().chunk(2) |
| small_crop_loss, large_crop_loss = 0, 0 |
| small_loss_terms, large_loss_terms = 0, 0 |
| queue_feat = self.queue.clone().detach() |
|
|
| for t_idx, targ in enumerate(targets): |
| for p_idx, pred in enumerate(preds): |
| if t_idx == p_idx: |
| continue |
| |
| l_pos = torch.einsum("nc,nc->n", [pred, targ]).unsqueeze(-1) |
| |
| l_neg = torch.einsum("nc,ck->nk", [pred, queue_feat]) |
| |
| logits = torch.cat([l_pos, l_neg], dim=1) |
| |
| logits /= self.temperature |
| |
| labels = torch.zeros(logits.shape[0], dtype=torch.long).to( |
| logits.device |
| ) |
| loss = self.CrossEntropyLoss(logits, labels) |
| if p_idx < 2: |
| large_crop_loss += loss |
| large_loss_terms += 1 |
| else: |
| small_crop_loss += loss |
| small_loss_terms += 1 |
| |
| self._dequeue_and_enqueue(targ) |
|
|
| large_crop_loss /= large_loss_terms |
| small_crop_loss /= small_loss_terms |
| loss = 0.5 * (large_crop_loss + small_crop_loss) |
| return loss |
|
|
|
|
| class ClusteringLoss(nn.Module): |
| """ |
| Clustering loss which is very simialr to the one in DINO |
| --out_dim: center dimension determinted by output dimention of student network |
| --ncrops: how many crops are used in student networks |
| --warmup_teacher_temp: Initial value for the teacher temperature |
| --teacher_temp: Final value (after linear warmup) of the teacher temperature |
| --warmup_teacher_temp_epochs: Number of warmup epochs for the teacher temperature |
| --nepochs: total training epoch |
| --student_temp: temperature parameter in student output |
| --center_momentum: EMA parameter for center update |
| """ |
|
|
| def __init__( |
| self, |
| out_dim, |
| ncrops, |
| warmup_teacher_temp, |
| teacher_temp, |
| warmup_teacher_temp_epochs, |
| nepochs, |
| student_temp=0.1, |
| center_momentum=0.9, |
| ): |
| super().__init__() |
| self.student_temp = student_temp |
| self.center_momentum = center_momentum |
| self.ncrops = ncrops |
| self.register_buffer("center", torch.zeros(1, out_dim)) |
| |
| |
| self.teacher_temp_schedule = np.concatenate( |
| ( |
| np.linspace( |
| warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs |
| ), |
| np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp, |
| ) |
| ) |
|
|
| def forward(self, student_output, teacher_output, epoch): |
| """ |
| Cross-entropy between softmax outputs of the teacher and student networks. |
| """ |
| student_out = student_output / self.student_temp |
| student_out = student_out.chunk(self.ncrops) |
|
|
| |
| temp = self.teacher_temp_schedule[epoch] |
| teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) |
| teacher_out = teacher_out.detach().chunk(2) |
|
|
| loss_large_crop, loss_small_crop = 0.0, 0.0 |
| loss_terms_large_crop, loss_terms_small_crop = 0, 0 |
| for iq, q in enumerate(teacher_out): |
| for v in range(len(student_out)): |
| if v == iq: |
| |
| continue |
| loss = torch.sum( |
| -q * F.log_softmax(student_out[v], dim=-1), dim=-1 |
| ).mean() |
| if v < 2: |
| loss_large_crop += loss |
| loss_terms_large_crop += 1 |
| else: |
| loss_small_crop += loss |
| loss_terms_small_crop += 1 |
|
|
| self.update_center(teacher_output) |
| loss_large_crop /= loss_terms_large_crop |
| loss_small_crop /= loss_terms_small_crop |
| total_loss = 0.5 * (loss_large_crop + loss_small_crop) |
| return total_loss |
|
|
| @torch.no_grad() |
| def update_center(self, teacher_output): |
| """ |
| Update center used for teacher output. |
| """ |
| batch_center = torch.mean(teacher_output, dim=0, keepdim=False) |
| dist.all_reduce(batch_center) |
| batch_center = batch_center / dist.get_world_size() |
|
|
| |
| self.center = self.center * self.center_momentum + batch_center * ( |
| 1 - self.center_momentum |
| ) |
|
|
|
|
| def get_multi_granular_loss(args): |
| """ |
| build the multi-granular loss |
| """ |
| all_losses, all_weights = {}, {} |
|
|
| |
| instance_supervision_loss = InfoNCELoss( |
| args.local_crops_number + 2, |
| dim=args.instance_out_dim, |
| queue_size=args.instance_queue_size, |
| temperature=args.instance_temp, |
| ).cuda() |
| all_losses["instance-sup."] = instance_supervision_loss |
| all_weights["instance-sup."] = args.loss_weights[0] |
|
|
| |
| local_group_supervision = InfoNCELoss( |
| args.local_crops_number + 2, |
| dim=args.local_group_out_dim, |
| queue_size=args.local_group_queue_size, |
| temperature=args.local_group_temp, |
| ).cuda() |
| all_losses["local-group-sup."] = local_group_supervision |
| all_weights["local-group-sup."] = args.loss_weights[1] |
|
|
| |
| group_loss = ClusteringLoss( |
| args.group_out_dim, |
| args.local_crops_number |
| + 2, |
| args.group_warmup_teacher_temp, |
| args.group_teacher_temp, |
| args.group_warmup_teacher_temp_epochs, |
| args.epochs, |
| student_temp=args.group_student_temp, |
| center_momentum=0.9, |
| ).cuda() |
| all_losses["group-sup."] = group_loss |
| all_weights["group-sup."] = args.loss_weights[2] |
| return all_losses, all_weights |
|
|