# ------------------------------------------------------------------------ # LW-DETR # Copyright (c) 2024 Baidu. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ """util for drop scheduler.""" import numpy as np def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'): """drop scheduler""" assert mode in ['standard', 'early', 'late'] if mode == 'standard': return np.full(epochs * niter_per_ep, drop_rate) early_iters = cutoff_epoch * niter_per_ep late_iters = (epochs - cutoff_epoch) * niter_per_ep if mode == 'early': assert schedule in ['constant', 'linear'] if schedule == 'constant': early_schedule = np.full(early_iters, drop_rate) elif schedule == 'linear': early_schedule = np.linspace(drop_rate, 0, early_iters) final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) elif mode == 'late': assert schedule in ['constant'] early_schedule = np.full(early_iters, 0) final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate))) assert len(final_schedule) == epochs * niter_per_ep return final_schedule