Prior2DSM / src /dinov3 /train /train.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import argparse
import copy
import gc
import logging
import math
import os
import sys
from functools import partial
from pathlib import Path
import torch
import torch.distributed
from torch.distributed._tensor import DTensor
import dinov3.distributed as distributed
from dinov3.checkpointer import (
find_latest_checkpoint,
keep_checkpoint_copy,
keep_last_n_checkpoints,
load_checkpoint,
register_dont_save_hooks,
save_checkpoint,
)
from dinov3.configs import setup_config, setup_job, setup_multidistillation
from dinov3.data import (
MaskingGenerator,
SamplerType,
collate_data_and_cast,
make_data_loader,
make_dataset,
CombinedDataLoader,
)
from dinov3.logging import MetricLogger, setup_logging
from dinov3.train.cosine_lr_scheduler import CosineScheduler, linear_warmup_cosine_decay
from dinov3.train.multidist_meta_arch import MultiDistillationMetaArch
from dinov3.train.ssl_meta_arch import SSLMetaArch
assert torch.__version__ >= (2, 1)
torch.backends.cuda.matmul.allow_tf32 = True # pytorch 1.12 sets this to false by default
torch.backends.cudnn.benchmark = False # True
logger = logging.getLogger("dinov3")
def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("DINOv3 training", add_help=add_help)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--no-resume",
action="store_true",
help="Whether to not attempt to resume from the checkpoint directory. ",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--eval", type=str, default="", help="Eval type to perform")
parser.add_argument(
"--eval_pretrained_weights",
type=str,
default="",
help="Path to pretrained weights",
)
parser.add_argument(
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--output-dir",
default="./local_dino",
type=str,
help="Path to save logs and checkpoints.",
)
parser.add_argument("--seed", default=0, type=int, help="RNG seed")
parser.add_argument(
"--benchmark-codebase",
action="store_true",
help="test the codebase for a few iters",
)
parser.add_argument("--test-ibot", action="store_true", help="test ibot")
parser.add_argument("--profiling", action="store_true", help="do profiling")
parser.add_argument("--dump-fsdp-weights", action="store_true", help="dump fsdp weights")
parser.add_argument("--record_ref_losses", action="store_true", help="record reference losses")
parser.add_argument("--ref_losses_path", default="", type=str)
parser.add_argument("--multi-distillation", action="store_true", help="run multi-distillation")
return parser
def build_optimizer(cfg, params_groups):
return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2))
def build_schedulers(cfg):
if "schedules" in cfg:
logger.info("Using schedules v2")
return build_schedulers_v2(cfg)
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
lr = dict(
base_value=cfg.optim["lr"],
final_value=cfg.optim["min_lr"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
start_warmup_value=0,
trunc_extra=cfg.optim["schedule_trunc_extra"],
)
wd = dict(
base_value=cfg.optim["weight_decay"],
final_value=cfg.optim["weight_decay_end"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
trunc_extra=cfg.optim["schedule_trunc_extra"],
)
momentum = dict(
base_value=cfg.teacher["momentum_teacher"],
final_value=cfg.teacher["final_momentum_teacher"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
trunc_extra=cfg.optim["schedule_trunc_extra"],
)
teacher_temp = dict(
base_value=cfg.teacher["teacher_temp"],
final_value=cfg.teacher["teacher_temp"],
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
)
lr_schedule = CosineScheduler(**lr)
wd_schedule = CosineScheduler(**wd)
momentum_schedule = CosineScheduler(**momentum)
teacher_temp_schedule = CosineScheduler(**teacher_temp)
last_layer_lr_schedule = CosineScheduler(**lr)
last_layer_lr_schedule.schedule[: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH] = (
0 # mimicking the original schedules
)
logger.info("Schedulers ready.")
return (
lr_schedule,
wd_schedule,
momentum_schedule,
teacher_temp_schedule,
last_layer_lr_schedule,
)
def build_schedulers_v2(cfg):
iter_per_epoch = cfg.train.OFFICIAL_EPOCH_LENGTH
total_iterations = cfg.train.OFFICIAL_EPOCH_LENGTH * cfg.optim.epochs
logger.info(f"Total training iterations {total_iterations}")
# LR scaling rules
lr_peak = cfg.schedules.lr.peak
lr_end = cfg.schedules.lr.end
if cfg.optim.scaling_rule == "linear_wrt_256":
lr_peak *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0
lr_end *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0
logger.info(
f"Scaling rule {cfg.optim.scaling_rule}, LR peak {cfg.schedules.lr.peak} -> {lr_peak}, LR end {cfg.schedules.lr.end} -> {lr_end}"
)
elif cfg.optim.scaling_rule == "sqrt_wrt_1024":
lr_peak *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0)
lr_end *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0)
logger.info(
f"Scaling rule {cfg.optim.scaling_rule}, LR peak {cfg.schedules.lr.peak} -> {lr_peak}, LR end {cfg.schedules.lr.end} -> {lr_end}"
)
else:
logger.info(f"No scaling rule for {cfg.optim.scaling_rule=}")
lr = linear_warmup_cosine_decay(
start=cfg.schedules.lr.start,
peak=lr_peak,
end=lr_end,
warmup_iterations=iter_per_epoch * cfg.schedules.lr.warmup_epochs,
total_iterations=total_iterations,
cosine_iterations=(
iter_per_epoch * cfg.schedules.lr.cosine_epochs if "cosine_epochs" in cfg.schedules.lr else None
),
)
last_layer_lr = lr.copy()
last_layer_lr[: iter_per_epoch * cfg.schedules.lr.freeze_last_layer_epochs] = 0
weight_decay = linear_warmup_cosine_decay(
start=cfg.schedules.weight_decay.start,
peak=cfg.schedules.weight_decay.peak,
end=cfg.schedules.weight_decay.end,
warmup_iterations=iter_per_epoch * cfg.schedules.weight_decay.warmup_epochs,
total_iterations=total_iterations,
cosine_iterations=(
iter_per_epoch * cfg.schedules.weight_decay.cosine_epochs
if "cosine_epochs" in cfg.schedules.weight_decay
else None
),
)
momentum = linear_warmup_cosine_decay(
start=cfg.schedules.momentum.start,
peak=cfg.schedules.momentum.peak,
end=cfg.schedules.momentum.end,
warmup_iterations=iter_per_epoch * cfg.schedules.momentum.warmup_epochs,
total_iterations=total_iterations,
cosine_iterations=(
iter_per_epoch * cfg.schedules.momentum.cosine_epochs if "cosine_epochs" in cfg.schedules.momentum else None
),
)
teacher_temp = linear_warmup_cosine_decay(
start=cfg.schedules.teacher_temp.start,
peak=cfg.schedules.teacher_temp.peak,
end=cfg.schedules.teacher_temp.end,
warmup_iterations=iter_per_epoch * cfg.schedules.teacher_temp.warmup_epochs,
total_iterations=total_iterations,
cosine_iterations=(
iter_per_epoch * cfg.schedules.teacher_temp.cosine_epochs
if "cosine_epochs" in cfg.schedules.teacher_temp
else None
),
)
return lr, weight_decay, momentum, teacher_temp, last_layer_lr
def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
for param_group in optimizer.param_groups:
is_last_layer = param_group["is_last_layer"]
lr_multiplier = param_group["lr_multiplier"]
wd_multiplier = param_group["wd_multiplier"]
param_group["weight_decay"] = wd * wd_multiplier
if is_last_layer:
param_group["lr"] = last_layer_lr * lr_multiplier
else:
param_group["lr"] = lr * lr_multiplier
def do_test(cfg, model, iteration, process_group, do_low_freq=False):
# dump a sharded checkpoint
eval_dir = Path(cfg.train.output_dir) / "eval" / str(iteration)
if distributed.is_subgroup_main_process():
eval_dir.mkdir(parents=True, exist_ok=True)
if cfg.train.sharded_eval_checkpoint:
ckpt_path = eval_dir / "sharded_teacher_checkpoint"
if distributed.is_subgroup_main_process():
ckpt_path.mkdir(parents=True, exist_ok=True)
torch.distributed.barrier()
teacher_backbone = model.model_ema
save_checkpoint(
ckpt_dir=ckpt_path, iteration=iteration, model=teacher_backbone, overwrite=True, process_group=process_group
)
if not distributed.is_subgroup_main_process():
return
else:
new_state_dict = model.model_ema.state_dict()
for k, tensor in list(new_state_dict.items()):
if isinstance(tensor, DTensor):
new_state_dict[k] = tensor.full_tensor()
if not distributed.is_subgroup_main_process():
return
# save teacher checkpoint
ckpt_path = eval_dir / "teacher_checkpoint.pth"
torch.save({"teacher": new_state_dict}, ckpt_path)
logger.info("Saved eval checkpoint: %s", ckpt_path)
def build_data_loader_from_cfg(
cfg,
model,
start_iter,
):
# Collate function
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)
if cfg.multidistillation.enabled:
assert cfg.multidistillation.global_batch_size % distributed.get_subgroup_size() == 0
local_batch_size = cfg.multidistillation.global_batch_size // distributed.get_subgroup_size()
dataloader_batch_size_per_gpu = (
cfg.multidistillation.global_batch_size + (distributed.get_world_size() - 1)
) // distributed.get_world_size()
else:
local_batch_size = None # will default to the standard local batch size matching the data batch size
dataloader_batch_size_per_gpu = cfg.train.batch_size_per_gpu
collate_fn = partial(
collate_data_and_cast,
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
mask_probability=cfg.ibot.mask_sample_probability,
dtype={
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}[cfg.compute_precision.param_dtype],
n_tokens=n_tokens,
mask_generator=mask_generator,
random_circular_shift=cfg.ibot.mask_random_circular_shift,
local_batch_size=local_batch_size,
)
batch_size = dataloader_batch_size_per_gpu
num_workers = cfg.train.num_workers
dataset_path = cfg.train.dataset_path
dataset = make_dataset(
dataset_str=dataset_path,
transform=model.build_data_augmentation_dino(cfg),
target_transform=lambda _: (),
)
if isinstance(dataset, torch.utils.data.IterableDataset):
sampler_type = SamplerType.INFINITE
else:
sampler_type = SamplerType.SHARDED_INFINITE if cfg.train.cache_dataset else SamplerType.INFINITE
data_loader = make_data_loader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
seed=cfg.train.seed + start_iter + 1,
sampler_type=sampler_type,
sampler_advance=start_iter * dataloader_batch_size_per_gpu,
drop_last=True,
collate_fn=collate_fn,
)
return data_loader
def build_multi_resolution_data_loader_from_cfg(
cfg,
model,
start_iter,
seed=65537,
):
global_crops_sizes = (
[cfg.crops.global_crops_size] if isinstance(cfg.crops.global_crops_size, int) else cfg.crops.global_crops_size
)
local_crops_sizes = (
[cfg.crops.local_crops_size] if isinstance(cfg.crops.local_crops_size, int) else cfg.crops.local_crops_size
)
gram_teacher_crops_sizes = (
[cfg.crops.gram_teacher_crops_size]
if cfg.crops.gram_teacher_crops_size is None or isinstance(cfg.crops.gram_teacher_crops_size, int)
else cfg.crops.gram_teacher_crops_size
)
loader_ratios = (
[cfg.crops.global_local_crop_pairs_ratios]
if type(cfg.crops.global_local_crop_pairs_ratios) in [int, float]
else cfg.crops.global_local_crop_pairs_ratios
)
assert len(global_crops_sizes) == len(local_crops_sizes) == len(gram_teacher_crops_sizes) == len(loader_ratios)
loaders = []
for increment, (global_crops_size_i, local_crops_size_i, gram_teacher_crops_size_i) in enumerate(
zip(global_crops_sizes, local_crops_sizes, gram_teacher_crops_sizes)
):
cfg_i = copy.deepcopy(cfg)
cfg_i.crops.global_crops_size = global_crops_size_i
cfg_i.crops.local_crops_size = local_crops_size_i
cfg_i.crops.gram_teacher_crops_size = gram_teacher_crops_size_i
cfg_i.train.seed = cfg.train.seed + increment + 1
loaders.append(build_data_loader_from_cfg(cfg=cfg_i, model=model, start_iter=start_iter))
if len(loaders) == 1:
data_loader = loaders[0]
else:
data_loader = CombinedDataLoader(
loaders_with_ratios=zip(loaders, loader_ratios),
batch_size=cfg.train.batch_size_per_gpu,
combining_mode=0,
seed=seed,
name="MultiResDL",
)
return data_loader
def do_train(cfg, model, resume=False):
process_subgroup = distributed.get_process_subgroup()
ckpt_dir = Path(cfg.train.output_dir, "ckpt").expanduser()
ckpt_dir.mkdir(parents=True, exist_ok=True)
model.train()
# Optimizer
optimizer = build_optimizer(cfg, model.get_params_groups())
(
lr_schedule,
wd_schedule,
momentum_schedule,
teacher_temp_schedule,
last_layer_lr_schedule,
) = build_schedulers(cfg)
if cfg.multidistillation.enabled:
register_dont_save_hooks(
model,
dont_save=[k for k, _ in model.state_dict().items() if k.startswith("teacher")],
)
model.init_weights()
start_iter = 0
if resume and (last_checkpoint_dir := find_latest_checkpoint(ckpt_dir)):
logger.info(f"Checkpoint found {last_checkpoint_dir}")
start_iter = (
load_checkpoint(
last_checkpoint_dir,
model=model,
optimizer=optimizer,
strict_loading=False,
process_group=process_subgroup,
)
+ 1
)
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
if cfg.multidistillation.enabled:
global_batch_size = cfg.multidistillation.global_batch_size
else:
global_batch_size = cfg.train.batch_size_per_gpu * distributed.get_world_size()
# Build data loader
data_loader = build_multi_resolution_data_loader_from_cfg(
cfg=cfg,
model=model,
start_iter=start_iter,
)
# Metric logging
logger.info("Starting training from iteration %d", start_iter)
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
# Manual garbage collection
gc.disable()
gc.collect()
# Training loop
student = model.student
iteration = start_iter
num_gram_updates = 0
if (
cfg.gram.use_loss
and model.has_gram_teacher
and cfg.gram.rep_update
and start_iter > 0
and start_iter >= cfg.gram.it_first_update
):
# If `start_iter == it_first_update`, we have performed one gram teacher update after
# iteration `start_iter - 1`, except if we are starting training from scratch and `start_iter == 0`.
num_gram_updates = math.ceil((start_iter + 1 - cfg.gram.it_first_update) / cfg.gram.update_frequency)
logger.info(f"Gram was updated {num_gram_updates} times before iteration {start_iter}")
consecutive_nan_count = 0
for data in metric_logger.log_every(
data_loader,
print_freq=10,
header="Training",
n_iterations=max_iter,
start_iteration=start_iter,
):
it = iteration
data["global_batch_size"] = global_batch_size
if iteration > max_iter:
return
# Garbage collection (trigger manually so it happens on all ranks at the same time)
if (iteration + 1) % 150 == 0:
logger.info("Garbage collection")
gc.collect()
if cfg.gram.use_loss and model.gram_it_load_ema_teacher == it:
logger.info(f"Loading EMA teacher into Gram teacher before iteration {it}")
model.gram_load_ema_teacher()
# Learning rates and other schedules
lr = lr_schedule[it]
wd = wd_schedule[it]
mom = momentum_schedule[it]
teacher_temp = teacher_temp_schedule[it]
last_layer_lr = last_layer_lr_schedule[it]
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
# Forward backward
optimizer.zero_grad(set_to_none=True)
total_loss, metrics_dict = model.forward_backward(data, teacher_temp=teacher_temp, iteration=it)
# Gradient clipping
if cfg.optim.clip_grad:
for k, v in student.items():
grad_norm = torch.nn.utils.clip_grad_norm_(
v.parameters(),
max_norm=cfg.optim.clip_grad,
)
metrics_dict[f"{k}_grad_norm"] = (
grad_norm.full_tensor().item()
if isinstance(grad_norm, torch.distributed.tensor.DTensor)
else grad_norm.item()
)
# Reduce total_loss to check for NaNs, reduce metrics for logging
total_loss_all_ranks = total_loss.new_empty(distributed.get_subgroup_size())
torch.distributed.all_gather_into_tensor(
total_loss_all_ranks,
total_loss.detach(),
group=distributed.get_process_subgroup(),
)
total_loss = total_loss_all_ranks.mean()
metrics_values = torch.stack(
[torch.as_tensor(v, dtype=torch.float32, device=total_loss.device).detach() for v in metrics_dict.values()]
)
torch.distributed.all_reduce(
metrics_values,
op=torch.distributed.ReduceOp.AVG,
group=distributed.get_process_subgroup(),
)
metrics_dict = dict(zip(metrics_dict.keys(), metrics_values))
if total_loss_all_ranks.isnan().any():
consecutive_nan_count += 1
which_ranks = total_loss_all_ranks.isnan().nonzero().flatten().tolist()
logger.warning("NaN loss detected on ranks: %s", which_ranks)
logger.warning("Consecutive NaNs: %d", consecutive_nan_count)
metrics_dict_str = "\n".join([f"{k}: {v}" for k, v in metrics_dict.items()])
logger.warning("All-reduced metrics:\n%s", metrics_dict_str)
if consecutive_nan_count > 2 and not cfg.multidistillation.enabled:
msg = "Too many consecutive nans detected in loss, aborting..."
logger.error(msg)
raise RuntimeError(msg)
else:
consecutive_nan_count = 0
# Step optimizer
optimizer.step()
model.update_ema(mom)
# [GRAM] Update gram teacher when using gram teacher and frequent updates
if (
cfg.gram.use_loss
and model.gram_rep_update
and (it + 1) >= model.gram_it_first_update
and (it + 1) % model.gram_update_frequency == 0
and (cfg.gram.max_updates is None or num_gram_updates < cfg.gram.max_updates)
):
logger.info(f"Updating Gram teacher from EMA teacher after iteration {it}")
model.update_gram()
num_gram_updates += 1
# Log metrics
metric_logger.update(lr=lr)
metric_logger.update(wd=wd)
metric_logger.update(mom=mom)
metric_logger.update(last_layer_lr=last_layer_lr)
metric_logger.update(total_loss=total_loss, **metrics_dict)
# Submit evaluation jobs
if (
cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0
# and iteration != max_iter - 1
):
do_test(cfg, model, f"training_{iteration}", process_group=process_subgroup)
torch.cuda.synchronize()
# Checkpointing
if (iteration + 1) % cfg.checkpointing.period == 0:
torch.cuda.synchronize()
save_checkpoint(
ckpt_dir / str(iteration),
iteration=iteration,
model=model,
optimizer=optimizer,
overwrite=True,
process_group=process_subgroup,
)
if distributed.is_subgroup_main_process():
keep_last_n_checkpoints(ckpt_dir, cfg.checkpointing.max_to_keep)
if "keep_every" in cfg.checkpointing and (iteration + 1) % cfg.checkpointing.keep_every == 0:
keep_checkpoint_copy(ckpt_dir / str(iteration))
iteration = iteration + 1
metric_logger.synchronize_between_processes()
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def main(argv=None):
if argv is None:
args = get_args_parser().parse_args()
else:
args = get_args_parser().parse_args(argv[1:])
args.output_dir = sys.argv[1]
if args.multi_distillation:
print("performing multidistillation run")
cfg = setup_multidistillation(args)
torch.distributed.barrier()
logger.info("setup_multidistillation done")
assert cfg.MODEL.META_ARCHITECTURE == "MultiDistillationMetaArch"
else:
setup_job(output_dir=args.output_dir, seed=args.seed)
cfg = setup_config(args, strict_cfg=False)
logger.info(cfg)
setup_logging(
output=os.path.join(os.path.abspath(args.output_dir), "nan_logs"),
name="nan_logger",
)
meta_arch = {
"SSLMetaArch": SSLMetaArch,
"MultiDistillationMetaArch": MultiDistillationMetaArch,
}.get(cfg.MODEL.META_ARCHITECTURE, None)
if meta_arch is None:
raise ValueError(f"Unknown MODEL.META_ARCHITECTURE {cfg.MODEL.META_ARCHITECTURE}")
logger.info(f"Making meta arch {meta_arch.__name__}")
with torch.device("meta"):
model = meta_arch(cfg)
model.prepare_for_distributed_training()
# Fill all values with `nans` so that we identify
# non-initialized values
model._apply(
lambda t: torch.full_like(
t,
fill_value=math.nan if t.dtype.is_floating_point else (2 ** (t.dtype.itemsize * 8 - 1)),
device="cuda",
),
recurse=True,
)
logger.info(f"Model after distributed:\n{model}")
if args.eval_only:
model.init_weights()
iteration = (
model.get_checkpointer_class()(model, save_dir=cfg.train.output_dir)
.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
.get("iteration", -1)
+ 1
)
return do_test(cfg, model, f"manual_{iteration}")
do_train(cfg, model, resume=not args.no_resume)
if __name__ == "__main__":
main()