Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import abc
from omegaconf import DictConfig
import torch
import torch.nn as nn
from torch import Tensor
def get_schedule_from_config(config: DictConfig):
match config.type:
case "geometric":
return GeometricSchedule(min_val=config.min, max_val=config.max)
case "linear":
return LinearSchedule()
case "sin":
return SinSchedule()
case "cosine":
return CosineSchedule()
case "polynomial":
return PolynomialSchedule(exp=config.exp)
case _:
raise ValueError(f"Invalid schedule type: {config.type}")
class Schedule(abc.ABC):
"""
Generic schedule class for masking or noising
This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately
"""
@abc.abstractmethod
def at(self, t: Tensor):
"""
Return value a(t)
"""
raise NotImplementedError
@abc.abstractmethod
def derivative_at(self, t: Tensor):
"""
Return d/dt a(t)
"""
raise NotImplementedError
def rate_scale_factor(self, t: Tensor) -> Tensor:
"""
Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation
"""
return self.derivative_at(t) / (1 - self.at(t))
def sample(self, shape, device) -> Tensor:
"""
Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1]
"""
uniform = torch.rand(shape, device=device)
return self.inv(uniform)
def sample_truncated(self, threshold, shape, device) -> Tensor:
"""
Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1]
"""
uniform = torch.rand(shape, device=device)
threshold = self.at(threshold)
return self.inv(uniform * (1 - threshold) + threshold)
@abc.abstractmethod
def inv(self, alpha: Tensor):
"""
Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t.
"""
raise NotImplementedError
class LinearSchedule(Schedule):
def __init__(self):
pass
def at(self, t: Tensor):
return t
def derivative_at(self, t: Tensor):
return torch.ones_like(t, device=t.device)
def inv(self, alpha: Tensor):
return alpha
class GeometricSchedule(Schedule, nn.Module):
def __init__(self, min_val: float, max_val: float):
super().__init__()
self.register_buffer("min", Tensor([min_val]))
self.register_buffer("max", Tensor([max_val]))
def at(self, t: Tensor):
min_val = self.min.to(t.device)
max_val = self.max.to(t.device)
return torch.exp(-(min_val ** (1 - t)) * max_val**t)
def derivative_at(self, t):
min_val = self.min.to(t.device)
max_val = self.max.to(t.device)
return (
self.at(t)
* min_val ** (1 - t)
* max_val**t
* (min_val.log() - max_val.log())
)
def inv(self, alpha: Tensor):
log_min = self.min.to(alpha.device).log()
log_max = self.max.to(alpha.device).log()
return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min)
class SinSchedule(Schedule, nn.Module):
def __init__(self):
super().__init__()
def at(self, t: Tensor):
return torch.sin(torch.pi / 2 * t)
def derivative_at(self, t: Tensor):
return (torch.pi / 2) * torch.cos(torch.pi / 2 * t)
def inv(self, alpha: Tensor):
return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.))
class CosineSchedule(Schedule, nn.Module):
def __init__(self):
super().__init__()
def at(self, t: Tensor):
return 1 - torch.cos(torch.pi / 2 * t)
def derivative_at(self, t: Tensor):
return (torch.pi / 2) * torch.sin(torch.pi / 2 * t)
def rate_scale_factor(self, t):
return (torch.pi/2) * torch.tan(torch.pi / 2 * t)
def inv(self, alpha):
return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.))
class PolynomialSchedule(Schedule, nn.Module):
def __init__(self, exp):
super().__init__()
self.exp = exp
def at(self, t: Tensor):
return t ** self.exp
def derivative_at(self, t: Tensor):
return self.exp * t ** (self.exp - 1)
def inv(self, alpha: Tensor):
return alpha ** (1 / self.exp)