Delete generation_utils.py

#2
by exdysa - opened
Files changed (1) hide show
  1. generation_utils.py +0 -464
generation_utils.py DELETED
@@ -1,464 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import warnings
17
- import copy
18
- from dataclasses import dataclass
19
- from typing import Any, Dict, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.distributions as dists
23
- from torch.nn import functional as F
24
- from transformers import __version__
25
- from transformers.generation.configuration_utils import (
26
- GenerationConfig
27
- )
28
- from transformers.utils import (
29
- ModelOutput,
30
- is_torchdynamo_compiling,
31
- logging,
32
- )
33
-
34
- logger = logging.get_logger(__name__)
35
-
36
-
37
- def top_p_logits(logits, top_p=None):
38
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
- sorted_indices_to_remove = cumulative_probs > top_p
41
- # Shift the indices to the right to keep the first token above the threshold
42
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
- sorted_indices_to_remove[..., 0] = 0
44
-
45
- mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
46
- mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
47
- logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
- return logits
49
-
50
- def top_k_logits(logits, top_k=None):
51
- top_k = min(top_k, logits.size(-1)) # Safety check
52
- # Remove all tokens with a probability less than the last token of the top-k
53
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
- logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
- return logits
56
-
57
-
58
- def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
-
60
- if temperature > 0:
61
- logits = logits / temperature
62
- if top_p is not None and top_p < 1:
63
- logits = top_p_logits(logits, top_p)
64
- if top_k is not None:
65
- logits = top_k_logits(logits, top_k)
66
- probs = torch.softmax(logits, dim=-1)
67
-
68
- if temperature > 0:
69
- try:
70
- x0 = dists.Categorical(probs=probs).sample()
71
- confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
- except:
73
- confidence, x0 = probs.max(dim=-1)
74
- else:
75
- confidence, x0 = probs.max(dim=-1)
76
-
77
- if margin_confidence:
78
- sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
- # Extract top1 and top2 probabilities
80
- top1_probs = sorted_probs[:, 0]
81
- top2_probs = sorted_probs[:, 1]
82
- # Calculate confidence as top1 - top2
83
- confidence = top1_probs - top2_probs
84
-
85
- if neg_entropy:
86
- epsilon = 1e-10
87
- log_probs = torch.log(probs + epsilon)
88
- confidence = torch.sum(probs * log_probs, dim=-1)
89
-
90
- return confidence, x0
91
-
92
-
93
- @dataclass
94
- class DreamModelOutput(ModelOutput):
95
- sequences: torch.LongTensor = None
96
- history: Optional[Tuple[torch.FloatTensor]] = None
97
-
98
-
99
- class DreamGenerationConfig(GenerationConfig):
100
- def __init__(self, **kwargs):
101
- self.temperature: float = kwargs.pop("temperature", 0.0)
102
- self.top_p: Optional[float] = kwargs.pop("top_p", None)
103
- self.top_k: Optional[int] = kwargs.pop("top_k", None)
104
- self.max_length = kwargs.pop("max_length", 20)
105
- self.max_new_tokens = kwargs.pop("max_new_tokens", None)
106
- # diffusion specific params
107
- self.eps: float = kwargs.pop("eps", 1e-3)
108
- self.steps: int = kwargs.pop("steps", 512)
109
- self.alg: str = kwargs.pop("alg", 'origin')
110
- self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
-
112
- # Parameters that define the output variables of `generate`
113
- self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
- self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
115
- self.output_history: bool = kwargs.pop("output_history", False)
116
-
117
- # Special tokens that can be used at generation time
118
- self.mask_token_id = kwargs.pop("mask_token_id", None)
119
- self.pad_token_id = kwargs.pop("pad_token_id", None)
120
- self.bos_token_id = kwargs.pop("bos_token_id", None)
121
- self.eos_token_id = kwargs.pop("eos_token_id", None)
122
-
123
- # Wild card
124
- self.generation_kwargs = kwargs.pop("generation_kwargs", {})
125
-
126
- # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
127
- # interface.
128
- self._from_model_config = kwargs.pop("_from_model_config", False)
129
- self._commit_hash = kwargs.pop("_commit_hash", None)
130
- self.transformers_version = kwargs.pop("transformers_version", __version__)
131
-
132
- # Additional attributes without default values
133
- if not self._from_model_config:
134
- # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
135
- # model's default configuration file
136
- for key, value in kwargs.items():
137
- try:
138
- setattr(self, key, value)
139
- except AttributeError as err:
140
- logger.error(f"Can't set {key} with value {value} for {self}")
141
- raise err
142
-
143
- # Validate the values of the attributes
144
- self.validate(is_init=True)
145
-
146
- def validate(self, is_init=False):
147
- pass
148
-
149
- class DreamGenerationMixin:
150
- @staticmethod
151
- def _expand_inputs_for_generation(
152
- expand_size: int = 1,
153
- input_ids: Optional[torch.LongTensor] = None,
154
- attention_mask: Optional[torch.LongTensor] = None
155
- ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
156
- """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
157
- # Do not call torch.repeat_interleave if expand_size is 1 because it clones
158
- # the input tensor and thus requires more memory although no change is applied
159
- if expand_size == 1:
160
- return input_ids, attention_mask
161
- if input_ids is not None:
162
- input_ids = input_ids.repeat_interleave(expand_size, dim=0)
163
- if attention_mask is not None:
164
- attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
- return input_ids, attention_mask
166
-
167
- def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
- """Performs validation related to the resulting generated length"""
169
-
170
- # Can't throw warnings/exceptions during compilation
171
- if is_torchdynamo_compiling():
172
- return
173
-
174
- # 1. Max length warnings related to poor parameterization
175
- if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
176
- # 20 is the default max_length of the generation config
177
- warnings.warn(
178
- f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
179
- "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
180
- "generation.",
181
- UserWarning,
182
- )
183
- if input_ids_length >= generation_config.max_length:
184
- input_ids_string = "input_ids"
185
- raise ValueError(
186
- f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
187
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
188
- " increasing `max_length` or, better yet, setting `max_new_tokens`."
189
- )
190
-
191
- def _prepare_generated_length(
192
- self,
193
- generation_config,
194
- has_default_max_length,
195
- input_ids_length,
196
- ):
197
- """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
198
-
199
- if generation_config.max_new_tokens is not None:
200
- if not has_default_max_length and generation_config.max_length is not None:
201
- logger.warning(
202
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
203
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
204
- "Please refer to the documentation for more information. "
205
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
206
- )
207
- generation_config.max_length = generation_config.max_new_tokens + input_ids_length
208
-
209
- elif has_default_max_length:
210
- if generation_config.max_length == DreamGenerationConfig().max_length:
211
- generation_config.max_length = generation_config.max_length + input_ids_length
212
- max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
213
- if max_position_embeddings is not None:
214
- generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
215
-
216
- return generation_config
217
-
218
- def _prepare_generation_config(
219
- self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
220
- ) -> DreamGenerationConfig:
221
- """
222
- Prepares the base generation config, then applies any generation configuration options from kwargs. This
223
- function handles retrocompatibility with respect to configuration files.
224
- """
225
- # priority: `generation_config` argument > `model.generation_config` (the default generation config)
226
- using_model_generation_config = False
227
- if generation_config is None:
228
- generation_config = DreamGenerationConfig.from_model_config(self.config)
229
- using_model_generation_config = True
230
-
231
- # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
232
- # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
233
- # exception will be raised in `_validate_model_kwargs`
234
- if not is_torchdynamo_compiling():
235
- generation_config = copy.deepcopy(generation_config)
236
- _kwargs = generation_config.update(**kwargs)
237
- # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
238
- if not using_model_generation_config:
239
- if generation_config.bos_token_id is None:
240
- generation_config.bos_token_id = self.generation_config.bos_token_id
241
- if generation_config.eos_token_id is None:
242
- generation_config.eos_token_id = self.generation_config.eos_token_id
243
- if generation_config.pad_token_id is None:
244
- generation_config.pad_token_id = self.generation_config.pad_token_id
245
- if generation_config.mask_token_id is None:
246
- generation_config.mask_token_id = self.generation_config.mask_token_id
247
-
248
- return generation_config
249
-
250
- def _prepare_special_tokens(
251
- self,
252
- generation_config: DreamGenerationConfig,
253
- device: Optional[Union[torch.device, str]] = None,
254
- ):
255
- """
256
- Prepares the special tokens for generation, overwriting the generation config with their processed versions
257
- converted to tensor.
258
-
259
- Note that `generation_config` is changed in place and stops being serializable after this method is called.
260
- That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
261
- function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
262
- """
263
-
264
- # Convert special tokens to tensors
265
- def _tensor_or_none(token, device=None):
266
- if token is None:
267
- return token
268
-
269
- device = device if device is not None else self.device
270
- if isinstance(token, torch.Tensor):
271
- return token.to(device)
272
- return torch.tensor(token, device=device, dtype=torch.long)
273
-
274
- bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
275
- eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
276
- pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
277
- mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
278
-
279
- # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
280
- if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
281
- eos_token_tensor = eos_token_tensor.unsqueeze(0)
282
-
283
- # Set pad token if unset (and there are conditions to do so)
284
- if pad_token_tensor is None and eos_token_tensor is not None:
285
- pad_token_tensor = eos_token_tensor[0]
286
- logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
287
-
288
- # Update generation config with the updated special tokens tensors
289
- # NOTE: this must be written into a different attribute name than the one holding the original special tokens
290
- # (in their non-tensor form), in order to enable end-to-end compilation. See
291
- # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
292
- generation_config._bos_token_tensor = bos_token_tensor
293
- generation_config._eos_token_tensor = eos_token_tensor
294
- generation_config._pad_token_tensor = pad_token_tensor
295
- generation_config._mask_token_tensor = mask_token_tensor
296
-
297
- @torch.no_grad()
298
- def diffusion_generate(
299
- self,
300
- inputs: Optional[torch.Tensor] = None,
301
- generation_config: Optional[DreamGenerationConfig] = None,
302
- **kwargs,
303
- ) -> Union[DreamModelOutput, torch.LongTensor]:
304
- # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
- generation_config = self._prepare_generation_config(generation_config, **kwargs)
306
- generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
307
- generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
308
-
309
- # 2. Define model inputs
310
- assert inputs is not None
311
- input_ids = inputs
312
- device = input_ids.device
313
- attention_mask = kwargs.pop("attention_mask", None)
314
- self._prepare_special_tokens(generation_config, device=device)
315
-
316
- # 3. Prepare `max_length`.
317
- input_ids_length = input_ids.shape[-1]
318
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
319
- generation_config = self._prepare_generated_length(
320
- generation_config=generation_config,
321
- has_default_max_length=has_default_max_length,
322
- input_ids_length=input_ids_length,
323
- )
324
-
325
- self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
326
-
327
- # 4. Check input_ids
328
- if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
329
- warnings.warn(
330
- "You are calling .generate() with the `input_ids` being on a device type different"
331
- f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
332
- f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
333
- " Please make sure that you have put `input_ids` to the"
334
- f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
335
- " running `.generate()`.",
336
- UserWarning,
337
- )
338
- if (
339
- hasattr(generation_config, "pad_token_id") and
340
- torch.any(input_ids == generation_config.pad_token_id) and
341
- attention_mask is None
342
- ):
343
- warnings.warn(
344
- "Padding was detected but no attention mask is passed here. For correct "
345
- "generation results, please set `attention_mask` when batch-padding inputs.",
346
- UserWarning,
347
- )
348
-
349
- input_ids, attention_mask = self._expand_inputs_for_generation(
350
- expand_size=generation_config.num_return_sequences,
351
- input_ids=input_ids,
352
- attention_mask=attention_mask
353
- )
354
-
355
- result = self._sample(
356
- input_ids,
357
- attention_mask=attention_mask,
358
- generation_config=generation_config,
359
- generation_tokens_hook_func=generation_tokens_hook_func,
360
- generation_logits_hook_func=generation_logits_hook_func
361
- )
362
- return result
363
-
364
- def _sample(
365
- self,
366
- input_ids: torch.LongTensor,
367
- attention_mask: Optional[torch.LongTensor],
368
- generation_config: DreamGenerationConfig,
369
- generation_tokens_hook_func,
370
- generation_logits_hook_func
371
- ) -> Union[DreamModelOutput, torch.LongTensor]:
372
- # init values
373
- output_history = generation_config.output_history
374
- return_dict_in_generate = generation_config.return_dict_in_generate
375
- max_length = generation_config.max_length
376
- mask_token_id = generation_config.mask_token_id
377
- steps = generation_config.steps
378
- eps = generation_config.eps
379
- alg = generation_config.alg
380
- alg_temp = generation_config.alg_temp
381
- temperature = generation_config.temperature
382
- top_p = generation_config.top_p
383
- top_k = generation_config.top_k
384
-
385
- histories = [] if (return_dict_in_generate and output_history) else None
386
-
387
- # pad input_ids to max_length
388
- x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
389
-
390
- if attention_mask is not None and torch.any(attention_mask == 0.0):
391
- # we do not mask the [MASK] tokens so value = 1.0
392
- attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
393
- tok_idx = attention_mask.long().cumsum(-1) - 1
394
- tok_idx.masked_fill_(attention_mask == 0, 1)
395
- # attention_mask is of shape [B, N]
396
- # broadcast to [B, 1, N, N]
397
- attention_mask = torch.logical_and(
398
- attention_mask.unsqueeze(1).unsqueeze(-2),
399
- attention_mask.unsqueeze(1).unsqueeze(-1),
400
- )
401
- else:
402
- tok_idx = None
403
- attention_mask = "full"
404
-
405
- timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
-
407
- # this allows user-defined token control of the intermediate steps
408
- x = generation_tokens_hook_func(None, x, None)
409
- for i in range(steps):
410
- mask_index = (x == mask_token_id)
411
- logits = self(x, attention_mask, tok_idx).logits
412
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
413
-
414
- # this allows user-defined logits control of the intermediate steps
415
- logits = generation_logits_hook_func(i, x, logits)
416
-
417
- mask_logits = logits[mask_index]
418
- t = timesteps[i]
419
- s = timesteps[i + 1]
420
-
421
- if alg == 'origin':
422
- p_transfer = 1 - s / t if i < steps - 1 else 1
423
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
424
- transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
425
- _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
426
- x[mask_index] = x0.clone()
427
- else:
428
- if alg == 'maskgit_plus':
429
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
- elif alg == 'topk_margin':
431
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
- elif alg == 'entropy':
433
- confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
- else:
435
- raise RuntimeError(f"Unknown alg: {alg}")
436
- num_mask_token = mask_index.sum() / mask_index.shape[0]
437
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
- full_confidence[mask_index] = confidence
440
- if number_transfer_tokens > 0:
441
- if alg_temp is None or alg_temp == 0:
442
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
- else:
444
- full_confidence = full_confidence / alg_temp
445
- full_confidence = F.softmax(full_confidence, dim=-1)
446
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
- x_[mask_index] = x0.clone()
449
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
- x[row_indices,transfer_index] = x_[row_indices,transfer_index]
451
-
452
- # this allows user-defined token control of the intermediate steps
453
- x = generation_tokens_hook_func(i, x, logits)
454
-
455
- if histories is not None:
456
- histories.append(x.clone())
457
-
458
- if return_dict_in_generate:
459
- return DreamModelOutput(
460
- sequences=x,
461
- history=histories,
462
- )
463
- else:
464
- return x