Spaces:
No application file
No application file
File size: 4,484 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import abc
from omegaconf import DictConfig
import torch
import torch.nn as nn
from torch import Tensor
def get_schedule_from_config(config: DictConfig):
match config.type:
case "geometric":
return GeometricSchedule(min_val=config.min, max_val=config.max)
case "linear":
return LinearSchedule()
case "sin":
return SinSchedule()
case "cosine":
return CosineSchedule()
case "polynomial":
return PolynomialSchedule(exp=config.exp)
case _:
raise ValueError(f"Invalid schedule type: {config.type}")
class Schedule(abc.ABC):
"""
Generic schedule class for masking or noising
This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately
"""
@abc.abstractmethod
def at(self, t: Tensor):
"""
Return value a(t)
"""
raise NotImplementedError
@abc.abstractmethod
def derivative_at(self, t: Tensor):
"""
Return d/dt a(t)
"""
raise NotImplementedError
def rate_scale_factor(self, t: Tensor) -> Tensor:
"""
Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation
"""
return self.derivative_at(t) / (1 - self.at(t))
def sample(self, shape, device) -> Tensor:
"""
Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1]
"""
uniform = torch.rand(shape, device=device)
return self.inv(uniform)
def sample_truncated(self, threshold, shape, device) -> Tensor:
"""
Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1]
"""
uniform = torch.rand(shape, device=device)
threshold = self.at(threshold)
return self.inv(uniform * (1 - threshold) + threshold)
@abc.abstractmethod
def inv(self, alpha: Tensor):
"""
Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t.
"""
raise NotImplementedError
class LinearSchedule(Schedule):
def __init__(self):
pass
def at(self, t: Tensor):
return t
def derivative_at(self, t: Tensor):
return torch.ones_like(t, device=t.device)
def inv(self, alpha: Tensor):
return alpha
class GeometricSchedule(Schedule, nn.Module):
def __init__(self, min_val: float, max_val: float):
super().__init__()
self.register_buffer("min", Tensor([min_val]))
self.register_buffer("max", Tensor([max_val]))
def at(self, t: Tensor):
min_val = self.min.to(t.device)
max_val = self.max.to(t.device)
return torch.exp(-(min_val ** (1 - t)) * max_val**t)
def derivative_at(self, t):
min_val = self.min.to(t.device)
max_val = self.max.to(t.device)
return (
self.at(t)
* min_val ** (1 - t)
* max_val**t
* (min_val.log() - max_val.log())
)
def inv(self, alpha: Tensor):
log_min = self.min.to(alpha.device).log()
log_max = self.max.to(alpha.device).log()
return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min)
class SinSchedule(Schedule, nn.Module):
def __init__(self):
super().__init__()
def at(self, t: Tensor):
return torch.sin(torch.pi / 2 * t)
def derivative_at(self, t: Tensor):
return (torch.pi / 2) * torch.cos(torch.pi / 2 * t)
def inv(self, alpha: Tensor):
return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.))
class CosineSchedule(Schedule, nn.Module):
def __init__(self):
super().__init__()
def at(self, t: Tensor):
return 1 - torch.cos(torch.pi / 2 * t)
def derivative_at(self, t: Tensor):
return (torch.pi / 2) * torch.sin(torch.pi / 2 * t)
def rate_scale_factor(self, t):
return (torch.pi/2) * torch.tan(torch.pi / 2 * t)
def inv(self, alpha):
return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.))
class PolynomialSchedule(Schedule, nn.Module):
def __init__(self, exp):
super().__init__()
self.exp = exp
def at(self, t: Tensor):
return t ** self.exp
def derivative_at(self, t: Tensor):
return self.exp * t ** (self.exp - 1)
def inv(self, alpha: Tensor):
return alpha ** (1 / self.exp) |