| |
| |
| |
| |
|
|
| import argparse |
| import logging |
| import math |
| import os |
| from functools import partial |
|
|
| from fvcore.common.checkpoint import PeriodicCheckpointer |
| import torch |
|
|
| from dinov2.data import SamplerType, make_data_loader, make_dataset |
| from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator |
| import dinov2.distributed as distributed |
| from dinov2.fsdp import FSDPCheckpointer |
| from dinov2.logging import MetricLogger |
| from dinov2.utils.config import setup |
| from dinov2.utils.utils import CosineScheduler |
|
|
| from dinov2.train.ssl_meta_arch import SSLMetaArch |
|
|
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| logger = logging.getLogger("dinov2") |
|
|
|
|
| def get_args_parser(add_help: bool = True): |
| parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) |
| parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") |
| parser.add_argument( |
| "--no-resume", |
| action="store_true", |
| help="Whether to not attempt to resume from the checkpoint directory. ", |
| ) |
| parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") |
| parser.add_argument("--eval", type=str, default="", help="Eval type to perform") |
| parser.add_argument( |
| "opts", |
| help=""" |
| Modify config options at the end of the command. For Yacs configs, use |
| space-separated "PATH.KEY VALUE" pairs. |
| For python-based LazyConfig, use "path.key=value". |
| """.strip(), |
| default=None, |
| nargs=argparse.REMAINDER, |
| ) |
| parser.add_argument( |
| "--output-dir", |
| "--output_dir", |
| default="", |
| type=str, |
| help="Output directory to save logs and checkpoints", |
| ) |
|
|
| return parser |
|
|
|
|
| def build_optimizer(cfg, params_groups): |
| return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) |
|
|
|
|
| def build_schedulers(cfg): |
| OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH |
| lr = dict( |
| base_value=cfg.optim["lr"], |
| final_value=cfg.optim["min_lr"], |
| total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
| warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, |
| start_warmup_value=0, |
| ) |
| wd = dict( |
| base_value=cfg.optim["weight_decay"], |
| final_value=cfg.optim["weight_decay_end"], |
| total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
| ) |
| momentum = dict( |
| base_value=cfg.teacher["momentum_teacher"], |
| final_value=cfg.teacher["final_momentum_teacher"], |
| total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, |
| ) |
| teacher_temp = dict( |
| base_value=cfg.teacher["teacher_temp"], |
| final_value=cfg.teacher["teacher_temp"], |
| total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, |
| warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, |
| start_warmup_value=cfg.teacher["warmup_teacher_temp"], |
| ) |
|
|
| lr_schedule = CosineScheduler(**lr) |
| wd_schedule = CosineScheduler(**wd) |
| momentum_schedule = CosineScheduler(**momentum) |
| teacher_temp_schedule = CosineScheduler(**teacher_temp) |
| last_layer_lr_schedule = CosineScheduler(**lr) |
|
|
| last_layer_lr_schedule.schedule[ |
| : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH |
| ] = 0 |
|
|
| logger.info("Schedulers ready.") |
|
|
| return ( |
| lr_schedule, |
| wd_schedule, |
| momentum_schedule, |
| teacher_temp_schedule, |
| last_layer_lr_schedule, |
| ) |
|
|
|
|
| def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): |
| for param_group in optimizer.param_groups: |
| is_last_layer = param_group["is_last_layer"] |
| lr_multiplier = param_group["lr_multiplier"] |
| wd_multiplier = param_group["wd_multiplier"] |
| param_group["weight_decay"] = wd * wd_multiplier |
| param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier |
|
|
|
|
| def do_test(cfg, model, iteration): |
| new_state_dict = model.teacher.state_dict() |
|
|
| if distributed.is_main_process(): |
| iterstring = str(iteration) |
| eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) |
| os.makedirs(eval_dir, exist_ok=True) |
| |
| teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") |
| torch.save({"teacher": new_state_dict}, teacher_ckp_path) |
|
|
|
|
| def do_train(cfg, model, resume=False): |
| model.train() |
| inputs_dtype = torch.half |
| fp16_scaler = model.fp16_scaler |
|
|
| |
|
|
| optimizer = build_optimizer(cfg, model.get_params_groups()) |
| ( |
| lr_schedule, |
| wd_schedule, |
| momentum_schedule, |
| teacher_temp_schedule, |
| last_layer_lr_schedule, |
| ) = build_schedulers(cfg) |
|
|
| |
| checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) |
|
|
| start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 |
|
|
| OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH |
| max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH |
|
|
| periodic_checkpointer = PeriodicCheckpointer( |
| checkpointer, |
| period=3 * OFFICIAL_EPOCH_LENGTH, |
| max_iter=max_iter, |
| max_to_keep=3, |
| ) |
|
|
| |
|
|
| img_size = cfg.crops.global_crops_size |
| patch_size = cfg.student.patch_size |
| n_tokens = (img_size // patch_size) ** 2 |
| mask_generator = MaskingGenerator( |
| input_size=(img_size // patch_size, img_size // patch_size), |
| max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, |
| ) |
|
|
| data_transform = DataAugmentationDINO( |
| cfg.crops.global_crops_scale, |
| cfg.crops.local_crops_scale, |
| cfg.crops.local_crops_number, |
| global_crops_size=cfg.crops.global_crops_size, |
| local_crops_size=cfg.crops.local_crops_size, |
| ) |
|
|
| collate_fn = partial( |
| collate_data_and_cast, |
| mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, |
| mask_probability=cfg.ibot.mask_sample_probability, |
| n_tokens=n_tokens, |
| mask_generator=mask_generator, |
| dtype=inputs_dtype, |
| ) |
|
|
| |
|
|
| dataset = make_dataset( |
| dataset_str=cfg.train.dataset_path, |
| transform=data_transform, |
| target_transform=lambda _: (), |
| ) |
| |
| sampler_type = SamplerType.SHARDED_INFINITE |
| data_loader = make_data_loader( |
| dataset=dataset, |
| batch_size=cfg.train.batch_size_per_gpu, |
| num_workers=cfg.train.num_workers, |
| shuffle=True, |
| seed=start_iter, |
| sampler_type=sampler_type, |
| sampler_advance=0, |
| drop_last=True, |
| collate_fn=collate_fn, |
| ) |
|
|
| |
|
|
| iteration = start_iter |
|
|
| logger.info("Starting training from iteration {}".format(start_iter)) |
| metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") |
| metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) |
| header = "Training" |
|
|
| for data in metric_logger.log_every( |
| data_loader, |
| 10, |
| header, |
| max_iter, |
| start_iter, |
| ): |
| current_batch_size = data["collated_global_crops"].shape[0] / 2 |
| if iteration > max_iter: |
| return |
|
|
| |
|
|
| lr = lr_schedule[iteration] |
| wd = wd_schedule[iteration] |
| mom = momentum_schedule[iteration] |
| teacher_temp = teacher_temp_schedule[iteration] |
| last_layer_lr = last_layer_lr_schedule[iteration] |
| apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) |
|
|
| |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) |
|
|
| |
|
|
| if fp16_scaler is not None: |
| if cfg.optim.clip_grad: |
| fp16_scaler.unscale_(optimizer) |
| for v in model.student.values(): |
| v.clip_grad_norm_(cfg.optim.clip_grad) |
| fp16_scaler.step(optimizer) |
| fp16_scaler.update() |
| else: |
| if cfg.optim.clip_grad: |
| for v in model.student.values(): |
| v.clip_grad_norm_(cfg.optim.clip_grad) |
| optimizer.step() |
|
|
| |
|
|
| model.update_teacher(mom) |
|
|
| |
|
|
| if distributed.get_global_size() > 1: |
| for v in loss_dict.values(): |
| torch.distributed.all_reduce(v) |
| loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} |
|
|
| if math.isnan(sum(loss_dict_reduced.values())): |
| logger.info("NaN detected") |
| raise AssertionError |
| losses_reduced = sum(loss for loss in loss_dict_reduced.values()) |
|
|
| metric_logger.update(lr=lr) |
| metric_logger.update(wd=wd) |
| metric_logger.update(mom=mom) |
| metric_logger.update(last_layer_lr=last_layer_lr) |
| metric_logger.update(current_batch_size=current_batch_size) |
| metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) |
|
|
| |
|
|
| if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: |
| do_test(cfg, model, f"training_{iteration}") |
| torch.cuda.synchronize() |
| periodic_checkpointer.step(iteration) |
|
|
| iteration = iteration + 1 |
| metric_logger.synchronize_between_processes() |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| def main(args): |
| cfg = setup(args) |
|
|
| model = SSLMetaArch(cfg).to(torch.device("cuda")) |
| model.prepare_for_distributed_training() |
|
|
| logger.info("Model:\n{}".format(model)) |
| if args.eval_only: |
| iteration = ( |
| FSDPCheckpointer(model, save_dir=cfg.train.output_dir) |
| .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) |
| .get("iteration", -1) |
| + 1 |
| ) |
| return do_test(cfg, model, f"manual_{iteration}") |
|
|
| do_train(cfg, model, resume=not args.no_resume) |
|
|
|
|
| if __name__ == "__main__": |
| args = get_args_parser(add_help=True).parse_args() |
| main(args) |
|
|