File size: 2,146 Bytes
c02d17f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import random
import torch
import torch.nn.functional as F
class CLIPSelf:
def __call__(self, batch, model, dist_P_VLM, dist_model, loss, device, cast_dtype, distributed, args):
if distributed:
model = model.module
dist_model = dist_model.module
images, normed_boxes, image_crops = batch # note texts are not paired with images
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
normed_boxes = normed_boxes.to(device=device, dtype=cast_dtype, non_blocking=True)
image_crops = image_crops.to(device=device, dtype=cast_dtype, non_blocking=True)
if args.multiscale:
cur_h, cur_w = images.shape[2:]
assert cur_h == cur_w
if cur_h == 1024:
tar_sizes = [320, 640, 896, 1024]
elif cur_h == 896:
tar_sizes = [336, 448, 672, 896]
else:
raise NotImplementedError
tar_size = random.choice(tar_sizes)
images = F.interpolate(images, size=(tar_size, tar_size), mode='bilinear')
rois_list = []
crops_list = []
for bboxes_per_image, crops_per_image in zip(normed_boxes, image_crops):
valid = bboxes_per_image[:, -1] > 0.5
rois_list.append(bboxes_per_image[valid, :4])
crops_list.append(crops_per_image[valid])
image_crops = torch.cat(crops_list)
with torch.no_grad():
teacher_crop_features = dist_model.encode_image(image_crops, normalize=False)
student_roi_features = model.encode_pseudo_boxes(images, rois_list, normalize=False,
extract_type=args.extract_type)
normed_student_features = F.normalize(student_roi_features, dim=-1)
normed_teacher_features = F.normalize(teacher_crop_features, dim=-1)
loss_cosine = 1.0 - (normed_student_features *
normed_teacher_features).sum(-1).mean()
losses = dict(loss_cosine=loss_cosine*args.cosine_weight)
return losses, len(images), model.logit_scale.exp()
|