nas / PFMBench /src /data /esm /utils /noise_schedules.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import math
import torch
def cosine_schedule(t: torch.Tensor):
# t is a tensor of size (batch_size,) with values between 0 and 1. This is the
# schedule used in the MaskGIT paper
return torch.cos(t * math.pi * 0.5)
def cubic_schedule(t):
return 1 - t**3
def linear_schedule(t):
return 1 - t
def square_root_schedule(t):
return 1 - torch.sqrt(t)
def square_schedule(t):
return 1 - t**2
NOISE_SCHEDULE_REGISTRY = {
"cosine": cosine_schedule,
"linear": linear_schedule,
"square_root_schedule": square_root_schedule,
"cubic": cubic_schedule,
"square": square_schedule,
}