File size: 6,991 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | """
Schedule-aware remasking and insertion logic that ensures the number of masked tokens
follows the interpolant schedule.
"""
import torch
import numpy as np
def apply_schedule_aware_insertion(
model,
xt_tmp,
new_xt,
t,
dt,
ext,
mask,
pad,
max_length,
orig_mask,
new_pos_orig,
quality_threshold=1,
):
"""
Remove low-quality insertions based on insertion confidence while respecting
the interpolant schedule for expected sequence length.
Args:
model: Model with planner and interpolant
xt_tmp: Sequence after insertion [B, L]
new_xt: Sequence before insertion [B, L]
t: Current time [B]
dt: Time step size
ext: Number of insertions per gap [B, L+1]
mask: Mask token ID
pad: Pad token ID
max_length: Maximum sequence length
orig_mask: Mask of original token positions [B, L]
new_pos_orig: New positions of original tokens [B, L]
quality_threshold: If a float, drop insertions with confidence below it
Returns:
xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
"""
device = xt_tmp.device
batch_size, L = xt_tmp.shape
total_ext = ext.sum(dim=1)
# Only proceed if there were insertions
if total_ext.sum() == 0:
return xt_tmp
# Get planner predictions on inserted state. The insertion head is trained
# with the pre-step time t (see loss_insert_planner_flexible), so condition
# on t here too; t_next is still used below for the length schedule.
t_next = t + dt
planner_out = model.planner(xt_tmp, t)
insertion_conf = planner_out.get("insertion_conf", None)
if insertion_conf is None:
return xt_tmp
insertion_conf = insertion_conf.squeeze(-1) # (B, L)
# Expected sequence length at next timestep according to schedule
current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
expected_length = estimated_final_length * expected_progress # [B]
# Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
# Fancy-indexing scatter avoids the per-batch python loop.
valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
is_original[valid_b, valid_p] = True
inserted_positions = (xt_tmp == mask) & ~is_original
# Two deletion modes, selected by `quality_threshold`:
# * float: drop insertions whose confidence is below the threshold, capped
# so the length never falls below the scheduled minimum.
candidates = inserted_positions & (insertion_conf < quality_threshold)
num_bad = candidates.sum(dim=1) # [B], long
min_length = expected_length.long().clamp(min=1) # [B]
max_removable = (current_length_after.long() - min_length).clamp(min=0)
length_after_removal = current_length_after.long() - num_bad
schedule_violates = length_after_removal < min_length
k_per_row = torch.where(schedule_violates, max_removable, num_bad)
k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
if not candidates.any():
return xt_tmp
# Select the lowest-confidence candidates per row via a sort.
neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
_, sorted_indices = scores.sort(dim=1, descending=True)
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
final_bad = torch.zeros_like(candidates)
final_bad.scatter_(1, sorted_indices, keep_in_topk)
if not final_bad.any():
return xt_tmp
# Compact each row to the left (keep good, drop bad), then pad the tail.
# Stable sort by the bad flag pushes bad positions to the right.
sort_key = final_bad.long()
_, perm = torch.sort(sort_key, dim=1, stable=True)
xt_tmp = torch.gather(xt_tmp, 1, perm)
num_keep = (~final_bad).sum(dim=1) # [B]
tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
return xt_tmp
def apply_schedule_aware_remasking(
model,
new_xt,
t,
dt,
remasking_conf,
clean_index,
mask,
neg_inf,
batch_size,
unmask_quality_threshold=None,
):
"""
Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
Args:
model: Model with interpolant that has an unmask_schedule
new_xt: Current sequence [B, L]
t: Current time [B]
dt: Time step size
remasking_conf: Confidence scores for tokens [B, L]
clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
mask: Mask token ID
neg_inf: Negative infinity tensor
batch_size: Batch size
Returns:
new_xt: Modified sequence with schedule-aware remasking applied
"""
# Optional AJD threshold gate (overrides the schedule-driven count when set):
# remask every clean token whose unmasking-quality confidence is below the
# threshold. Higher threshold => more aggressive remasking.
if unmask_quality_threshold is not None:
to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
t_next = t + dt
num_clean = clean_index.sum(dim=1) # [B], long
current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
# Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
if k_per_row.sum() == 0:
return new_xt
# Use confidence to decide which clean tokens to remask: lowest conf first.
remasking_score_temp = -1.0 * remasking_conf # low conf = high score
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
_, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
L = remasking_score_temp.shape[1]
positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
to_mask = torch.zeros_like(clean_index)
to_mask.scatter_(1, sorted_indices, keep_in_topk)
new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
return new_xt
|