DenseVLM / src /training /zero_shot.py
lyhisme's picture
Upload 151 files
c02d17f verified
import logging
import torch
import torch.nn.functional as F
from training.dist_utils import all_gather
from tqdm import tqdm
from .distributed import is_master
from open_clip import get_cast_dtype
from .precision import get_autocast
def run(model, dataloader, args):
cls_embeddings = dataloader.dataset.embeddings
cls_embeddings = F.normalize(torch.from_numpy(cls_embeddings).float(), dim=-1)
cls_embeddings = cls_embeddings.to(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
if cast_dtype is not None:
cls_embeddings = cls_embeddings.to(dtype=cast_dtype)
with torch.no_grad():
correct_rois = []
correct_maskpool = []
correct_crops = []
similarity_crops = []
similarity_rois = []
similarity_maskpool = []
all_box_sizes = []
all_is_thing = []
all_cls_labels = []
for images, bboxes, image_crops, gt_masks, masked_image_crops \
in tqdm(dataloader, disable=not is_master(args)):
images = images.to(args.device)
bboxes = bboxes.to(args.device)
image_crops = image_crops.to(args.device)
masked_image_crops = masked_image_crops.to(args.device)
gt_masks = gt_masks.to(args.device)
if cast_dtype is not None:
images = images.to(dtype=cast_dtype)
bboxes = bboxes.to(dtype=cast_dtype)
image_crops = image_crops.to(dtype=cast_dtype)
masked_image_crops = masked_image_crops.to(dtype=cast_dtype)
gt_masks = gt_masks.to(dtype=cast_dtype)
image_crops_list = []
gt_masks_list = []
cls_labels = []
rois = []
box_sizes = []
is_thing = []
for bboxes_per_image, crops_per_image, gt_mask, masked_crops_per_image \
in zip(bboxes, image_crops, gt_masks, masked_image_crops):
valid = bboxes_per_image[:, 5] > 0.5
rois.append(bboxes_per_image[valid, :4])
cls_labels.append(bboxes_per_image[valid, 4])
image_crops_list.append(crops_per_image[valid])
gt_masks_list.append(gt_mask[valid])
box_sizes.append(bboxes_per_image[valid, 6])
is_thing.append(bboxes_per_image[valid, 7])
cls_labels = torch.cat(cls_labels, dim=0).to(torch.long)
if cls_labels.shape[0] == 0:
continue
image_crops = torch.cat(image_crops_list)
box_sizes = torch.cat(box_sizes, dim=0).float()
is_thing = torch.cat(is_thing, dim=0)
all_box_sizes.append(box_sizes)
all_is_thing.append(is_thing)
with autocast():
# predict
if args.distributed and not args.horovod:
module = model.module
else:
module = model
roi_extractor = module.encode_pseudo_boxes
roi_features = roi_extractor(images, rois, normalize=True,
extract_type=args.extract_type)
mask_pooler = module.encode_masks
maskpool_features = mask_pooler(images, gt_masks_list,
normalize=True, mask_attn=args.extract_type == "v1")
# New way to obtain crop features
if args.image_ave_pool:
feature_map = module.visual.encode_dense(image_crops, keep_shape=True)
crop_features = feature_map.mean(dim=(-2, -1))
crop_features = F.normalize(crop_features, dim=-1)
else:
crop_features = module.encode_image(image_crops, normalize=True)
if cast_dtype is not None:
roi_features = roi_features.to(dtype=cast_dtype)
crop_features = crop_features.to(dtype=cast_dtype)
maskpool_features = maskpool_features.to(dtype=cast_dtype)
roi_logits = roi_features @ cls_embeddings.T
crop_logits = crop_features @ cls_embeddings.T
maskpool_logits = maskpool_features @ cls_embeddings.T
_, roi_top5_inds = roi_logits.topk(5)
_, crop_top5_inds = crop_logits.topk(5)
_, maskpool_top5_inds = maskpool_logits.topk(5)
correct_rois.append(roi_top5_inds == cls_labels.view(-1, 1))
correct_crops.append(crop_top5_inds == cls_labels.view(-1, 1))
correct_maskpool.append(maskpool_top5_inds == cls_labels.view(-1, 1))
similarity_rois.append(torch.gather(roi_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])
similarity_crops.append(torch.gather(crop_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])
similarity_maskpool.append(torch.gather(maskpool_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])
all_cls_labels.append(cls_labels)
# TODO: gather correct matrix across gpus
correct_rois = torch.cat(correct_rois).float()
correct_crops = torch.cat(correct_crops).float()
correct_maskpool = torch.cat(correct_maskpool).float()
similarity_rois = torch.cat(similarity_rois).float()
similarity_crops = torch.cat(similarity_crops).float()
similarity_maskpool = torch.cat(similarity_maskpool).float()
all_box_sizes = torch.cat(all_box_sizes)
all_is_thing = torch.cat(all_is_thing)
all_cls_labels = torch.cat(all_cls_labels)
if args.distributed and not args.horovod:
correct_rois = multi_gpu_sync(correct_rois)
correct_crops = multi_gpu_sync(correct_crops)
correct_maskpool = multi_gpu_sync(correct_maskpool)
all_box_sizes = multi_gpu_sync(all_box_sizes)
all_is_thing = multi_gpu_sync(all_is_thing)
similarity_rois = multi_gpu_sync(similarity_rois)
similarity_crops = multi_gpu_sync(similarity_crops)
similarity_maskpool = multi_gpu_sync(similarity_maskpool)
all_cls_labels = multi_gpu_sync(all_cls_labels)
return correct_rois, correct_crops, correct_maskpool, \
similarity_rois, similarity_crops, similarity_maskpool, \
all_box_sizes, all_is_thing, all_cls_labels
def multi_gpu_sync(x):
device = x.device
x_list = all_gather(x.cpu())
x = torch.cat([res.to(device) for res in x_list])
return x
def macc_with_is_thing(correct_matrix, is_thing, all_cls_labels, prefix):
def _macc(corrects, cls_labels):
min_id = cls_labels.min().item()
max_id = cls_labels.max().item()
cand_labels = list(range(min_id, max_id+1))
acc_per_cls = []
for lb in cand_labels:
corrects_per_cls = corrects[cls_labels == lb]
if corrects_per_cls.shape[0] == 0:
continue
acc_per_cls.append(corrects_per_cls.mean().half().item())
return sum(acc_per_cls) / len(acc_per_cls)
results = {}
thing_correct_matrix = correct_matrix[is_thing > 0]
stuff_correct_matrix = correct_matrix[is_thing < 1]
thing_cls_labels = all_cls_labels[is_thing > 0].long()
stuff_cls_labels = all_cls_labels[is_thing < 1].long()
thing_top1_acc = _macc(thing_correct_matrix[:, 0], thing_cls_labels)
thing_top5_acc = _macc(thing_correct_matrix.sum(-1), thing_cls_labels)
stuff_top1_acc = _macc(stuff_correct_matrix[:, 0], stuff_cls_labels)
stuff_top5_acc = _macc(stuff_correct_matrix.sum(-1), stuff_cls_labels)
results[f'{prefix}.thing.macc1'] = thing_top1_acc
results[f'{prefix}.thing.macc5'] = thing_top5_acc
results[f'{prefix}.stuff.macc1'] = stuff_top1_acc
results[f'{prefix}.stuff.macc5'] = stuff_top5_acc
return results
def zero_shot_eval(model, data, epoch, args):
if 'val' not in data:
return {}
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
logging.info('Region classifier')
results = {}
correct_rois, correct_crops, correct_maskpool, \
similarity_rois, similarity_crops, similarity_maskpool, \
all_box_sizes, all_is_thing, all_cls_labels = run(model, data['val'].dataloader, args)
results.update(macc_with_is_thing(correct_rois, all_is_thing, all_cls_labels, 'rois'))
results.update(macc_with_is_thing(correct_crops, all_is_thing, all_cls_labels, 'crops'))
results.update(macc_with_is_thing(correct_maskpool, all_is_thing, all_cls_labels, 'maskpool'))
return results