WeiChow commited on
Commit
3a98aab
·
verified ·
1 Parent(s): 2bb06dc

Upload scheduler/scheduler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scheduler/scheduler.py +227 -0
scheduler/scheduler.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+ import torch.nn.functional as F
24
+
25
+ def gumbel_noise(t, generator=None):
26
+ device = generator.device if generator is not None else t.device
27
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
28
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ @dataclass
40
+ class SchedulerOutput(BaseOutput):
41
+ """
42
+ Output class for the scheduler's `step` function output.
43
+
44
+ Args:
45
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
46
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
47
+ denoising loop.
48
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
50
+ `pred_original_sample` can be used to preview progress or for guidance.
51
+ """
52
+
53
+ prev_sample: torch.Tensor
54
+ pred_original_sample: torch.Tensor = None
55
+
56
+
57
+ class Scheduler(SchedulerMixin, ConfigMixin):
58
+ order = 1
59
+
60
+ temperatures: torch.Tensor
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ mask_token_id: int,
66
+ masking_schedule: str = "cosine",
67
+ ):
68
+ self.temperatures = None
69
+ self.timesteps = None
70
+
71
+ def set_timesteps(
72
+ self,
73
+ num_inference_steps: int,
74
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
75
+ device: Union[str, torch.device] = None,
76
+ ):
77
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
78
+
79
+ if isinstance(temperature, (tuple, list)):
80
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
81
+ else:
82
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
83
+
84
+
85
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
86
+ def top_k_top_p_filtering(
87
+ self,
88
+ logits,
89
+ top_k: int = 0,
90
+ top_p: float = 1.0,
91
+ filter_value: float = -float("Inf"),
92
+ min_tokens_to_keep: int = 1,
93
+ ):
94
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
95
+ Args:
96
+ logits: logits distribution shape (batch size, vocabulary size)
97
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
98
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
99
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
100
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
101
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
102
+ """
103
+ if top_k > 0:
104
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
105
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
106
+ logits[indices_to_remove] = filter_value
107
+
108
+ if top_p < 1.0:
109
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
110
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
111
+
112
+
113
+ sorted_indices_to_remove = cumulative_probs > top_p
114
+ if min_tokens_to_keep > 1:
115
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
116
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
117
+ sorted_indices_to_remove[..., 0] = 0
118
+
119
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_remove)
120
+ logits[indices_to_remove] = filter_value
121
+
122
+ return logits
123
+
124
+
125
+ def step(
126
+ self,
127
+ model_output: torch.Tensor,
128
+ timestep: torch.long,
129
+ sample: torch.LongTensor,
130
+ starting_mask_ratio: int = 1,
131
+ generator: Optional[torch.Generator] = None,
132
+ return_dict: bool = True,
133
+ using_topk_topp: Optional[bool] = False,
134
+ sampling_temperature: Optional[float] = 1.0,
135
+ ) -> Union[SchedulerOutput, Tuple]:
136
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
137
+
138
+ if two_dim_input:
139
+ batch_size, codebook_size, height, width = model_output.shape
140
+ sample = sample.reshape(batch_size, height * width)
141
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
142
+
143
+ unknown_map = sample == self.config.mask_token_id
144
+
145
+ if using_topk_topp:
146
+ model_output = model_output / max(sampling_temperature, 1e-5)
147
+
148
+ if using_topk_topp:
149
+ top_k=8192
150
+ top_p=0.2
151
+ if top_k > 0 or top_p < 1.0:
152
+ model_output = self.top_k_top_p_filtering(model_output, top_k=top_k, top_p=top_p)
153
+
154
+ probs = model_output.softmax(dim=-1)
155
+
156
+ device = probs.device
157
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
158
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
159
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
160
+ probs_ = probs_.reshape(-1, probs.size(-1))
161
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
162
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
163
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
164
+
165
+ if timestep == 0:
166
+ prev_sample = pred_original_sample
167
+ else:
168
+ seq_len = sample.shape[1]
169
+ step_idx = (self.timesteps == timestep).nonzero()
170
+ ratio = (step_idx + 1) / len(self.timesteps)
171
+
172
+ if self.config.masking_schedule == "cosine":
173
+ mask_ratio = torch.cos(ratio * math.pi / 2)
174
+ elif self.config.masking_schedule == "linear":
175
+ mask_ratio = 1 - ratio
176
+ else:
177
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
178
+
179
+ mask_ratio = starting_mask_ratio * mask_ratio
180
+
181
+ mask_len = (seq_len * mask_ratio).floor()
182
+ # do not mask more than amount previously masked
183
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
184
+ # mask at least one
185
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
186
+
187
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
188
+ # Ignores the tokens given in the input by overwriting their confidence.
189
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
190
+
191
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
192
+
193
+ # Masks tokens with lower confidence.
194
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
195
+
196
+ if two_dim_input:
197
+ prev_sample = prev_sample.reshape(batch_size, height, width)
198
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
199
+
200
+ if not return_dict:
201
+ return (prev_sample, pred_original_sample)
202
+
203
+ return SchedulerOutput(prev_sample, pred_original_sample)
204
+
205
+ def add_noise(self, sample, timesteps, generator=None):
206
+ step_idx = (self.timesteps == timesteps).nonzero()
207
+ ratio = (step_idx + 1) / len(self.timesteps)
208
+
209
+ if self.config.masking_schedule == "cosine":
210
+ mask_ratio = torch.cos(ratio * math.pi / 2)
211
+ elif self.config.masking_schedule == "linear":
212
+ mask_ratio = 1 - ratio
213
+ else:
214
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
215
+
216
+ mask_indices = (
217
+ torch.rand(
218
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
219
+ ).to(sample.device)
220
+ < mask_ratio
221
+ )
222
+
223
+ masked_sample = sample.clone()
224
+
225
+ masked_sample[mask_indices] = self.config.mask_token_id
226
+
227
+ return masked_sample