| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | implment some functions for optimizers |
| | """ |
| | import numpy as np |
| | import torch |
| |
|
| | import utils |
| |
|
| |
|
| | def clip_gradients(model, clip): |
| | """ |
| | clip gradient if gradient norm > clip |
| | """ |
| | norms = [] |
| | for name, p in model.named_parameters(): |
| | if p.grad is not None: |
| | param_norm = p.grad.data.norm(2) |
| | norms.append(param_norm.item()) |
| | clip_coef = clip / (param_norm + 1e-6) |
| | if clip_coef < 1: |
| | p.grad.data.mul_(clip_coef) |
| | return norms |
| |
|
| |
|
| | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): |
| | """ |
| | cancle gradient if epoch > freeze_last_layer |
| | """ |
| | if epoch >= freeze_last_layer: |
| | return |
| | for n, p in model.named_parameters(): |
| | if "last_layer" in n: |
| | p.grad = None |
| |
|
| |
|
| | def cosine_scheduler( |
| | base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0 |
| | ): |
| | """ |
| | start_warmup_value to base_value in the first warmup_epochs epochs; |
| | then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs |
| | """ |
| | warmup_schedule = np.array([]) |
| | warmup_iters = warmup_epochs * niter_per_ep |
| | if warmup_epochs > 0: |
| | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) |
| |
|
| | iters = np.arange(epochs * niter_per_ep - warmup_iters) |
| | schedule = final_value + 0.5 * (base_value - final_value) * ( |
| | 1 + np.cos(np.pi * iters / len(iters)) |
| | ) |
| |
|
| | schedule = np.concatenate((warmup_schedule, schedule)) |
| | assert len(schedule) == epochs * niter_per_ep |
| | return schedule |
| |
|
| |
|
| | def get_params_groups(model): |
| | """ |
| | divide the parameters into several groups, see below |
| | """ |
| | regularized = [] |
| | not_regularized = [] |
| | patch_embed = [] |
| | patch_embed_not_regularized = [] |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | |
| | if name.endswith(".bias") or len(param.shape) == 1: |
| | if "patch_embed" in name: |
| | patch_embed_not_regularized.append(param) |
| | else: |
| | not_regularized.append(param) |
| | elif "patch_embed" in name: |
| | patch_embed.append(param) |
| | else: |
| | regularized.append(param) |
| | return [ |
| | {"name": "normal_params", "params": regularized}, |
| | {"name": "patch_embed", "params": patch_embed}, |
| | { |
| | "name": "no_wd", |
| | "params": not_regularized, |
| | "apply_wd": False, |
| | "weight_decay": 0.0, |
| | }, |
| | { |
| | "name": "patch_embed_no_wd", |
| | "params": patch_embed_not_regularized, |
| | "apply_wd": False, |
| | "weight_decay": 0.0, |
| | }, |
| | ] |
| |
|
| |
|
| | class LARS(torch.optim.Optimizer): |
| | """ |
| | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | params, |
| | lr=0, |
| | weight_decay=0, |
| | momentum=0.9, |
| | eta=0.001, |
| | weight_decay_filter=None, |
| | lars_adaptation_filter=None, |
| | ): |
| | defaults = dict( |
| | lr=lr, |
| | weight_decay=weight_decay, |
| | momentum=momentum, |
| | eta=eta, |
| | weight_decay_filter=weight_decay_filter, |
| | lars_adaptation_filter=lars_adaptation_filter, |
| | ) |
| | super().__init__(params, defaults) |
| |
|
| | @torch.no_grad() |
| | def step(self): |
| | for g in self.param_groups: |
| | for p in g["params"]: |
| | dp = p.grad |
| |
|
| | if dp is None: |
| | continue |
| |
|
| | if p.ndim != 1: |
| | dp = dp.add(p, alpha=g["weight_decay"]) |
| |
|
| | if p.ndim != 1: |
| | param_norm = torch.norm(p) |
| | update_norm = torch.norm(dp) |
| | one = torch.ones_like(param_norm) |
| | q = torch.where( |
| | param_norm > 0.0, |
| | torch.where( |
| | update_norm > 0, (g["eta"] * param_norm / update_norm), one |
| | ), |
| | one, |
| | ) |
| | dp = dp.mul(q) |
| |
|
| | param_state = self.state[p] |
| | if "mu" not in param_state: |
| | param_state["mu"] = torch.zeros_like(p) |
| | mu = param_state["mu"] |
| | mu.mul_(g["momentum"]).add_(dp) |
| |
|
| | p.add_(mu, alpha=-g["lr"]) |
| |
|
| |
|
| | def get_optimizer(student, len_dataloader, args): |
| | """ |
| | build an optimizer for training |
| | """ |
| | |
| | params_groups = get_params_groups(student) |
| | if args.optimizer == "adamw": |
| | optimizer = torch.optim.AdamW(params_groups) |
| | elif args.optimizer == "sgd": |
| | optimizer = torch.optim.SGD( |
| | params_groups, lr=0, momentum=0.9 |
| | ) |
| | elif args.optimizer == "lars": |
| | optimizer = LARS(params_groups) |
| | |
| | fp16_scaler = None |
| | if args.use_fp16: |
| | fp16_scaler = torch.cuda.amp.GradScaler() |
| |
|
| | |
| | lr_schedule = cosine_scheduler( |
| | args.lr |
| | * (args.batch_size_per_gpu * utils.get_world_size()) |
| | / 256.0, |
| | args.min_lr, |
| | args.epochs, |
| | len_dataloader, |
| | warmup_epochs=args.warmup_epochs, |
| | ) |
| | wd_schedule = cosine_scheduler( |
| | args.weight_decay, |
| | args.weight_decay_end, |
| | args.epochs, |
| | len_dataloader, |
| | ) |
| | |
| | momentum_schedule = cosine_scheduler( |
| | args.momentum_teacher, 1, args.epochs, len_dataloader |
| | ) |
| | print("Loss, optimizer and schedulers ready.") |
| |
|
| | return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule |
| |
|