abreza's picture
update xpose code
1e29c76
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
"""Collections of utilities related to optimization."""
from bisect import bisect_right
import os
import torch
def update_ema(model, model_ema, decay):
"""Apply exponential moving average update.
The weights are updated in-place as follow:
w_ema = w_ema * decay + (1 - decay) * w
Args:
model: active model that is being optimized
model_ema: running average model
decay: exponential decay parameter
"""
with torch.no_grad():
if hasattr(model, "module"):
# unwrapping DDP
model = model.module
msd = model.state_dict()
for k, ema_v in model_ema.state_dict().items():
model_v = msd[k].detach()
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def adjust_learning_rate(
optimizer,
epoch: int,
curr_step: int,
args,
):
"""Adjust the lr according to the schedule.
Args:
Optimizer: torch optimizer to update.
epoch(int): number of the current epoch.
curr_step(int): number of optimization step taken so far.
num_training_step(int): total number of optimization steps.
args: additional training dependent args:
- lr_drop(int): number of epochs before dropping the learning rate.
- fraction_warmup_steps(float) fraction of steps over which the lr will be increased to its peak.
- lr(float): base learning rate
- lr_backbone(float): learning rate of the backbone
- text_encoder_backbone(float): learning rate of the text encoder
- schedule(str): the requested learning rate schedule:
"step": all lrs divided by 10 after lr_drop epochs
"multistep": divided by 2 after lr_drop epochs, then by 2 after every 50 epochs
"linear_with_warmup": same as "step" for backbone + transformer, but for the text encoder, linearly
increase for a fraction of the training, then linearly decrease back to 0.
"all_linear_with_warmup": same as "linear_with_warmup" for all learning rates involved.
"""
try:
num_warmup_steps = args.num_warmup_steps
except:
return
if epoch > 0:
return
if curr_step > num_warmup_steps:
return
text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps))
optimizer.param_groups[-1]["lr"] = args.lr_backbone * text_encoder_gamma