| | import argparse |
| | import os |
| | import sys |
| | import torch |
| | import shutil |
| | import utils |
| | import numpy as np |
| | import metrics |
| | import torch.nn as nn |
| | import torch.backends.cudnn as cudnn |
| | import vision_transformer as vits |
| |
|
| | from tqdm import tqdm |
| | from datasets import build_dataset |
| | from meters import AverageEpochMeter |
| | from object_discovery import ncut, detect_box, get_feats |
| | from visualization import visualize_fms, visualize_predictions_gt, visualize_img, visualize_eigvec |
| |
|
| | from collections import OrderedDict |
| | |
| |
|
| | def evaluate(): |
| | if args.device != 'cuda': |
| | args.distributed = False |
| | else: |
| | utils.init_distributed_mode(args) |
| |
|
| | print(args) |
| |
|
| | device = torch.device(args.device) |
| | cudnn.benchmark = True |
| | |
| | |
| | if args.arch in vits.__dict__.keys(): |
| | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=args.num_labels) |
| | embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) |
| | print(f'embed_dim: {embed_dim}') |
| | else: |
| | print(f"Unknow architecture: {args.arch}") |
| | sys.exit(1) |
| | |
| | model.to(device) |
| | model.eval() |
| |
|
| | linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) |
| | linear_classifier = linear_classifier.to(device) |
| |
|
| | |
| | checkpoint_model = torch.load(args.pretrained_weights, map_location='cpu') |
| | if args.classifier_weights is None: |
| | linear_classifier.load_state_dict(checkpoint_model['classifier']) |
| | model.load_state_dict(checkpoint_model['model'], strict=False) |
| | else: |
| | model.load_state_dict(checkpoint_model) |
| | checkpoint_classifier = torch.load(args.classifier_weights, map_location='cpu') |
| | new_state_dict = OrderedDict() |
| | new_state_dict = {k[7:]: v for k, v in checkpoint_classifier['state_dict'].items()} |
| | |
| | linear_classifier.load_state_dict(new_state_dict) |
| | print('Load from checkpoint done.') |
| | |
| | dataset_val, _ = build_dataset(is_train=False, args=args) |
| |
|
| | if args.distributed: |
| | sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) |
| | else: |
| | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| |
|
| | val_loader = torch.utils.data.DataLoader( |
| | dataset_val, |
| | sampler=sampler_val, |
| | batch_size=args.batch_size_per_gpu, |
| | num_workers=args.num_workers, |
| | pin_memory=True, |
| | ) |
| | |
| | |
| | num_batches = len(val_loader) |
| | scales = [args.patch_size, args.patch_size] |
| | top1_cls_meter = AverageEpochMeter('Top-1 Cls') |
| | top5_cls_meter = AverageEpochMeter('Top-5 Cls') |
| | gt_loc_meter = AverageEpochMeter('GT-Known Loc') |
| | top1_loc_meter = AverageEpochMeter('Top-1 Loc') |
| | feat_out = {} |
| | bbox_error = 0 |
| | cls_error = 0 |
| | skip = 0 |
| | def hook_fn_forward_qkv(module, input, output): |
| | feat_out["qkv"] = output |
| | model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) |
| | val_loader = tqdm(val_loader) |
| | for i, (images, labels, gt_boxes) in enumerate(val_loader): |
| | init_image_size = images[0].shape |
| | images = utils.padding_img(images[0], args) |
| | images = images.unsqueeze(0) |
| | if images.shape[-1] > 1000 and images.shape[-2] > 1000: |
| | skip = skip + 1 |
| | continue |
| | with torch.no_grad(): |
| | images = images.to(device) |
| | labels = labels.to(device) |
| | |
| | intermediate_output, shape= model.get_intermediate_layers(images, args.n_last_blocks) |
| | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
| | if args.avgpool_patchtokens: |
| | output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
| | output = output.reshape(output.shape[0], -1) |
| | output = linear_classifier(output) |
| | |
| | batch_size = images.size(0) |
| | top1_cls, top5_cls = utils.accuracy(output, labels, topk=(1, 5)) |
| | top1_cls_meter.update(top1_cls, batch_size) |
| | top5_cls_meter.update(top5_cls, batch_size) |
| |
|
| |
|
| |
|
| | |
| | dims = (images[0].shape[-2] // args.patch_size, images[0].shape[-1] // args.patch_size) |
| | k = get_feats(feat_out, shape) |
| | bboxes, mask, seed, eigvec= ncut(k, dims, scales, init_image_size=init_image_size[1:], eps=args.eps, tau=args.tau) |
| | bboxes_copy = np.copy(bboxes) |
| | bboxes = torch.FloatTensor(bboxes).unsqueeze(0) |
| | gt_loc, top1_loc = metrics.loc_accuracy(output, labels, gt_boxes, bboxes) |
| | |
| | if args.visual: |
| | visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/all') |
| | visualize_eigvec(eigvec, './output/all', i, dims, scales) |
| |
|
| | |
| | if top1_loc == 0 and gt_loc == 0: |
| | bbox_error = bbox_error + 1 |
| | if args.visual: |
| | |
| | visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/bbox_error') |
| | |
| | elif top1_loc == 0 and gt_loc == 1: |
| | cls_error = cls_error +1 |
| | if args.visual: |
| | visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/classification_error') |
| |
|
| | gt_loc_meter.update(gt_loc, batch_size) |
| | top1_loc_meter.update(top1_loc, batch_size) |
| | top1_cls = top1_cls_meter.compute() |
| | top5_cls = top5_cls_meter.compute() |
| | gt_loc = gt_loc_meter.compute() |
| | top1_loc = top1_loc_meter.compute() |
| | val_loader.set_description(f'Top1_cls: {top1_cls:.4f}, top5_cls{top5_cls:.4f}, gt_loc: {gt_loc:.4f}, top1_loc:{top1_loc:.4f}') |
| | print(f'Top1_cls: {top1_cls}, top5_cls{top5_cls}, gt_loc: {gt_loc}, top1_loc:{top1_loc}') |
| | print(f'Bbox error: {bbox_error}, cls error: {cls_error}') |
| | print(f'Skip large image: {skip}') |
| | |
| | |
| | class LinearClassifier(nn.Module): |
| | """Linear layer to train on top of frozen features""" |
| | def __init__(self, dim, num_labels=1000): |
| | super(LinearClassifier, self).__init__() |
| | self.num_labels = num_labels |
| | self.linear = nn.Linear(dim, num_labels) |
| | self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| | self.linear.bias.data.zero_() |
| |
|
| | def forward(self, x): |
| | |
| | x = x.view(x.size(0), -1) |
| | |
| | return self.linear(x) |
| |
|
| | if __name__=='__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens |
| | for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") |
| | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, |
| | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. |
| | We typically set this to False for ViT-Small and to True with ViT-Base.""") |
| | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') |
| | parser.add_argument('--dataset', default='cub', type=str, choices=['cub', 'imagenet'], help='Architecture') |
| | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') |
| | parser.add_argument('--input_size', default=224, type=int, help='Input image size, default(224).') |
| | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
| | parser.add_argument('--classifier_weights', default=None, type=str, help="Path to linear classifier pretrained weights to evaluate.") |
| | parser.add_argument('--batch_size_per_gpu', default=256, type=int, help='Per-GPU batch-size') |
| | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
| | distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
| | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
| | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) |
| | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
| | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") |
| | parser.add_argument('--num_labels', default=200, type=int, help='Number of labels for linear classifier') |
| | parser.add_argument('--device', default='cuda', help='device to use for training / testing') |
| | parser.add_argument('--distributed', default=False, action='store_true', help='device to use for training / testing') |
| | parser.add_argument('--ori_size', default=False, action='store_true', help='Evaluate on image raw size') |
| | parser.add_argument('--visual', default=False, action='store_true', help='Visualize error examples on ./test') |
| | parser.add_argument('--eps', default=1e-5, type=float, help='hyperparameter for tokencut') |
| | parser.add_argument('--tau', default=0.05, type=float, help='hyperparamter for tokencut') |
| | parser.add_argument('--no_center_crop', default=False, action='store_true', help='Center crop input image') |
| | args = parser.parse_args() |
| | evaluate() |
| |
|