|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""util for drop scheduler."""
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode='standard', schedule='constant'):
|
|
|
"""drop scheduler"""
|
|
|
assert mode in ['standard', 'early', 'late']
|
|
|
if mode == 'standard':
|
|
|
return np.full(epochs * niter_per_ep, drop_rate)
|
|
|
|
|
|
early_iters = cutoff_epoch * niter_per_ep
|
|
|
late_iters = (epochs - cutoff_epoch) * niter_per_ep
|
|
|
|
|
|
if mode == 'early':
|
|
|
assert schedule in ['constant', 'linear']
|
|
|
if schedule == 'constant':
|
|
|
early_schedule = np.full(early_iters, drop_rate)
|
|
|
elif schedule == 'linear':
|
|
|
early_schedule = np.linspace(drop_rate, 0, early_iters)
|
|
|
final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
|
|
|
elif mode == 'late':
|
|
|
assert schedule in ['constant']
|
|
|
early_schedule = np.full(early_iters, 0)
|
|
|
final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))
|
|
|
|
|
|
assert len(final_schedule) == epochs * niter_per_ep
|
|
|
return final_schedule
|
|
|
|