| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import sys |
| import pickle |
| import argparse |
|
|
| import torch |
| from torch import nn |
| import torch.distributed as dist |
| import torch.backends.cudnn as cudnn |
| from torchvision import models as torchvision_models |
| from torchvision import transforms as pth_transforms |
| from PIL import Image, ImageFile |
| import numpy as np |
|
|
| import utils |
| import vision_transformer as vits |
| from eval_knn import extract_features |
|
|
|
|
| class CopydaysDataset(): |
| def __init__(self, basedir): |
| self.basedir = basedir |
| self.block_names = ( |
| ['original', 'strong'] + |
| ['jpegqual/%d' % i for i in |
| [3, 5, 8, 10, 15, 20, 30, 50, 75]] + |
| ['crops/%d' % i for i in |
| [10, 15, 20, 30, 40, 50, 60, 70, 80]]) |
| self.nblocks = len(self.block_names) |
|
|
| self.query_blocks = range(self.nblocks) |
| self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157 |
| self.q_block_sizes[1] = 229 |
| |
| self.database_blocks = [0] |
|
|
| def get_block(self, i): |
| dirname = self.basedir + '/' + self.block_names[i] |
| fnames = [dirname + '/' + fname |
| for fname in sorted(os.listdir(dirname)) |
| if fname.endswith('.jpg')] |
| return fnames |
|
|
| def get_block_filenames(self, subdir_name): |
| dirname = self.basedir + '/' + subdir_name |
| return [fname |
| for fname in sorted(os.listdir(dirname)) |
| if fname.endswith('.jpg')] |
|
|
| def eval_result(self, ids, distances): |
| j0 = 0 |
| for i in range(self.nblocks): |
| j1 = j0 + self.q_block_sizes[i] |
| block_name = self.block_names[i] |
| I = ids[j0:j1] |
| sum_AP = 0 |
| if block_name != 'strong': |
| |
| positives_per_query = [[i] for i in range(j1 - j0)] |
| else: |
| originals = self.get_block_filenames('original') |
| strongs = self.get_block_filenames('strong') |
|
|
| |
| positives_per_query = [ |
| [j for j, bname in enumerate(originals) |
| if bname[:4] == qname[:4]] |
| for qname in strongs] |
|
|
| for qno, Iline in enumerate(I): |
| positives = positives_per_query[qno] |
| ranks = [] |
| for rank, bno in enumerate(Iline): |
| if bno in positives: |
| ranks.append(rank) |
| sum_AP += score_ap_from_ranks_1(ranks, len(positives)) |
|
|
| print("eval on %s mAP=%.3f" % ( |
| block_name, sum_AP / (j1 - j0))) |
| j0 = j1 |
|
|
|
|
| |
| def score_ap_from_ranks_1(ranks, nres): |
| """ Compute the average precision of one search. |
| ranks = ordered list of ranks of true positives |
| nres = total number of positives in dataset |
| """ |
|
|
| |
| ap = 0.0 |
|
|
| |
| recall_step = 1.0 / nres |
|
|
| for ntp, rank in enumerate(ranks): |
|
|
| |
| |
| |
| if rank == 0: |
| precision_0 = 1.0 |
| else: |
| precision_0 = ntp / float(rank) |
|
|
| |
| |
| precision_1 = (ntp + 1) / float(rank + 1) |
|
|
| ap += (precision_1 + precision_0) * recall_step / 2.0 |
|
|
| return ap |
|
|
|
|
| class ImgListDataset(torch.utils.data.Dataset): |
| def __init__(self, img_list, transform=None): |
| self.samples = img_list |
| self.transform = transform |
|
|
| def __getitem__(self, i): |
| with open(self.samples[i], 'rb') as f: |
| img = Image.open(f) |
| img = img.convert('RGB') |
| if self.transform is not None: |
| img = self.transform(img) |
| return img, i |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
|
|
| def is_image_file(s): |
| ext = s.split(".")[-1] |
| if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']: |
| return True |
| return False |
|
|
|
|
| @torch.no_grad() |
| def extract_features(image_list, model, args): |
| transform = pth_transforms.Compose([ |
| pth_transforms.Resize((args.imsize, args.imsize), interpolation=3), |
| pth_transforms.ToTensor(), |
| pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
| tempdataset = ImgListDataset(image_list, transform=transform) |
| data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu, |
| num_workers=args.num_workers, drop_last=False, |
| sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False)) |
| features = None |
| for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10): |
| samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True) |
| feats = model.get_intermediate_layers(samples, n=1)[0].clone() |
|
|
| cls_output_token = feats[:, 0, :] |
| |
| b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1] |
| feats = feats[:, 1:, :].reshape(b, h, w, d) |
| feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2) |
| feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1) |
| |
| feats = torch.cat((cls_output_token, feats), dim=1) |
|
|
| |
| if dist.get_rank() == 0 and features is None: |
| features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) |
| if args.use_cuda: |
| features = features.cuda(non_blocking=True) |
|
|
| |
| y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) |
| y_l = list(y_all.unbind(0)) |
| y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) |
| y_all_reduce.wait() |
| index_all = torch.cat(y_l) |
|
|
| |
| feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1), |
| dtype=feats.dtype, device=feats.device) |
| output_l = list(feats_all.unbind(0)) |
| output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) |
| output_all_reduce.wait() |
|
|
| |
| if dist.get_rank() == 0: |
| if args.use_cuda: |
| features.index_copy_(0, index_all, torch.cat(output_l)) |
| else: |
| features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) |
| return features |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser('Copy detection on Copydays') |
| parser.add_argument('--data_path', default='/path/to/copydays/', type=str, |
| help="See https://lear.inrialpes.fr/~jegou/data.php#copydays") |
| parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str, |
| help="""Path to directory with images used for computing the whitening operator. |
| In our paper, we use 20k random images from YFCC100M.""") |
| parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str, |
| help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.") |
| parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)') |
| parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size') |
| parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
| parser.add_argument('--use_cuda', default=True, type=utils.bool_flag) |
| parser.add_argument('--arch', default='vit_base', type=str, help='Architecture') |
| parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.') |
| parser.add_argument("--checkpoint_key", default="teacher", type=str, |
| help='Key to use in the checkpoint (example: "teacher")') |
| parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
| parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
| distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
| parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
| args = parser.parse_args() |
|
|
| utils.init_distributed_mode(args) |
| print("git:\n {}\n".format(utils.get_sha())) |
| print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) |
| cudnn.benchmark = True |
|
|
| |
| if "vit" in args.arch: |
| model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) |
| print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") |
| else: |
| print(f"Architecture {args.arch} non supported") |
| sys.exit(1) |
| if args.use_cuda: |
| model.cuda() |
| model.eval() |
| utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) |
|
|
| dataset = CopydaysDataset(args.data_path) |
|
|
| |
| |
| queries = [] |
| for q in dataset.query_blocks: |
| queries.append(extract_features(dataset.get_block(q), model, args)) |
| if utils.get_rank() == 0: |
| queries = torch.cat(queries) |
| print(f"Extraction of queries features done. Shape: {queries.shape}") |
|
|
| |
| database = [] |
| for b in dataset.database_blocks: |
| database.append(extract_features(dataset.get_block(b), model, args)) |
|
|
| |
| if os.path.isdir(args.distractors_path): |
| print("Using distractors...") |
| list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)] |
| database.append(extract_features(list_distractors, model, args)) |
| if utils.get_rank() == 0: |
| database = torch.cat(database) |
| print(f"Extraction of database and distractors features done. Shape: {database.shape}") |
|
|
| |
| if os.path.isdir(args.whitening_path): |
| print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.") |
| list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)] |
| features_for_whitening = extract_features(list_whit, model, args) |
| if utils.get_rank() == 0: |
| |
| mean_feature = torch.mean(features_for_whitening, dim=0) |
| database -= mean_feature |
| queries -= mean_feature |
| pca = utils.PCA(dim=database.shape[-1], whit=0.5) |
| |
| cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0] |
| pca.train_pca(cov.cpu().numpy()) |
| database = pca.apply(database) |
| queries = pca.apply(queries) |
|
|
| |
| if utils.get_rank() == 0: |
| |
| database = nn.functional.normalize(database, dim=1, p=2) |
| queries = nn.functional.normalize(queries, dim=1, p=2) |
|
|
| |
| similarity = torch.mm(queries, database.T) |
| distances, indices = similarity.topk(20, largest=True, sorted=True) |
|
|
| |
| retrieved = dataset.eval_result(indices, distances) |
| dist.barrier() |
|
|
|
|