Update generation_utils.py
Browse files- generation_utils.py +13 -8
generation_utils.py
CHANGED
|
@@ -179,20 +179,25 @@ class DreamGenerationMixin:
|
|
| 179 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 180 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 181 |
|
| 182 |
-
#
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# RCR: Select tokens based on cumulative confidence
|
| 187 |
for j in range(batch_size):
|
| 188 |
if number_transfer_tokens > 0:
|
| 189 |
-
|
| 190 |
-
batch_mask_index = mask_index[j]
|
| 191 |
|
| 192 |
# Select top confident tokens to transfer
|
| 193 |
-
_, select_indices = torch.topk(
|
| 194 |
-
x[j, select_indices] =
|
| 195 |
-
overtime_confidence[j, select_indices] =
|
| 196 |
|
| 197 |
# RCR: Re-mask lowest confidence tokens for next steps
|
| 198 |
if step < total_steps - 1:
|
|
|
|
| 179 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 180 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 181 |
|
| 182 |
+
# Create full confidence tensor matching x dimensions
|
| 183 |
+
full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=confidence.dtype)
|
| 184 |
+
|
| 185 |
+
# Create temporary tensor for x0 that matches x dimensions
|
| 186 |
+
x_temp = torch.zeros_like(x, device=x.device, dtype=torch.long) + mask_token_id
|
| 187 |
+
|
| 188 |
+
# Fill masked positions with x0 and confidence
|
| 189 |
+
x_temp[mask_index] = x0.clone()
|
| 190 |
+
full_confidence[mask_index] = confidence
|
| 191 |
|
| 192 |
# RCR: Select tokens based on cumulative confidence
|
| 193 |
for j in range(batch_size):
|
| 194 |
if number_transfer_tokens > 0:
|
| 195 |
+
batch_full_confidence = full_confidence[j]
|
|
|
|
| 196 |
|
| 197 |
# Select top confident tokens to transfer
|
| 198 |
+
_, select_indices = torch.topk(batch_full_confidence, k=number_transfer_tokens, largest=True)
|
| 199 |
+
x[j, select_indices] = x_temp[j, select_indices]
|
| 200 |
+
overtime_confidence[j, select_indices] = batch_full_confidence[select_indices].clone().float()
|
| 201 |
|
| 202 |
# RCR: Re-mask lowest confidence tokens for next steps
|
| 203 |
if step < total_steps - 1:
|