Update generation_utils.py
Browse files- generation_utils.py +86 -18
generation_utils.py
CHANGED
|
@@ -109,6 +109,10 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 109 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 110 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# Parameters that define the output variables of `generate`
|
| 113 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 114 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
|
@@ -164,6 +168,56 @@ class DreamGenerationMixin:
|
|
| 164 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 165 |
return input_ids, attention_mask
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 168 |
"""Performs validation related to the resulting generated length"""
|
| 169 |
|
|
@@ -382,6 +436,10 @@ class DreamGenerationMixin:
|
|
| 382 |
top_p = generation_config.top_p
|
| 383 |
top_k = generation_config.top_k
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 386 |
|
| 387 |
# pad input_ids to max_length
|
|
@@ -404,6 +462,9 @@ class DreamGenerationMixin:
|
|
| 404 |
|
| 405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 406 |
|
|
|
|
|
|
|
|
|
|
| 407 |
# this allows user-defined token control of the intermediate steps
|
| 408 |
x = generation_tokens_hook_func(None, x, None)
|
| 409 |
for i in range(steps):
|
|
@@ -425,29 +486,36 @@ class DreamGenerationMixin:
|
|
| 425 |
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
| 426 |
x[mask_index] = x0.clone()
|
| 427 |
else:
|
| 428 |
-
if alg == 'maskgit_plus':
|
| 429 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 430 |
-
elif alg == 'topk_margin':
|
| 431 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
| 432 |
-
elif alg == 'entropy':
|
| 433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
| 434 |
else:
|
| 435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
# this allows user-defined token control of the intermediate steps
|
| 453 |
x = generation_tokens_hook_func(i, x, logits)
|
|
|
|
| 109 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 110 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 111 |
|
| 112 |
+
# RCR specific parameters
|
| 113 |
+
self.rcr: bool = kwargs.pop("rcr", False)
|
| 114 |
+
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
| 115 |
+
|
| 116 |
# Parameters that define the output variables of `generate`
|
| 117 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
| 118 |
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
|
|
|
| 168 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 169 |
return input_ids, attention_mask
|
| 170 |
|
| 171 |
+
def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
|
| 172 |
+
mask_token_id, step, total_steps, s, t):
|
| 173 |
+
"""
|
| 174 |
+
Apply Running Confidence Remasking (RCR) logic adapted for Dream model.
|
| 175 |
+
"""
|
| 176 |
+
batch_size = x.shape[0]
|
| 177 |
+
|
| 178 |
+
# Calculate number of tokens to transfer using Dream's scheduling
|
| 179 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 180 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 181 |
+
|
| 182 |
+
# Update predictions for masked positions only
|
| 183 |
+
x0 = torch.where(mask_index, x0, x)
|
| 184 |
+
confidence = torch.where(mask_index, confidence, torch.tensor(-float('inf'), device=x0.device))
|
| 185 |
+
|
| 186 |
+
# RCR: Select tokens based on cumulative confidence
|
| 187 |
+
for j in range(batch_size):
|
| 188 |
+
if number_transfer_tokens > 0:
|
| 189 |
+
batch_confidence = confidence[j]
|
| 190 |
+
batch_mask_index = mask_index[j]
|
| 191 |
+
|
| 192 |
+
# Select top confident tokens to transfer
|
| 193 |
+
_, select_indices = torch.topk(batch_confidence, k=number_transfer_tokens, largest=True)
|
| 194 |
+
x[j, select_indices] = x0[j, select_indices]
|
| 195 |
+
overtime_confidence[j, select_indices] = batch_confidence[select_indices].clone().float()
|
| 196 |
+
|
| 197 |
+
# RCR: Re-mask lowest confidence tokens for next steps
|
| 198 |
+
if step < total_steps - 1:
|
| 199 |
+
# Find tokens that have been generated (non-zero confidence)
|
| 200 |
+
generated_mask = overtime_confidence[j] > 0
|
| 201 |
+
if generated_mask.any():
|
| 202 |
+
# Calculate tokens to re-mask for next iteration
|
| 203 |
+
next_num_mask_tokens = int(num_mask_token * (1 - torch.linspace(1, s, total_steps + 1, device=x.device)[step + 2] / t))
|
| 204 |
+
|
| 205 |
+
if next_num_mask_tokens > 0:
|
| 206 |
+
# Get confidence of generated tokens
|
| 207 |
+
generated_confidence = overtime_confidence[j][generated_mask]
|
| 208 |
+
generated_indices = torch.where(generated_mask)[0]
|
| 209 |
+
|
| 210 |
+
if len(generated_confidence) >= next_num_mask_tokens:
|
| 211 |
+
# Re-mask lowest confidence tokens
|
| 212 |
+
_, local_mask_indices = torch.topk(
|
| 213 |
+
generated_confidence,
|
| 214 |
+
k=next_num_mask_tokens,
|
| 215 |
+
largest=False
|
| 216 |
+
)
|
| 217 |
+
global_mask_indices = generated_indices[local_mask_indices]
|
| 218 |
+
x[j, global_mask_indices] = mask_token_id
|
| 219 |
+
overtime_confidence[j, global_mask_indices] = 0.0
|
| 220 |
+
|
| 221 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 222 |
"""Performs validation related to the resulting generated length"""
|
| 223 |
|
|
|
|
| 436 |
top_p = generation_config.top_p
|
| 437 |
top_k = generation_config.top_k
|
| 438 |
|
| 439 |
+
# RCR specific values
|
| 440 |
+
rcr = generation_config.rcr
|
| 441 |
+
conf_alg = generation_config.conf_alg
|
| 442 |
+
|
| 443 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 444 |
|
| 445 |
# pad input_ids to max_length
|
|
|
|
| 462 |
|
| 463 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 464 |
|
| 465 |
+
# RCR tracking - initialize overtime confidence tracking
|
| 466 |
+
overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
|
| 467 |
+
|
| 468 |
# this allows user-defined token control of the intermediate steps
|
| 469 |
x = generation_tokens_hook_func(None, x, None)
|
| 470 |
for i in range(steps):
|
|
|
|
| 486 |
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
| 487 |
x[mask_index] = x0.clone()
|
| 488 |
else:
|
| 489 |
+
if alg == 'maskgit_plus' or (rcr and conf_alg == 'maskgit_plus'):
|
| 490 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 491 |
+
elif alg == 'topk_margin' or (rcr and conf_alg == 'topk_margin'):
|
| 492 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
| 493 |
+
elif alg == 'entropy' or (rcr and conf_alg == 'entropy'):
|
| 494 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
| 495 |
else:
|
| 496 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 497 |
+
|
| 498 |
+
# Apply RCR logic if enabled
|
| 499 |
+
if rcr:
|
| 500 |
+
self._apply_rcr_logic(x, x0, confidence, mask_index, overtime_confidence,
|
| 501 |
+
mask_token_id, i, steps, s, t)
|
| 502 |
+
else:
|
| 503 |
+
# Original Dream sampling logic
|
| 504 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 505 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
| 506 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
| 507 |
+
full_confidence[mask_index] = confidence
|
| 508 |
+
if number_transfer_tokens > 0:
|
| 509 |
+
if alg_temp is None or alg_temp == 0:
|
| 510 |
+
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 511 |
+
else:
|
| 512 |
+
full_confidence = full_confidence / alg_temp
|
| 513 |
+
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 514 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 515 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 516 |
+
x_[mask_index] = x0.clone()
|
| 517 |
+
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 518 |
+
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
| 519 |
|
| 520 |
# this allows user-defined token control of the intermediate steps
|
| 521 |
x = generation_tokens_hook_func(i, x, logits)
|