A2D2 / a2d2_mol /remasking_scheduleaware.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
6.99 kB
"""
Schedule-aware remasking and insertion logic that ensures the number of masked tokens
follows the interpolant schedule.
"""
import torch
import numpy as np
def apply_schedule_aware_insertion(
model,
xt_tmp,
new_xt,
t,
dt,
ext,
mask,
pad,
max_length,
orig_mask,
new_pos_orig,
quality_threshold=1,
):
"""
Remove low-quality insertions based on insertion confidence while respecting
the interpolant schedule for expected sequence length.
Args:
model: Model with planner and interpolant
xt_tmp: Sequence after insertion [B, L]
new_xt: Sequence before insertion [B, L]
t: Current time [B]
dt: Time step size
ext: Number of insertions per gap [B, L+1]
mask: Mask token ID
pad: Pad token ID
max_length: Maximum sequence length
orig_mask: Mask of original token positions [B, L]
new_pos_orig: New positions of original tokens [B, L]
quality_threshold: If a float, drop insertions with confidence below it
Returns:
xt_tmp: Modified sequence with low-quality insertions removed (respecting schedule)
"""
device = xt_tmp.device
batch_size, L = xt_tmp.shape
total_ext = ext.sum(dim=1)
# Only proceed if there were insertions
if total_ext.sum() == 0:
return xt_tmp
# Get planner predictions on inserted state. The insertion head is trained
# with the pre-step time t (see loss_insert_planner_flexible), so condition
# on t here too; t_next is still used below for the length schedule.
t_next = t + dt
planner_out = model.planner(xt_tmp, t)
insertion_conf = planner_out.get("insertion_conf", None)
if insertion_conf is None:
return xt_tmp
insertion_conf = insertion_conf.squeeze(-1) # (B, L)
# Expected sequence length at next timestep according to schedule
current_length_after = xt_tmp.ne(pad).sum(dim=1).float() # [B]
expected_progress = model.interpolant.insertion_schedule.at(t_next) # [B]
estimated_final_length = current_length_after / (expected_progress.clamp(min=0.1))
expected_length = estimated_final_length * expected_progress # [B]
# Mark positions in xt_tmp that came from new_xt (originals) vs. fresh insertions.
# Fancy-indexing scatter avoids the per-batch python loop.
valid_b, valid_l = orig_mask.nonzero(as_tuple=True)
valid_p = new_pos_orig[valid_b, valid_l].long().clamp_(0, L - 1)
is_original = torch.zeros_like(xt_tmp, dtype=torch.bool)
is_original[valid_b, valid_p] = True
inserted_positions = (xt_tmp == mask) & ~is_original
# Two deletion modes, selected by `quality_threshold`:
# * float: drop insertions whose confidence is below the threshold, capped
# so the length never falls below the scheduled minimum.
candidates = inserted_positions & (insertion_conf < quality_threshold)
num_bad = candidates.sum(dim=1) # [B], long
min_length = expected_length.long().clamp(min=1) # [B]
max_removable = (current_length_after.long() - min_length).clamp(min=0)
length_after_removal = current_length_after.long() - num_bad
schedule_violates = length_after_removal < min_length
k_per_row = torch.where(schedule_violates, max_removable, num_bad)
k_per_row = torch.where(num_bad > 0, k_per_row, torch.zeros_like(k_per_row))
if not candidates.any():
return xt_tmp
# Select the lowest-confidence candidates per row via a sort.
neg_inf = torch.tensor(float('-inf'), device=device, dtype=insertion_conf.dtype)
scores = torch.where(candidates, -insertion_conf, neg_inf) # higher = worse
_, sorted_indices = scores.sort(dim=1, descending=True)
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
final_bad = torch.zeros_like(candidates)
final_bad.scatter_(1, sorted_indices, keep_in_topk)
if not final_bad.any():
return xt_tmp
# Compact each row to the left (keep good, drop bad), then pad the tail.
# Stable sort by the bad flag pushes bad positions to the right.
sort_key = final_bad.long()
_, perm = torch.sort(sort_key, dim=1, stable=True)
xt_tmp = torch.gather(xt_tmp, 1, perm)
num_keep = (~final_bad).sum(dim=1) # [B]
tail_mask = positions >= num_keep.unsqueeze(1) # [B, L]
xt_tmp = torch.where(tail_mask, torch.full_like(xt_tmp, pad), xt_tmp)
return xt_tmp
def apply_schedule_aware_remasking(
model,
new_xt,
t,
dt,
remasking_conf,
clean_index,
mask,
neg_inf,
batch_size,
unmask_quality_threshold=None,
):
"""
Apply schedule-aware remasking: adjust number of masks to match expected count from schedule.
Args:
model: Model with interpolant that has an unmask_schedule
new_xt: Current sequence [B, L]
t: Current time [B]
dt: Time step size
remasking_conf: Confidence scores for tokens [B, L]
clean_index: Boolean mask of clean tokens (not mask, not pad) [B, L]
mask: Mask token ID
neg_inf: Negative infinity tensor
batch_size: Batch size
Returns:
new_xt: Modified sequence with schedule-aware remasking applied
"""
# Optional AJD threshold gate (overrides the schedule-driven count when set):
# remask every clean token whose unmasking-quality confidence is below the
# threshold. Higher threshold => more aggressive remasking.
if unmask_quality_threshold is not None:
to_mask = clean_index & (remasking_conf < unmask_quality_threshold)
return torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
t_next = t + dt
num_clean = clean_index.sum(dim=1) # [B], long
current_seq_len = (num_clean + (new_xt == mask).sum(dim=1)).float() # [B]
expected_unmasked_frac = model.interpolant.unmask_schedule.at(t_next) # [B]
expected_num_clean = expected_unmasked_frac * current_seq_len # [B]
masks_to_add = (num_clean.float() - expected_num_clean).round().long() # [B]
# Per-row k = min(masks_to_add, num_clean), clamped to >= 0.
k_per_row = torch.minimum(masks_to_add.clamp(min=0), num_clean) # [B]
if k_per_row.sum() == 0:
return new_xt
# Use confidence to decide which clean tokens to remask: lowest conf first.
remasking_score_temp = -1.0 * remasking_conf # low conf = high score
remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf)
_, sorted_indices = remasking_score_temp.sort(dim=1, descending=True)
L = remasking_score_temp.shape[1]
positions = torch.arange(L, device=new_xt.device).unsqueeze(0) # [1, L]
keep_in_topk = positions < k_per_row.unsqueeze(1) # [B, L]
to_mask = torch.zeros_like(clean_index)
to_mask.scatter_(1, sorted_indices, keep_in_topk)
new_xt = torch.where(to_mask, torch.full_like(new_xt, mask), new_xt)
return new_xt