Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| from functools import partial | |
| # step scheduler | |
| def fn_LinearWarmup(warmup_steps, step): | |
| if step < warmup_steps: # linear warmup | |
| return float(step) / float(max(1, warmup_steps)) | |
| else: | |
| return 1.0 | |
| def Scheduler_LinearWarmup(warmup_steps): | |
| return partial(fn_LinearWarmup, warmup_steps) | |
| def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step): | |
| if step < warmup_steps: # linear warmup | |
| return float(step) / float(max(1, warmup_steps)) | |
| else: # cosine learning rate schedule | |
| multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) | |
| return max(multipler, multipler_min) | |
| def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min): | |
| return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min) |