File size: 4,484 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import abc
from omegaconf import DictConfig
import torch
import torch.nn as nn
from torch import Tensor


def get_schedule_from_config(config: DictConfig):
    match config.type:
        case "geometric":
            return GeometricSchedule(min_val=config.min, max_val=config.max)
        case "linear":
            return LinearSchedule()
        case "sin":
            return SinSchedule()
        case "cosine":
            return CosineSchedule()
        case "polynomial":
            return PolynomialSchedule(exp=config.exp)
        case _:
            raise ValueError(f"Invalid schedule type: {config.type}")


class Schedule(abc.ABC):
    """
    Generic schedule class for masking or noising
    This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately
    """

    @abc.abstractmethod
    def at(self, t: Tensor):
        """
        Return value a(t)
        """
        raise NotImplementedError

    @abc.abstractmethod
    def derivative_at(self, t: Tensor):
        """
        Return d/dt a(t)
        """
        raise NotImplementedError

    def rate_scale_factor(self, t: Tensor) -> Tensor:
        """
        Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation
        """
        return self.derivative_at(t) / (1 - self.at(t))

    def sample(self, shape, device) -> Tensor:
        """
        Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1]
        """
        uniform = torch.rand(shape, device=device)
        return self.inv(uniform)

    def sample_truncated(self, threshold, shape, device) -> Tensor:
        """
        Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1]
        """
        uniform = torch.rand(shape, device=device)
        threshold = self.at(threshold)
        return self.inv(uniform * (1 - threshold) + threshold)

    @abc.abstractmethod
    def inv(self, alpha: Tensor):
        """
        Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t.
        """
        raise NotImplementedError


class LinearSchedule(Schedule):
    def __init__(self):
        pass

    def at(self, t: Tensor):
        return t

    def derivative_at(self, t: Tensor):
        return torch.ones_like(t, device=t.device)

    def inv(self, alpha: Tensor):
        return alpha


class GeometricSchedule(Schedule, nn.Module):
    def __init__(self, min_val: float, max_val: float):
        super().__init__()
        self.register_buffer("min", Tensor([min_val]))
        self.register_buffer("max", Tensor([max_val]))

    def at(self, t: Tensor):
        min_val = self.min.to(t.device)
        max_val = self.max.to(t.device)
        return torch.exp(-(min_val ** (1 - t)) * max_val**t)

    def derivative_at(self, t):
        min_val = self.min.to(t.device)
        max_val = self.max.to(t.device)
        return (
            self.at(t)
            * min_val ** (1 - t)
            * max_val**t
            * (min_val.log() - max_val.log())
        )

    def inv(self, alpha: Tensor):
        log_min = self.min.to(alpha.device).log()
        log_max = self.max.to(alpha.device).log()
        return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min)


class SinSchedule(Schedule, nn.Module):
    def __init__(self):
        super().__init__()

    def at(self, t: Tensor):
        return torch.sin(torch.pi / 2 * t)

    def derivative_at(self, t: Tensor):
        return (torch.pi / 2) * torch.cos(torch.pi / 2 * t)

    def inv(self, alpha: Tensor):
        return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.))


class CosineSchedule(Schedule, nn.Module):
    def __init__(self):
        super().__init__()

    def at(self, t: Tensor):
        return 1 - torch.cos(torch.pi / 2 * t)
    
    def derivative_at(self, t: Tensor):
        return (torch.pi / 2) * torch.sin(torch.pi / 2 * t)
    
    def rate_scale_factor(self, t):
        return (torch.pi/2) * torch.tan(torch.pi / 2 * t)
    
    def inv(self, alpha):
        return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.))

class PolynomialSchedule(Schedule, nn.Module):
    def __init__(self, exp):
        super().__init__()
        self.exp = exp
    
    def at(self, t: Tensor):
        return t ** self.exp
    
    def derivative_at(self, t: Tensor):
        return self.exp * t ** (self.exp - 1)
    
    def inv(self, alpha: Tensor):
        return alpha ** (1 / self.exp)