autoprogrammer commited on
Commit
85f58f1
·
verified ·
1 Parent(s): bb6d931

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +13 -8
generation_utils.py CHANGED
@@ -179,20 +179,25 @@ class DreamGenerationMixin:
179
  num_mask_token = mask_index.sum() / mask_index.shape[0]
180
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
181
 
182
- # Update predictions for masked positions only
183
- x0 = torch.where(mask_index, x0, x)
184
- confidence = torch.where(mask_index, confidence, torch.tensor(-float('inf'), device=x0.device))
 
 
 
 
 
 
185
 
186
  # RCR: Select tokens based on cumulative confidence
187
  for j in range(batch_size):
188
  if number_transfer_tokens > 0:
189
- batch_confidence = confidence[j]
190
- batch_mask_index = mask_index[j]
191
 
192
  # Select top confident tokens to transfer
193
- _, select_indices = torch.topk(batch_confidence, k=number_transfer_tokens, largest=True)
194
- x[j, select_indices] = x0[j, select_indices]
195
- overtime_confidence[j, select_indices] = batch_confidence[select_indices].clone().float()
196
 
197
  # RCR: Re-mask lowest confidence tokens for next steps
198
  if step < total_steps - 1:
 
179
  num_mask_token = mask_index.sum() / mask_index.shape[0]
180
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
181
 
182
+ # Create full confidence tensor matching x dimensions
183
+ full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=confidence.dtype)
184
+
185
+ # Create temporary tensor for x0 that matches x dimensions
186
+ x_temp = torch.zeros_like(x, device=x.device, dtype=torch.long) + mask_token_id
187
+
188
+ # Fill masked positions with x0 and confidence
189
+ x_temp[mask_index] = x0.clone()
190
+ full_confidence[mask_index] = confidence
191
 
192
  # RCR: Select tokens based on cumulative confidence
193
  for j in range(batch_size):
194
  if number_transfer_tokens > 0:
195
+ batch_full_confidence = full_confidence[j]
 
196
 
197
  # Select top confident tokens to transfer
198
+ _, select_indices = torch.topk(batch_full_confidence, k=number_transfer_tokens, largest=True)
199
+ x[j, select_indices] = x_temp[j, select_indices]
200
+ overtime_confidence[j, select_indices] = batch_full_confidence[select_indices].clone().float()
201
 
202
  # RCR: Re-mask lowest confidence tokens for next steps
203
  if step < total_steps - 1: