| |
| |
| |
| |
|
|
| import argparse |
| from functools import partial |
| import json |
| import logging |
| import os |
| import sys |
| from typing import List, Optional |
|
|
| import torch |
| from torch.nn.functional import one_hot, softmax |
|
|
| import dinov2.distributed as distributed |
| from dinov2.data import SamplerType, make_data_loader, make_dataset |
| from dinov2.data.transforms import make_classification_eval_transform |
| from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric |
| from dinov2.eval.setup import get_args_parser as get_setup_args_parser |
| from dinov2.eval.setup import setup_and_build_model |
| from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features |
|
|
|
|
| logger = logging.getLogger("dinov2") |
|
|
|
|
| def get_args_parser( |
| description: Optional[str] = None, |
| parents: Optional[List[argparse.ArgumentParser]] = None, |
| add_help: bool = True, |
| ): |
| parents = parents or [] |
| setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) |
| parents = [setup_args_parser] |
| parser = argparse.ArgumentParser( |
| description=description, |
| parents=parents, |
| add_help=add_help, |
| ) |
| parser.add_argument( |
| "--train-dataset", |
| dest="train_dataset_str", |
| type=str, |
| help="Training dataset", |
| ) |
| parser.add_argument( |
| "--val-dataset", |
| dest="val_dataset_str", |
| type=str, |
| help="Validation dataset", |
| ) |
| parser.add_argument( |
| "--nb_knn", |
| nargs="+", |
| type=int, |
| help="Number of NN to use. 20 is usually working the best.", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| help="Temperature used in the voting coefficient", |
| ) |
| parser.add_argument( |
| "--gather-on-cpu", |
| action="store_true", |
| help="Whether to gather the train features on cpu, slower" |
| "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| help="Batch size.", |
| ) |
| parser.add_argument( |
| "--n-per-class-list", |
| nargs="+", |
| type=int, |
| help="Number to take per class", |
| ) |
| parser.add_argument( |
| "--n-tries", |
| type=int, |
| help="Number of tries", |
| ) |
| parser.set_defaults( |
| train_dataset_str="ImageNet:split=TRAIN", |
| val_dataset_str="ImageNet:split=VAL", |
| nb_knn=[10, 20, 100, 200], |
| temperature=0.07, |
| batch_size=256, |
| n_per_class_list=[-1], |
| n_tries=1, |
| ) |
| return parser |
|
|
|
|
| class KnnModule(torch.nn.Module): |
| """ |
| Gets knn of test features from all processes on a chunk of the train features |
| |
| Each rank gets a chunk of the train features as well as a chunk of the test features. |
| In `compute_neighbors`, for each rank one after the other, its chunk of test features |
| is sent to all devices, partial knns are computed with each chunk of train features |
| then collated back on the original device. |
| """ |
|
|
| def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): |
| super().__init__() |
|
|
| self.global_rank = distributed.get_global_rank() |
| self.global_size = distributed.get_global_size() |
|
|
| self.device = device |
| self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) |
| self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) |
|
|
| self.nb_knn = nb_knn |
| self.max_k = max(self.nb_knn) |
| self.T = T |
| self.num_classes = num_classes |
|
|
| def _get_knn_sims_and_labels(self, similarity, train_labels): |
| topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) |
| neighbors_labels = torch.gather(train_labels, 1, indices) |
| return topk_sims, neighbors_labels |
|
|
| def _similarity_for_rank(self, features_rank, source_rank): |
| |
| broadcast_shape = torch.tensor(features_rank.shape).to(self.device) |
| torch.distributed.broadcast(broadcast_shape, source_rank) |
|
|
| broadcasted = features_rank |
| if self.global_rank != source_rank: |
| broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) |
| torch.distributed.broadcast(broadcasted, source_rank) |
|
|
| |
| similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) |
| candidate_labels = self.candidates.expand(len(similarity_rank), -1) |
| return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) |
|
|
| def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): |
| |
| topk_sims_rank = retrieved_rank = None |
| if self.global_rank == target_rank: |
| topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] |
| retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] |
|
|
| torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) |
| torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) |
|
|
| if self.global_rank == target_rank: |
| |
| topk_sims_rank = torch.cat(topk_sims_rank, dim=1) |
| retrieved_rank = torch.cat(retrieved_rank, dim=1) |
| results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) |
| return results |
| return None |
|
|
| def compute_neighbors(self, features_rank): |
| for rank in range(self.global_size): |
| topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) |
| results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) |
| if results is not None: |
| topk_sims_rank, neighbors_labels_rank = results |
| return topk_sims_rank, neighbors_labels_rank |
|
|
| def forward(self, features_rank): |
| """ |
| Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` |
| """ |
| assert all(k <= self.max_k for k in self.nb_knn) |
|
|
| topk_sims, neighbors_labels = self.compute_neighbors(features_rank) |
| batch_size = neighbors_labels.shape[0] |
| topk_sims_transform = softmax(topk_sims / self.T, 1) |
| matmul = torch.mul( |
| one_hot(neighbors_labels, num_classes=self.num_classes), |
| topk_sims_transform.view(batch_size, -1, 1), |
| ) |
| probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} |
| return probas_for_k |
|
|
|
|
| class DictKeysModule(torch.nn.Module): |
| def __init__(self, keys): |
| super().__init__() |
| self.keys = keys |
|
|
| def forward(self, features_dict, targets): |
| for k in self.keys: |
| features_dict = features_dict[k] |
| return {"preds": features_dict, "target": targets} |
|
|
|
|
| def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): |
| modules = {} |
| mapping = create_class_indices_mapping(train_labels) |
| for npc in n_per_class_list: |
| if npc < 0: |
| full_module = module( |
| train_features=train_features, |
| train_labels=train_labels, |
| nb_knn=nb_knn, |
| ) |
| modules["full"] = ModuleDictWithForward({"1": full_module}) |
| continue |
| all_tries = {} |
| for t in range(n_tries): |
| final_indices = filter_train(mapping, npc, seed=t) |
| k_list = list(set(nb_knn + [npc])) |
| k_list = sorted([el for el in k_list if el <= npc]) |
| all_tries[str(t)] = module( |
| train_features=train_features[final_indices], |
| train_labels=train_labels[final_indices], |
| nb_knn=k_list, |
| ) |
| modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) |
|
|
| return ModuleDictWithForward(modules) |
|
|
|
|
| def filter_train(mapping, n_per_class, seed): |
| torch.manual_seed(seed) |
| final_indices = [] |
| for k in mapping.keys(): |
| index = torch.randperm(len(mapping[k]))[:n_per_class] |
| final_indices.append(mapping[k][index]) |
| return torch.cat(final_indices).squeeze() |
|
|
|
|
| def create_class_indices_mapping(labels): |
| unique_labels, inverse = torch.unique(labels, return_inverse=True) |
| mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} |
| return mapping |
|
|
|
|
| class ModuleDictWithForward(torch.nn.ModuleDict): |
| def forward(self, *args, **kwargs): |
| return {k: module(*args, **kwargs) for k, module in self._modules.items()} |
|
|
|
|
| def eval_knn( |
| model, |
| train_dataset, |
| val_dataset, |
| accuracy_averaging, |
| nb_knn, |
| temperature, |
| batch_size, |
| num_workers, |
| gather_on_cpu, |
| n_per_class_list=[-1], |
| n_tries=1, |
| ): |
| model = ModelWithNormalize(model) |
|
|
| logger.info("Extracting features for train set...") |
| train_features, train_labels = extract_features( |
| model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu |
| ) |
| logger.info(f"Train features created, shape {train_features.shape}.") |
|
|
| val_dataloader = make_data_loader( |
| dataset=val_dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| sampler_type=SamplerType.DISTRIBUTED, |
| drop_last=False, |
| shuffle=False, |
| persistent_workers=True, |
| ) |
| num_classes = train_labels.max() + 1 |
| metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) |
|
|
| device = torch.cuda.current_device() |
| partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) |
| knn_module_dict = create_module_dict( |
| module=partial_module, |
| n_per_class_list=n_per_class_list, |
| n_tries=n_tries, |
| nb_knn=nb_knn, |
| train_features=train_features, |
| train_labels=train_labels, |
| ) |
| postprocessors, metrics = {}, {} |
| for n_per_class, knn_module in knn_module_dict.items(): |
| for t, knn_try in knn_module.items(): |
| postprocessors = { |
| **postprocessors, |
| **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, |
| } |
| metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} |
| model_with_knn = torch.nn.Sequential(model, knn_module_dict) |
|
|
| |
| logger.info("Start the k-NN classification.") |
| _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) |
|
|
| |
| for n_per_class, knn_module in knn_module_dict.items(): |
| first_try = list(knn_module.keys())[0] |
| k_list = knn_module[first_try].nb_knn |
| for k in k_list: |
| keys = results_dict[(n_per_class, first_try, k)].keys() |
| results_dict[(n_per_class, k)] = { |
| key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) |
| for key in keys |
| } |
| for t in knn_module.keys(): |
| del results_dict[(n_per_class, t, k)] |
|
|
| return results_dict |
|
|
|
|
| def eval_knn_with_model( |
| model, |
| output_dir, |
| train_dataset_str="ImageNet:split=TRAIN", |
| val_dataset_str="ImageNet:split=VAL", |
| nb_knn=(10, 20, 100, 200), |
| temperature=0.07, |
| autocast_dtype=torch.float, |
| accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, |
| transform=None, |
| gather_on_cpu=False, |
| batch_size=256, |
| num_workers=5, |
| n_per_class_list=[-1], |
| n_tries=1, |
| ): |
| transform = transform or make_classification_eval_transform() |
|
|
| train_dataset = make_dataset( |
| dataset_str=train_dataset_str, |
| transform=transform, |
| ) |
| val_dataset = make_dataset( |
| dataset_str=val_dataset_str, |
| transform=transform, |
| ) |
|
|
| with torch.cuda.amp.autocast(dtype=autocast_dtype): |
| results_dict_knn = eval_knn( |
| model=model, |
| train_dataset=train_dataset, |
| val_dataset=val_dataset, |
| accuracy_averaging=accuracy_averaging, |
| nb_knn=nb_knn, |
| temperature=temperature, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| gather_on_cpu=gather_on_cpu, |
| n_per_class_list=n_per_class_list, |
| n_tries=n_tries, |
| ) |
|
|
| results_dict = {} |
| if distributed.is_main_process(): |
| for knn_ in results_dict_knn.keys(): |
| top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 |
| top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 |
| results_dict[f"{knn_} Top 1"] = top1 |
| results_dict[f"{knn_} Top 5"] = top5 |
| logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") |
|
|
| metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") |
| with open(metrics_file_path, "a") as f: |
| for k, v in results_dict.items(): |
| f.write(json.dumps({k: v}) + "\n") |
|
|
| if distributed.is_enabled(): |
| torch.distributed.barrier() |
| return results_dict |
|
|
|
|
| def main(args): |
| model, autocast_dtype = setup_and_build_model(args) |
| eval_knn_with_model( |
| model=model, |
| output_dir=args.output_dir, |
| train_dataset_str=args.train_dataset_str, |
| val_dataset_str=args.val_dataset_str, |
| nb_knn=args.nb_knn, |
| temperature=args.temperature, |
| autocast_dtype=autocast_dtype, |
| accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, |
| transform=None, |
| gather_on_cpu=args.gather_on_cpu, |
| batch_size=args.batch_size, |
| num_workers=5, |
| n_per_class_list=args.n_per_class_list, |
| n_tries=args.n_tries, |
| ) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| description = "DINOv2 k-NN evaluation" |
| args_parser = get_args_parser(description=description) |
| args = args_parser.parse_args() |
| sys.exit(main(args)) |
|
|