| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | from functools import partial |
| | import json |
| | import logging |
| | import os |
| | import sys |
| | from typing import List, Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn.parallel import DistributedDataParallel |
| | from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer |
| |
|
| | from dinov2.data import SamplerType, make_data_loader, make_dataset |
| | from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform |
| | import dinov2.distributed as distributed |
| | from dinov2.eval.metrics import MetricType, build_metric |
| | from dinov2.eval.setup import get_args_parser as get_setup_args_parser |
| | from dinov2.eval.setup import setup_and_build_model |
| | from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate |
| | from dinov2.logging import MetricLogger |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | def get_args_parser( |
| | description: Optional[str] = None, |
| | parents: Optional[List[argparse.ArgumentParser]] = None, |
| | add_help: bool = True, |
| | ): |
| | parents = parents or [] |
| | setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) |
| | parents = [setup_args_parser] |
| | parser = argparse.ArgumentParser( |
| | description=description, |
| | parents=parents, |
| | add_help=add_help, |
| | ) |
| | parser.add_argument( |
| | "--train-dataset", |
| | dest="train_dataset_str", |
| | type=str, |
| | help="Training dataset", |
| | ) |
| | parser.add_argument( |
| | "--val-dataset", |
| | dest="val_dataset_str", |
| | type=str, |
| | help="Validation dataset", |
| | ) |
| | parser.add_argument( |
| | "--test-datasets", |
| | dest="test_dataset_strs", |
| | type=str, |
| | nargs="+", |
| | help="Test datasets, none to reuse the validation dataset", |
| | ) |
| | parser.add_argument( |
| | "--epochs", |
| | type=int, |
| | help="Number of training epochs", |
| | ) |
| | parser.add_argument( |
| | "--batch-size", |
| | type=int, |
| | help="Batch Size (per GPU)", |
| | ) |
| | parser.add_argument( |
| | "--num-workers", |
| | type=int, |
| | help="Number de Workers", |
| | ) |
| | parser.add_argument( |
| | "--epoch-length", |
| | type=int, |
| | help="Length of an epoch in number of iterations", |
| | ) |
| | parser.add_argument( |
| | "--save-checkpoint-frequency", |
| | type=int, |
| | help="Number of epochs between two named checkpoint saves.", |
| | ) |
| | parser.add_argument( |
| | "--eval-period-iterations", |
| | type=int, |
| | help="Number of iterations between two evaluations.", |
| | ) |
| | parser.add_argument( |
| | "--learning-rates", |
| | nargs="+", |
| | type=float, |
| | help="Learning rates to grid search.", |
| | ) |
| | parser.add_argument( |
| | "--no-resume", |
| | action="store_true", |
| | help="Whether to not resume from existing checkpoints", |
| | ) |
| | parser.add_argument( |
| | "--val-metric-type", |
| | type=MetricType, |
| | choices=list(MetricType), |
| | help="Validation metric", |
| | ) |
| | parser.add_argument( |
| | "--test-metric-types", |
| | type=MetricType, |
| | choices=list(MetricType), |
| | nargs="+", |
| | help="Evaluation metric", |
| | ) |
| | parser.add_argument( |
| | "--classifier-fpath", |
| | type=str, |
| | help="Path to a file containing pretrained linear classifiers", |
| | ) |
| | parser.add_argument( |
| | "--val-class-mapping-fpath", |
| | type=str, |
| | help="Path to a file containing a mapping to adjust classifier outputs", |
| | ) |
| | parser.add_argument( |
| | "--test-class-mapping-fpaths", |
| | nargs="+", |
| | type=str, |
| | help="Path to a file containing a mapping to adjust classifier outputs", |
| | ) |
| | parser.set_defaults( |
| | train_dataset_str="ImageNet:split=TRAIN", |
| | val_dataset_str="ImageNet:split=VAL", |
| | test_dataset_strs=None, |
| | epochs=10, |
| | batch_size=128, |
| | num_workers=8, |
| | epoch_length=1250, |
| | save_checkpoint_frequency=20, |
| | eval_period_iterations=1250, |
| | learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], |
| | val_metric_type=MetricType.MEAN_ACCURACY, |
| | test_metric_types=None, |
| | classifier_fpath=None, |
| | val_class_mapping_fpath=None, |
| | test_class_mapping_fpaths=[None], |
| | ) |
| | return parser |
| |
|
| |
|
| | def has_ddp_wrapper(m: nn.Module) -> bool: |
| | return isinstance(m, DistributedDataParallel) |
| |
|
| |
|
| | def remove_ddp_wrapper(m: nn.Module) -> nn.Module: |
| | return m.module if has_ddp_wrapper(m) else m |
| |
|
| |
|
| | def _pad_and_collate(batch): |
| | maxlen = max(len(targets) for image, targets in batch) |
| | padded_batch = [ |
| | (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch |
| | ] |
| | return torch.utils.data.default_collate(padded_batch) |
| |
|
| |
|
| | def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): |
| | intermediate_output = x_tokens_list[-use_n_blocks:] |
| | output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) |
| | if use_avgpool: |
| | output = torch.cat( |
| | ( |
| | output, |
| | torch.mean(intermediate_output[-1][0], dim=1), |
| | ), |
| | dim=-1, |
| | ) |
| | output = output.reshape(output.shape[0], -1) |
| | return output.float() |
| |
|
| |
|
| | class LinearClassifier(nn.Module): |
| | """Linear layer to train on top of frozen features""" |
| |
|
| | def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): |
| | super().__init__() |
| | self.out_dim = out_dim |
| | self.use_n_blocks = use_n_blocks |
| | self.use_avgpool = use_avgpool |
| | self.num_classes = num_classes |
| | self.linear = nn.Linear(out_dim, num_classes) |
| | self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| | self.linear.bias.data.zero_() |
| |
|
| | def forward(self, x_tokens_list): |
| | output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) |
| | return self.linear(output) |
| |
|
| |
|
| | class AllClassifiers(nn.Module): |
| | def __init__(self, classifiers_dict): |
| | super().__init__() |
| | self.classifiers_dict = nn.ModuleDict() |
| | self.classifiers_dict.update(classifiers_dict) |
| |
|
| | def forward(self, inputs): |
| | return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} |
| |
|
| | def __len__(self): |
| | return len(self.classifiers_dict) |
| |
|
| |
|
| | class LinearPostprocessor(nn.Module): |
| | def __init__(self, linear_classifier, class_mapping=None): |
| | super().__init__() |
| | self.linear_classifier = linear_classifier |
| | self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) |
| |
|
| | def forward(self, samples, targets): |
| | preds = self.linear_classifier(samples) |
| | return { |
| | "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, |
| | "target": targets, |
| | } |
| |
|
| |
|
| | def scale_lr(learning_rates, batch_size): |
| | return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 |
| |
|
| |
|
| | def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): |
| | linear_classifiers_dict = nn.ModuleDict() |
| | optim_param_groups = [] |
| | for n in n_last_blocks_list: |
| | for avgpool in [False, True]: |
| | for _lr in learning_rates: |
| | lr = scale_lr(_lr, batch_size) |
| | out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] |
| | linear_classifier = LinearClassifier( |
| | out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes |
| | ) |
| | linear_classifier = linear_classifier.cuda() |
| | linear_classifiers_dict[ |
| | f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") |
| | ] = linear_classifier |
| | optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) |
| |
|
| | linear_classifiers = AllClassifiers(linear_classifiers_dict) |
| | if distributed.is_enabled(): |
| | linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) |
| |
|
| | return linear_classifiers, optim_param_groups |
| |
|
| |
|
| | @torch.no_grad() |
| | def evaluate_linear_classifiers( |
| | feature_model, |
| | linear_classifiers, |
| | data_loader, |
| | metric_type, |
| | metrics_file_path, |
| | training_num_classes, |
| | iteration, |
| | prefixstring="", |
| | class_mapping=None, |
| | best_classifier_on_val=None, |
| | ): |
| | logger.info("running validation !") |
| |
|
| | num_classes = len(class_mapping) if class_mapping is not None else training_num_classes |
| | metric = build_metric(metric_type, num_classes=num_classes) |
| | postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} |
| | metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} |
| |
|
| | _, results_dict_temp = evaluate( |
| | feature_model, |
| | data_loader, |
| | postprocessors, |
| | metrics, |
| | torch.cuda.current_device(), |
| | ) |
| |
|
| | logger.info("") |
| | results_dict = {} |
| | max_accuracy = 0 |
| | best_classifier = "" |
| | for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): |
| | logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") |
| | if ( |
| | best_classifier_on_val is None and metric["top-1"].item() > max_accuracy |
| | ) or classifier_string == best_classifier_on_val: |
| | max_accuracy = metric["top-1"].item() |
| | best_classifier = classifier_string |
| |
|
| | results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} |
| |
|
| | logger.info(f"best classifier: {results_dict['best_classifier']}") |
| |
|
| | if distributed.is_main_process(): |
| | with open(metrics_file_path, "a") as f: |
| | f.write(f"iter: {iteration}\n") |
| | for k, v in results_dict.items(): |
| | f.write(json.dumps({k: v}) + "\n") |
| | f.write("\n") |
| |
|
| | return results_dict |
| |
|
| |
|
| | def eval_linear( |
| | *, |
| | feature_model, |
| | linear_classifiers, |
| | train_data_loader, |
| | val_data_loader, |
| | metrics_file_path, |
| | optimizer, |
| | scheduler, |
| | output_dir, |
| | max_iter, |
| | checkpoint_period, |
| | running_checkpoint_period, |
| | eval_period, |
| | metric_type, |
| | training_num_classes, |
| | resume=True, |
| | classifier_fpath=None, |
| | val_class_mapping=None, |
| | ): |
| | checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) |
| | start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 |
| |
|
| | periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) |
| | iteration = start_iter |
| | logger.info("Starting training from iteration {}".format(start_iter)) |
| | metric_logger = MetricLogger(delimiter=" ") |
| | header = "Training" |
| |
|
| | for data, labels in metric_logger.log_every( |
| | train_data_loader, |
| | 10, |
| | header, |
| | max_iter, |
| | start_iter, |
| | ): |
| | data = data.cuda(non_blocking=True) |
| | labels = labels.cuda(non_blocking=True) |
| |
|
| | features = feature_model(data) |
| | outputs = linear_classifiers(features) |
| |
|
| | losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} |
| | loss = sum(losses.values()) |
| |
|
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| |
|
| | |
| | optimizer.step() |
| | scheduler.step() |
| |
|
| | |
| | if iteration % 10 == 0: |
| | torch.cuda.synchronize() |
| | metric_logger.update(loss=loss.item()) |
| | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| | print("lr", optimizer.param_groups[0]["lr"]) |
| |
|
| | if iteration - start_iter > 5: |
| | if iteration % running_checkpoint_period == 0: |
| | torch.cuda.synchronize() |
| | if distributed.is_main_process(): |
| | logger.info("Checkpointing running_checkpoint") |
| | periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) |
| | torch.cuda.synchronize() |
| | periodic_checkpointer.step(iteration) |
| |
|
| | if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: |
| | _ = evaluate_linear_classifiers( |
| | feature_model=feature_model, |
| | linear_classifiers=remove_ddp_wrapper(linear_classifiers), |
| | data_loader=val_data_loader, |
| | metrics_file_path=metrics_file_path, |
| | prefixstring=f"ITER: {iteration}", |
| | metric_type=metric_type, |
| | training_num_classes=training_num_classes, |
| | iteration=iteration, |
| | class_mapping=val_class_mapping, |
| | ) |
| | torch.cuda.synchronize() |
| |
|
| | iteration = iteration + 1 |
| |
|
| | val_results_dict = evaluate_linear_classifiers( |
| | feature_model=feature_model, |
| | linear_classifiers=remove_ddp_wrapper(linear_classifiers), |
| | data_loader=val_data_loader, |
| | metrics_file_path=metrics_file_path, |
| | metric_type=metric_type, |
| | training_num_classes=training_num_classes, |
| | iteration=iteration, |
| | class_mapping=val_class_mapping, |
| | ) |
| | return val_results_dict, feature_model, linear_classifiers, iteration |
| |
|
| |
|
| | def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): |
| | test_dataset = make_dataset( |
| | dataset_str=test_dataset_str, |
| | transform=make_classification_eval_transform(), |
| | ) |
| | test_data_loader = make_data_loader( |
| | dataset=test_dataset, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | sampler_type=SamplerType.DISTRIBUTED, |
| | drop_last=False, |
| | shuffle=False, |
| | persistent_workers=False, |
| | collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, |
| | ) |
| | return test_data_loader |
| |
|
| |
|
| | def test_on_datasets( |
| | feature_model, |
| | linear_classifiers, |
| | test_dataset_strs, |
| | batch_size, |
| | num_workers, |
| | test_metric_types, |
| | metrics_file_path, |
| | training_num_classes, |
| | iteration, |
| | best_classifier_on_val, |
| | prefixstring="", |
| | test_class_mappings=[None], |
| | ): |
| | results_dict = {} |
| | for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): |
| | logger.info(f"Testing on {test_dataset_str}") |
| | test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) |
| | dataset_results_dict = evaluate_linear_classifiers( |
| | feature_model, |
| | remove_ddp_wrapper(linear_classifiers), |
| | test_data_loader, |
| | metric_type, |
| | metrics_file_path, |
| | training_num_classes, |
| | iteration, |
| | prefixstring="", |
| | class_mapping=class_mapping, |
| | best_classifier_on_val=best_classifier_on_val, |
| | ) |
| | results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] |
| | return results_dict |
| |
|
| |
|
| | def run_eval_linear( |
| | model, |
| | output_dir, |
| | train_dataset_str, |
| | val_dataset_str, |
| | batch_size, |
| | epochs, |
| | epoch_length, |
| | num_workers, |
| | save_checkpoint_frequency, |
| | eval_period_iterations, |
| | learning_rates, |
| | autocast_dtype, |
| | test_dataset_strs=None, |
| | resume=True, |
| | classifier_fpath=None, |
| | val_class_mapping_fpath=None, |
| | test_class_mapping_fpaths=[None], |
| | val_metric_type=MetricType.MEAN_ACCURACY, |
| | test_metric_types=None, |
| | ): |
| | seed = 0 |
| |
|
| | if test_dataset_strs is None: |
| | test_dataset_strs = [val_dataset_str] |
| | if test_metric_types is None: |
| | test_metric_types = [val_metric_type] * len(test_dataset_strs) |
| | else: |
| | assert len(test_metric_types) == len(test_dataset_strs) |
| | assert len(test_dataset_strs) == len(test_class_mapping_fpaths) |
| |
|
| | train_transform = make_classification_train_transform() |
| | train_dataset = make_dataset( |
| | dataset_str=train_dataset_str, |
| | transform=train_transform, |
| | ) |
| | training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) |
| | sampler_type = SamplerType.SHARDED_INFINITE |
| | |
| |
|
| | n_last_blocks_list = [1, 4] |
| | n_last_blocks = max(n_last_blocks_list) |
| | autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) |
| | feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) |
| | sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) |
| |
|
| | linear_classifiers, optim_param_groups = setup_linear_classifiers( |
| | sample_output, |
| | n_last_blocks_list, |
| | learning_rates, |
| | batch_size, |
| | training_num_classes, |
| | ) |
| |
|
| | optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) |
| | max_iter = epochs * epoch_length |
| | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) |
| | checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) |
| | start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 |
| | train_data_loader = make_data_loader( |
| | dataset=train_dataset, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | shuffle=True, |
| | seed=seed, |
| | sampler_type=sampler_type, |
| | sampler_advance=start_iter, |
| | drop_last=True, |
| | persistent_workers=True, |
| | ) |
| | val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) |
| |
|
| | checkpoint_period = save_checkpoint_frequency * epoch_length |
| |
|
| | if val_class_mapping_fpath is not None: |
| | logger.info(f"Using class mapping from {val_class_mapping_fpath}") |
| | val_class_mapping = np.load(val_class_mapping_fpath) |
| | else: |
| | val_class_mapping = None |
| |
|
| | test_class_mappings = [] |
| | for class_mapping_fpath in test_class_mapping_fpaths: |
| | if class_mapping_fpath is not None and class_mapping_fpath != "None": |
| | logger.info(f"Using class mapping from {class_mapping_fpath}") |
| | class_mapping = np.load(class_mapping_fpath) |
| | else: |
| | class_mapping = None |
| | test_class_mappings.append(class_mapping) |
| |
|
| | metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") |
| | val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( |
| | feature_model=feature_model, |
| | linear_classifiers=linear_classifiers, |
| | train_data_loader=train_data_loader, |
| | val_data_loader=val_data_loader, |
| | metrics_file_path=metrics_file_path, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | output_dir=output_dir, |
| | max_iter=max_iter, |
| | checkpoint_period=checkpoint_period, |
| | running_checkpoint_period=epoch_length, |
| | eval_period=eval_period_iterations, |
| | metric_type=val_metric_type, |
| | training_num_classes=training_num_classes, |
| | resume=resume, |
| | val_class_mapping=val_class_mapping, |
| | classifier_fpath=classifier_fpath, |
| | ) |
| | results_dict = {} |
| | if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: |
| | results_dict = test_on_datasets( |
| | feature_model, |
| | linear_classifiers, |
| | test_dataset_strs, |
| | batch_size, |
| | 0, |
| | test_metric_types, |
| | metrics_file_path, |
| | training_num_classes, |
| | iteration, |
| | val_results_dict["best_classifier"]["name"], |
| | prefixstring="", |
| | test_class_mappings=test_class_mappings, |
| | ) |
| | results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] |
| | results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] |
| | logger.info("Test Results Dict " + str(results_dict)) |
| |
|
| | return results_dict |
| |
|
| |
|
| | def main(args): |
| | model, autocast_dtype = setup_and_build_model(args) |
| | run_eval_linear( |
| | model=model, |
| | output_dir=args.output_dir, |
| | train_dataset_str=args.train_dataset_str, |
| | val_dataset_str=args.val_dataset_str, |
| | test_dataset_strs=args.test_dataset_strs, |
| | batch_size=args.batch_size, |
| | epochs=args.epochs, |
| | epoch_length=args.epoch_length, |
| | num_workers=args.num_workers, |
| | save_checkpoint_frequency=args.save_checkpoint_frequency, |
| | eval_period_iterations=args.eval_period_iterations, |
| | learning_rates=args.learning_rates, |
| | autocast_dtype=autocast_dtype, |
| | resume=not args.no_resume, |
| | classifier_fpath=args.classifier_fpath, |
| | val_metric_type=args.val_metric_type, |
| | test_metric_types=args.test_metric_types, |
| | val_class_mapping_fpath=args.val_class_mapping_fpath, |
| | test_class_mapping_fpaths=args.test_class_mapping_fpaths, |
| | ) |
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | description = "DINOv2 linear evaluation" |
| | args_parser = get_args_parser(description=description) |
| | args = args_parser.parse_args() |
| | sys.exit(main(args)) |
| |
|