| import math | |
| import torch | |
| def cosine_schedule(t: torch.Tensor): | |
| # t is a tensor of size (batch_size,) with values between 0 and 1. This is the | |
| # schedule used in the MaskGIT paper | |
| return torch.cos(t * math.pi * 0.5) | |
| def cubic_schedule(t): | |
| return 1 - t**3 | |
| def linear_schedule(t): | |
| return 1 - t | |
| def square_root_schedule(t): | |
| return 1 - torch.sqrt(t) | |
| def square_schedule(t): | |
| return 1 - t**2 | |
| NOISE_SCHEDULE_REGISTRY = { | |
| "cosine": cosine_schedule, | |
| "linear": linear_schedule, | |
| "square_root_schedule": square_root_schedule, | |
| "cubic": cubic_schedule, | |
| "square": square_schedule, | |
| } | |