autoprogrammer commited on
Commit
bb6d931
·
verified ·
1 Parent(s): 6be7e05

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +86 -18
generation_utils.py CHANGED
@@ -109,6 +109,10 @@ class DreamGenerationConfig(GenerationConfig):
109
  self.alg: str = kwargs.pop("alg", 'origin')
110
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
 
 
 
 
 
112
  # Parameters that define the output variables of `generate`
113
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
@@ -164,6 +168,56 @@ class DreamGenerationMixin:
164
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
  return input_ids, attention_mask
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
  """Performs validation related to the resulting generated length"""
169
 
@@ -382,6 +436,10 @@ class DreamGenerationMixin:
382
  top_p = generation_config.top_p
383
  top_k = generation_config.top_k
384
 
 
 
 
 
385
  histories = [] if (return_dict_in_generate and output_history) else None
386
 
387
  # pad input_ids to max_length
@@ -404,6 +462,9 @@ class DreamGenerationMixin:
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
 
 
 
 
407
  # this allows user-defined token control of the intermediate steps
408
  x = generation_tokens_hook_func(None, x, None)
409
  for i in range(steps):
@@ -425,29 +486,36 @@ class DreamGenerationMixin:
425
  _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
426
  x[mask_index] = x0.clone()
427
  else:
428
- if alg == 'maskgit_plus':
429
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
- elif alg == 'topk_margin':
431
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
- elif alg == 'entropy':
433
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
  else:
435
  raise RuntimeError(f"Unknown alg: {alg}")
436
- num_mask_token = mask_index.sum() / mask_index.shape[0]
437
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
- full_confidence[mask_index] = confidence
440
- if number_transfer_tokens > 0:
441
- if alg_temp is None or alg_temp == 0:
442
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
- else:
444
- full_confidence = full_confidence / alg_temp
445
- full_confidence = F.softmax(full_confidence, dim=-1)
446
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
- x_[mask_index] = x0.clone()
449
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
- x[row_indices,transfer_index] = x_[row_indices,transfer_index]
 
 
 
 
 
 
 
451
 
452
  # this allows user-defined token control of the intermediate steps
453
  x = generation_tokens_hook_func(i, x, logits)
 
109
  self.alg: str = kwargs.pop("alg", 'origin')
110
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
 
112
+ # RCR specific parameters
113
+ self.rcr: bool = kwargs.pop("rcr", False)
114
+ self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
115
+
116
  # Parameters that define the output variables of `generate`
117
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
118
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
 
168
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
169
  return input_ids, attention_mask
170
 
171
+ def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
172
+ mask_token_id, step, total_steps, s, t):
173
+ """
174
+ Apply Running Confidence Remasking (RCR) logic adapted for Dream model.
175
+ """
176
+ batch_size = x.shape[0]
177
+
178
+ # Calculate number of tokens to transfer using Dream's scheduling
179
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
180
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
181
+
182
+ # Update predictions for masked positions only
183
+ x0 = torch.where(mask_index, x0, x)
184
+ confidence = torch.where(mask_index, confidence, torch.tensor(-float('inf'), device=x0.device))
185
+
186
+ # RCR: Select tokens based on cumulative confidence
187
+ for j in range(batch_size):
188
+ if number_transfer_tokens > 0:
189
+ batch_confidence = confidence[j]
190
+ batch_mask_index = mask_index[j]
191
+
192
+ # Select top confident tokens to transfer
193
+ _, select_indices = torch.topk(batch_confidence, k=number_transfer_tokens, largest=True)
194
+ x[j, select_indices] = x0[j, select_indices]
195
+ overtime_confidence[j, select_indices] = batch_confidence[select_indices].clone().float()
196
+
197
+ # RCR: Re-mask lowest confidence tokens for next steps
198
+ if step < total_steps - 1:
199
+ # Find tokens that have been generated (non-zero confidence)
200
+ generated_mask = overtime_confidence[j] > 0
201
+ if generated_mask.any():
202
+ # Calculate tokens to re-mask for next iteration
203
+ next_num_mask_tokens = int(num_mask_token * (1 - torch.linspace(1, s, total_steps + 1, device=x.device)[step + 2] / t))
204
+
205
+ if next_num_mask_tokens > 0:
206
+ # Get confidence of generated tokens
207
+ generated_confidence = overtime_confidence[j][generated_mask]
208
+ generated_indices = torch.where(generated_mask)[0]
209
+
210
+ if len(generated_confidence) >= next_num_mask_tokens:
211
+ # Re-mask lowest confidence tokens
212
+ _, local_mask_indices = torch.topk(
213
+ generated_confidence,
214
+ k=next_num_mask_tokens,
215
+ largest=False
216
+ )
217
+ global_mask_indices = generated_indices[local_mask_indices]
218
+ x[j, global_mask_indices] = mask_token_id
219
+ overtime_confidence[j, global_mask_indices] = 0.0
220
+
221
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
222
  """Performs validation related to the resulting generated length"""
223
 
 
436
  top_p = generation_config.top_p
437
  top_k = generation_config.top_k
438
 
439
+ # RCR specific values
440
+ rcr = generation_config.rcr
441
+ conf_alg = generation_config.conf_alg
442
+
443
  histories = [] if (return_dict_in_generate and output_history) else None
444
 
445
  # pad input_ids to max_length
 
462
 
463
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
464
 
465
+ # RCR tracking - initialize overtime confidence tracking
466
+ overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
467
+
468
  # this allows user-defined token control of the intermediate steps
469
  x = generation_tokens_hook_func(None, x, None)
470
  for i in range(steps):
 
486
  _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
487
  x[mask_index] = x0.clone()
488
  else:
489
+ if alg == 'maskgit_plus' or (rcr and conf_alg == 'maskgit_plus'):
490
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
491
+ elif alg == 'topk_margin' or (rcr and conf_alg == 'topk_margin'):
492
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
493
+ elif alg == 'entropy' or (rcr and conf_alg == 'entropy'):
494
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
495
  else:
496
  raise RuntimeError(f"Unknown alg: {alg}")
497
+
498
+ # Apply RCR logic if enabled
499
+ if rcr:
500
+ self._apply_rcr_logic(x, x0, confidence, mask_index, overtime_confidence,
501
+ mask_token_id, i, steps, s, t)
502
+ else:
503
+ # Original Dream sampling logic
504
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
505
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
506
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
507
+ full_confidence[mask_index] = confidence
508
+ if number_transfer_tokens > 0:
509
+ if alg_temp is None or alg_temp == 0:
510
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
511
+ else:
512
+ full_confidence = full_confidence / alg_temp
513
+ full_confidence = F.softmax(full_confidence, dim=-1)
514
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
515
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
516
+ x_[mask_index] = x0.clone()
517
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
518
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
519
 
520
  # this allows user-defined token control of the intermediate steps
521
  x = generation_tokens_hook_func(i, x, logits)