Spaces:
No application file
No application file
| import abc | |
| from omegaconf import DictConfig | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| def get_schedule_from_config(config: DictConfig): | |
| match config.type: | |
| case "geometric": | |
| return GeometricSchedule(min_val=config.min, max_val=config.max) | |
| case "linear": | |
| return LinearSchedule() | |
| case "sin": | |
| return SinSchedule() | |
| case "cosine": | |
| return CosineSchedule() | |
| case "polynomial": | |
| return PolynomialSchedule(exp=config.exp) | |
| case _: | |
| raise ValueError(f"Invalid schedule type: {config.type}") | |
| class Schedule(abc.ABC): | |
| """ | |
| Generic schedule class for masking or noising | |
| This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately | |
| """ | |
| def at(self, t: Tensor): | |
| """ | |
| Return value a(t) | |
| """ | |
| raise NotImplementedError | |
| def derivative_at(self, t: Tensor): | |
| """ | |
| Return d/dt a(t) | |
| """ | |
| raise NotImplementedError | |
| def rate_scale_factor(self, t: Tensor) -> Tensor: | |
| """ | |
| Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation | |
| """ | |
| return self.derivative_at(t) / (1 - self.at(t)) | |
| def sample(self, shape, device) -> Tensor: | |
| """ | |
| Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1] | |
| """ | |
| uniform = torch.rand(shape, device=device) | |
| return self.inv(uniform) | |
| def sample_truncated(self, threshold, shape, device) -> Tensor: | |
| """ | |
| Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1] | |
| """ | |
| uniform = torch.rand(shape, device=device) | |
| threshold = self.at(threshold) | |
| return self.inv(uniform * (1 - threshold) + threshold) | |
| def inv(self, alpha: Tensor): | |
| """ | |
| Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t. | |
| """ | |
| raise NotImplementedError | |
| class LinearSchedule(Schedule): | |
| def __init__(self): | |
| pass | |
| def at(self, t: Tensor): | |
| return t | |
| def derivative_at(self, t: Tensor): | |
| return torch.ones_like(t, device=t.device) | |
| def inv(self, alpha: Tensor): | |
| return alpha | |
| class GeometricSchedule(Schedule, nn.Module): | |
| def __init__(self, min_val: float, max_val: float): | |
| super().__init__() | |
| self.register_buffer("min", Tensor([min_val])) | |
| self.register_buffer("max", Tensor([max_val])) | |
| def at(self, t: Tensor): | |
| min_val = self.min.to(t.device) | |
| max_val = self.max.to(t.device) | |
| return torch.exp(-(min_val ** (1 - t)) * max_val**t) | |
| def derivative_at(self, t): | |
| min_val = self.min.to(t.device) | |
| max_val = self.max.to(t.device) | |
| return ( | |
| self.at(t) | |
| * min_val ** (1 - t) | |
| * max_val**t | |
| * (min_val.log() - max_val.log()) | |
| ) | |
| def inv(self, alpha: Tensor): | |
| log_min = self.min.to(alpha.device).log() | |
| log_max = self.max.to(alpha.device).log() | |
| return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min) | |
| class SinSchedule(Schedule, nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def at(self, t: Tensor): | |
| return torch.sin(torch.pi / 2 * t) | |
| def derivative_at(self, t: Tensor): | |
| return (torch.pi / 2) * torch.cos(torch.pi / 2 * t) | |
| def inv(self, alpha: Tensor): | |
| return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.)) | |
| class CosineSchedule(Schedule, nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def at(self, t: Tensor): | |
| return 1 - torch.cos(torch.pi / 2 * t) | |
| def derivative_at(self, t: Tensor): | |
| return (torch.pi / 2) * torch.sin(torch.pi / 2 * t) | |
| def rate_scale_factor(self, t): | |
| return (torch.pi/2) * torch.tan(torch.pi / 2 * t) | |
| def inv(self, alpha): | |
| return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.)) | |
| class PolynomialSchedule(Schedule, nn.Module): | |
| def __init__(self, exp): | |
| super().__init__() | |
| self.exp = exp | |
| def at(self, t: Tensor): | |
| return t ** self.exp | |
| def derivative_at(self, t: Tensor): | |
| return self.exp * t ** (self.exp - 1) | |
| def inv(self, alpha: Tensor): | |
| return alpha ** (1 / self.exp) |