Kiria-Nozan commited on
Commit
044aa8d
·
verified ·
1 Parent(s): 92779dc

initial release

Browse files
Files changed (1) hide show
  1. noise_schedule.py +151 -0
noise_schedule.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Flags required to enable jit fusion kernels
7
+ torch._C._jit_set_profiling_mode(False)
8
+ torch._C._jit_set_profiling_executor(False)
9
+ torch._C._jit_override_can_fuse_on_cpu(True)
10
+ torch._C._jit_override_can_fuse_on_gpu(True)
11
+
12
+
13
+ def get_noise(config, dtype=torch.float32):
14
+ if config.noise.type == 'geometric':
15
+ return GeometricNoise(config.noise.sigma_min,
16
+ config.noise.sigma_max)
17
+ elif config.noise.type == 'loglinear':
18
+ return LogLinearNoise()
19
+ elif config.noise.type == 'cosine':
20
+ return CosineNoise()
21
+ elif config.noise.type == 'cosinesqr':
22
+ return CosineSqrNoise()
23
+ elif config.noise.type == 'linear':
24
+ return Linear(config.noise.sigma_min,
25
+ config.noise.sigma_max,
26
+ dtype)
27
+ else:
28
+ raise ValueError(f'{config.noise.type} is not a valid noise')
29
+
30
+
31
+ def binary_discretization(z):
32
+ z_hard = torch.sign(z)
33
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
34
+ return z_soft + (z_hard - z_soft).detach()
35
+
36
+
37
+ class Noise(abc.ABC, nn.Module):
38
+ """
39
+ Baseline forward method to get the total + rate of noise at a timestep
40
+ """
41
+ def forward(self, t):
42
+ # Assume time goes from 0 to 1
43
+ return self.total_noise(t), self.rate_noise(t)
44
+
45
+ @abc.abstractmethod
46
+ def rate_noise(self, t):
47
+ """
48
+ Rate of change of noise ie g(t)
49
+ """
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def total_noise(self, t):
54
+ """
55
+ Total noise ie \int_0^t g(t) dt + g(0)
56
+ """
57
+ pass
58
+
59
+
60
+ class CosineNoise(Noise):
61
+ def __init__(self, eps=1e-3):
62
+ super().__init__()
63
+ self.eps = eps
64
+
65
+ def rate_noise(self, t):
66
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
67
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
68
+ scale = torch.pi / 2
69
+ return scale * sin / (cos + self.eps)
70
+
71
+ def total_noise(self, t):
72
+ cos = torch.cos(t * torch.pi / 2)
73
+ return - torch.log(self.eps + (1 - self.eps) * cos)
74
+
75
+
76
+ class CosineSqrNoise(Noise):
77
+ def __init__(self, eps=1e-3):
78
+ super().__init__()
79
+ self.eps = eps
80
+
81
+ def rate_noise(self, t):
82
+ cos = (1 - self.eps) * (
83
+ torch.cos(t * torch.pi / 2) ** 2)
84
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
85
+ scale = torch.pi / 2
86
+ return scale * sin / (cos + self.eps)
87
+
88
+ def total_noise(self, t):
89
+ cos = torch.cos(t * torch.pi / 2) ** 2
90
+ return - torch.log(self.eps + (1 - self.eps) * cos)
91
+
92
+
93
+ class Linear(Noise):
94
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
95
+ super().__init__()
96
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
97
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
98
+
99
+ def rate_noise(self, t):
100
+ return self.sigma_max - self.sigma_min
101
+
102
+ def total_noise(self, t):
103
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
104
+
105
+ def importance_sampling_transformation(self, t):
106
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
107
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
108
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
109
+ return (sigma_t - self.sigma_min) / (
110
+ self.sigma_max - self.sigma_min)
111
+
112
+
113
+ class GeometricNoise(Noise):
114
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
115
+ super().__init__()
116
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
117
+
118
+ def rate_noise(self, t):
119
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
120
+ self.sigmas[1].log() - self.sigmas[0].log())
121
+
122
+ def total_noise(self, t):
123
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
124
+
125
+
126
+ class LogLinearNoise(Noise):
127
+ """Log Linear noise schedule.
128
+
129
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
130
+ ~1 when t varies from 0 to 1. Total noise is
131
+ -log(1 - (1 - eps) * t), so the sigma will be
132
+ (1 - eps) * t.
133
+ """
134
+ def __init__(self, eps=1e-3):
135
+ super().__init__()
136
+ self.eps = eps
137
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
138
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
139
+
140
+ def rate_noise(self, t):
141
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
142
+
143
+ def total_noise(self, t):
144
+ return -torch.log1p(-(1 - self.eps) * t)
145
+
146
+ def importance_sampling_transformation(self, t):
147
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
148
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
149
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
150
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
151
+ return t