File size: 38,972 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 | import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
from omegaconf import DictConfig
import torch.nn.functional as F
from model.transformer import AnyOrderMaskInsertionFlow
from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
from .bregman import jump_kernel_elbo, mse
from .schedule import get_schedule_from_config
from lightning_modules.any_order import AnyOrderInsertionFlowModule
from model.model_wrapper import RemaskingAnyOrder
from sampling import _sample_tokens
import re
from typing import Dict, Any
from dataclasses import dataclass
def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns a new state_dict where any key containing '._orig_mod.' is replaced
by removing the '_orig_mod' segment, e.g.
'model._orig_mod.vocab_embed.embedding'
becomes
'model.vocab_embed.embedding'
"""
new_state_dict: Dict[str, Any] = {}
for key, value in state_dict.items():
# remove all occurrences of '._orig_mod.'
clean_key = re.sub(r"\._orig_mod\.", ".", key)
new_state_dict[clean_key] = value
return new_state_dict
@torch.no_grad()
def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float:
"""Rank-based AUROC (Mann-Whitney U statistic).
AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN
when only one class is present (AUC undefined). Ties are not averaged, which
is fine for continuous logits used here.
"""
scores = scores.float().reshape(-1)
labels = labels.float().reshape(-1)
n_pos = labels.sum()
n_neg = labels.numel() - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
order = torch.argsort(scores)
ranks = torch.empty_like(scores)
ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype)
auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
return auc.item()
class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule):
"""
Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model
for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint
and add the schedule model on top.
"""
def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False):
# Initialize parent class first
super().__init__(config)
self.args = args
self.insertion_planner = insertion_planner
# Save hyperparameters for this class (overrides parent's save)
self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args'])
# Load pretrained model weights BEFORE initializing planner to avoid circular reference
if pretrained_checkpoint is not None:
self.load_pretrained_model(pretrained_checkpoint)
# Initialize adaptive schedule model AFTER loading pretrained weights
self.planner = RemaskingAnyOrder(
backbone=self,
d_model=self.config.model.hidden_size,
insertion_planner=insertion_planner)
def load_pretrained_model(self, checkpoint_path: str):
"""
Load pretrained AnyOrderInsertionFlowModule weights.
Only loads the base model and interpolant, not the schedule model.
"""
print(f"Loading pretrained model from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Extract state dict - handle different checkpoint formats
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# Strip _orig_mod keys if present
state_dict = strip_orig_mod_keys(state_dict)
# Filter out planner keys (if any exist from a previous FT checkpoint)
base_state_dict = {k: v for k, v in state_dict.items()
if not k.startswith('planner.')}
# Load the base model weights
# Use strict=False to ignore missing schedule_model keys
incompatible_keys = self.load_state_dict(base_state_dict, strict=False)
# Filter out expected missing planner keys for cleaner output
unexpected_missing = [k for k in incompatible_keys.missing_keys
if not k.startswith('planner.')]
planner_missing = [k for k in incompatible_keys.missing_keys
if k.startswith('planner.')]
if unexpected_missing:
print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}")
if planner_missing:
print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)")
if incompatible_keys.unexpected_keys:
print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}")
# Freeze base model if specified
if self.config.training.get('freeze_base_model', False):
print("Freezing base model parameters")
for name, param in self.named_parameters():
if not name.startswith('planner.'):
param.requires_grad = False
def forward(self, x, t, return_features=False):
# Use parent class forward method
return super().forward(x, t, return_features=return_features)
def training_loss(self, x1, t):
# Use parent class training_loss for base model loss
# Planner is trained separately via loss_planner_flexible with reward gradients
unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t)
return unmask_loss, insertion_loss, total_loss
def training_step(self, batch, batch_idx):
# Extract input data
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
t = self.sample_time(x1.shape[0], x1.device)
# Calculate the base model loss (planner trained separately, not here)
unmask_loss, len_loss, loss = self.training_loss(x1, t)
# Log component losses
self.log("train/unmask_loss", unmask_loss, prog_bar=True)
self.log("train/len_loss", len_loss, prog_bar=True)
self.log("train/total_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
t = self.sample_time(x1.shape[0], x1.device)
unmask_loss, len_loss, loss = self.training_loss(x1, t)
self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss
@classmethod
def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs):
"""
Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner.
Extracts config from original pretrained checkpoint and loads finetuned weights.
"""
print(f"Loading finetuned checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False)
# Check if this is a wrapped checkpoint (from PeptideFinetuner)
hparams = checkpoint.get('hyper_parameters', {})
state_dict = checkpoint.get('state_dict', {})
# Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper)
has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys())
if has_policy_prefix:
# Detect model type (molecule vs peptide) based on vocab size in checkpoint
# Molecule models have vocab size ~1882, peptide models have ~587
vocab_size = None
for k, v in state_dict.items():
if 'vocab_embed.embedding' in k:
vocab_size = v.shape[0]
break
is_molecule_model = vocab_size is not None and vocab_size > 1000
model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner"
print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})")
# Extract args from hyperparameters
if 'args' not in hparams:
raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.")
args = hparams['args']
print(f"Found args in hyperparameters, type: {type(args)}")
# Get original checkpoint path from args
# Handle both Namespace (hasattr) and dict (get) access patterns
original_ckpt_path = None
if hasattr(args, 'checkpoint_path'):
original_ckpt_path = args.checkpoint_path
elif isinstance(args, dict) and 'checkpoint_path' in args:
original_ckpt_path = args['checkpoint_path']
# If checkpoint_path is not set or is None, use default pretrained checkpoint
# Select appropriate default based on detected model type
if original_ckpt_path is None:
_repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if is_molecule_model:
original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt')
print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint")
else:
original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt')
print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint")
# Try to load config directly from checkpoint first (new checkpoints)
# Fall back to loading from original checkpoint (old checkpoints)
if 'config' in checkpoint:
print("Found config directly in checkpoint")
config = checkpoint['config']
else:
print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}")
# Load config from original pretrained checkpoint
orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False)
if 'config' not in orig_ckpt:
raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config")
config = orig_ckpt['config']
# Ensure adaptive schedule is enabled
# Need to disable struct mode to add new keys to OmegaConf config
from omegaconf import OmegaConf
if hasattr(config, 'training'):
OmegaConf.set_struct(config, False)
config.training.use_adaptive_schedule = True
OmegaConf.set_struct(config, True)
# Create args object if needed
if not hasattr(args, '__dict__'):
# Convert dict to object with attributes
class Args:
pass
args_obj = Args()
for k, v in args.items():
setattr(args_obj, k, v)
args = args_obj
# Initialize model with config and args
model = cls(
config=config,
args=args,
pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint
insertion_planner=getattr(args, 'insertion_planner', False)
)
# Extract policy_model weights from state_dict
policy_state = {}
for k, v in state_dict.items():
if k.startswith('policy_model.'):
# Strip 'policy_model.' prefix
new_key = k[len('policy_model.'):]
policy_state[new_key] = v
# Load the finetuned weights
incompatible = model.load_state_dict(policy_state, strict=False)
if incompatible.missing_keys or incompatible.unexpected_keys:
print(f"Warning: Incompatible keys when loading finetuned weights:")
if incompatible.missing_keys:
print(f" Missing: {incompatible.missing_keys[:5]}...")
if incompatible.unexpected_keys:
print(f" Unexpected: {incompatible.unexpected_keys[:5]}...")
# Initialize or load EMA params
if model.use_ema:
if "ema_params" in checkpoint:
# Load EMA params from checkpoint
model.ema_params = checkpoint["ema_params"]
print("Loaded EMA params from checkpoint")
else:
# Initialize empty EMA params (will be populated if needed)
model.ema_params = {
name: param.clone().detach()
for name, param in model.named_parameters()
}
print("Initialized EMA params from current model state")
else:
model.ema_params = {}
# Load planner state if it exists
if "planner_state" in checkpoint and hasattr(model, 'planner'):
model.planner.load_state_dict(checkpoint["planner_state"], strict=False)
print("Loaded planner state from checkpoint")
return model
else:
# Not a wrapped checkpoint, use default Lightning loading
# But we still need to provide required __init__ arguments
raise NotImplementedError(
"Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. "
"Please provide config and args as kwargs."
)
def on_save_checkpoint(self, checkpoint):
"""Save config and EMA params, including planner state."""
# Call parent to save config and base model EMA
super().on_save_checkpoint(checkpoint)
# Explicitly save planner state
if hasattr(self, 'planner'):
checkpoint["planner_state"] = self.planner.state_dict()
def on_load_checkpoint(self, checkpoint):
"""Load config and reinitialize interpolant, including planner."""
# For finetuned checkpoints loaded via custom load_from_checkpoint,
# config may not be in checkpoint (it's loaded from original checkpoint)
if "config" in checkpoint:
# Call parent to restore config and interpolant
super().on_load_checkpoint(checkpoint)
else:
# Config already set during __init__ via load_from_checkpoint
# Just restore EMA params if they exist
if self.use_ema and "ema_params" in checkpoint:
self.ema_params = checkpoint["ema_params"]
# Restore planner state if it exists in checkpoint
if hasattr(self, 'planner') and "planner_state" in checkpoint:
self.planner.load_state_dict(checkpoint["planner_state"])
print("Loaded planner from checkpoint")
def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
r"""
Weighted denoising cross entropy loss
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
log_rnd: [B] β pre-computed importance weights (already softmax-normalized over the full buffer)
x: [B, L] (no mask)
num_replicates: R, number of replicates of each row in x
weight_func: w(lambda) for each sample, 1/lambda by default
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
"""
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
if centering:
batch_weights = batch_weights - centering_strength * batch_weights.mean()
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
t = lamda
# compute unmasking and insertion loss
interpolant_sample = self.interpolant.sample_interpolant(t, batch)
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
prediction: ModelPrediction = self(interpolant_sample.xt, t)
scale_factor = self.config.interpolant.max_length
match self.unmask_loss_fn:
case "elbo":
mask_indices = interpolant_sample.mask_indices
unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L]
unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy(
prediction.token_logits[mask_indices],
interpolant_sample.unmasked[mask_indices],
reduction="none",
)
unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R]
case _:
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
match self.insert_loss_fn:
case "expectation":
gaps, gaps_mask = interpolant_sample.gaps_and_mask
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo(
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
)
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
case "distribution":
gaps, gaps_mask = interpolant_sample.gaps_and_mask
insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1]
insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy(
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
)
insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R]
total_loss = unmask_loss + insertion_loss # [B*R]
# end compute unmasking and insertion loss
weighted_loss = total_loss * batch_weights # [B*R]
return weighted_loss.mean()
def one_step_sampler(self, xt, t, pred_rate=None):
"""
Sample one step of unmasking using model predictions.
Args:
xt: Current state [B, L]
t: Time [B]
pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model.
Returns:
new_xt: Next state [B, L]
update_ids: Boolean mask of updated positions [B, L]
"""
mask = self.interpolant.mask_token
pad = self.interpolant.pad_token
batch_size, L = xt.shape
device = xt.device
steps = self.args.total_num_steps
dt = 1.0 / steps
max_length = self.interpolant.max_length
# Use actual tensor dimension L instead of max_length to handle replicated batches
batch_idx_L = (
torch.arange(batch_size, device=device)
.view(batch_size, 1)
.expand(batch_size, L)
)
pos_idx_L = (
torch.arange(L, device=device)
.view(1, L)
.expand(batch_size, L)
)
# βββ predict and convert rates βββ
if pred_rate is None:
pred_rate = self(xt, t)
pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t)
unmask_rate = pred_rate.unmask_rate # (B, L, V)
len_rate = pred_rate.length_rate # (B, L+1)
# βββ unmask step (Euler) βββ
mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True)
unmask_rate[xt != mask] = 0
unmask_rate[mask_pos + (mask,)] = 0
unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1)
trans_prob = (unmask_rate * dt).clamp(0.0, 1.0)
# add "stay" probability
_xt = xt.clone()
_xt[xt == pad] = mask
trans_prob.scatter_add_(
2,
_xt.unsqueeze(-1),
torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype),
)
trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step
# Renormalize probabilities to ensure they sum to 1
prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True)
# Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad)
mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0)
if mask_has_zero_prob.any():
# Create uniform distribution over valid tokens (excluding mask and pad)
num_zero_prob = mask_has_zero_prob.sum().item()
uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype)
uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1
trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob
else:
# Normalize to sum to 1
trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum
new_xt = _sample_tokens(trans_prob)
new_xt[xt == pad] = pad
new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt)
# update indices--boolean tensor of shape (B, max_length)
# A position is updated if:
# 1. The token changed (xt != new_xt)
# 2. It's not a pad position
# 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask)
# Debug before fix
old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask)
# Correct logic: updated positions are where mask tokens were changed
update_ids = (xt != new_xt) & (xt != pad)
if self.insertion_planner is False:
return new_xt, update_ids
# βββ Poisson insertion (tau-leaping) β can insert multiple masks per gap βββ
ext = torch.poisson(len_rate * dt).long() # (B, L+1)
xt_len = xt.ne(pad).sum(dim=1) # (B,)
# Use ext.shape[1] to get the actual max_length dimension from the data
actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1
gaps = torch.arange(ext.shape[1], device=device).view(1, -1)
ext = ext * (gaps <= xt_len.view(batch_size, 1)).long()
total_ext = ext.sum(dim=1)
valid = xt_len + total_ext <= actual_max_length
ext = ext * valid.view(batch_size, 1).long()
ext_ex = ext.int().cumsum(dim=1) # (B, L+1)
new_len = xt_len + total_ext # (B,)
xt_tmp = torch.full_like(xt, pad)
# Create position indices that match xt_tmp's shape
pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1)
mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1)
xt_tmp[mask_fill] = mask
new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L)
orig_mask = pos_idx_L < xt_len.view(batch_size, 1)
flat_b = batch_idx_L[orig_mask]
flat_p = new_pos_orig[orig_mask]
xt_tmp[flat_b, flat_p] = new_xt[orig_mask]
new_ins_xt = xt_tmp
# Newly inserted masks: positions that are mask now but weren't before.
newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad)
update_ins_ids = newly_inserted_masks
return new_xt, update_ids, new_ins_xt, update_ins_ids
def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
r"""
Weighted denoising cross entropy loss
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
log_rnd: [B] β pre-computed importance weights (already softmax-normalized over the full buffer)
x: [B, L] (no mask)
num_replicates: R, number of replicates of each row in x
weight_func: w(lambda) for each sample, 1/lambda by default
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
"""
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
batch_size = batch.shape[0]
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
if centering:
batch_weights = batch_weights - centering_strength * batch_weights.mean()
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
t = lamda
scale_factor = self.config.interpolant.max_length
# compute unmasking and insertion loss
interpolant_sample = self.interpolant.sample_interpolant(t, batch)
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
prediction: ModelPrediction = self(interpolant_sample.xt, t)
with torch.no_grad(): # no need to compute gradient in this step
sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction)
# one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids)
xs, update_ids = sampler_out[0], sampler_out[1]
# The remasking head scores the freshly-decoded tokens to decide which to
# remask, so it reads the POST-unmask state xs (matching inference, which
# calls the planner on the decoded new_xt).
planner = self.planner(xs, t)
remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
# Compute per-sample loss
# IMPORTANT: interpolant_sample.xt has been reordered via st permutation
# We need to map back to the original positions to compare with batch
st = interpolant_sample.st # [B*R, L] permutation indices
batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
binary_label = (xs == batch_reordered).float()
# Only compute loss on positions that were updated
per_token_loss = F.binary_cross_entropy_with_logits(
remasking_conf.squeeze(-1), # [B*R, L]
binary_label, # [B*R, L]
reduction="none" # [B*R, L]
)
per_token_loss = per_token_loss * update_ids.float() # [B*R, L]
# Mask out non-updated positions and average per sample
per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R]
# Weight by importance sampling weights
weighted_loss = per_sample_loss * batch_weights # [B*R]
# βββ AUC / label-balance diagnostics (see loss_insert_planner_flexible) βββ
with torch.no_grad():
metrics = {}
sel_u = update_ids.bool()
if sel_u.any():
u_scores = remasking_conf.squeeze(-1)[sel_u]
u_labels = binary_label[sel_u]
metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
metrics["unmask_label_mean"] = u_labels.mean().item()
metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
metrics["unmask_n"] = float(sel_u.sum().item())
self._last_planner_metrics = metrics
return weighted_loss.mean()
def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0):
r"""
Weighted denoising cross entropy loss
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
log_rnd: [B] β pre-computed importance weights
x: [B, L] (no mask)
num_replicates: R, number of replicates of each row in x
weight_func: w(lambda) for each sample, 1/lambda by default
centering_strength: float, controls how much of the mean is subtracted (DMPO-style)
softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights)
"""
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
batch_size = batch.shape[0]
batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B]
if centering:
batch_weights = batch_weights - centering_strength * batch_weights.mean()
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
t = lamda
scale_factor = self.config.interpolant.max_length
# compute unmasking and insertion loss
# deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted
# gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices
interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch)
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch)
prediction: ModelPrediction = self(interpolant_sample.xt, t)
with torch.no_grad(): # no need to compute gradient in this step
xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction)
# The remasking head scores the freshly-decoded tokens to decide which to
# remask, so it must see the POST-unmask state xs_unmask (matching
# inference in inference_quality.py, which calls the planner on the
# decoded new_xt). Grad stays on here since this head is what we train.
planner = self.planner(xs_unmask, t)
remasking_conf = planner["remasking_conf"] # [B*R, L, 1]
# The insertion-quality head scores the freshly-inserted mask tokens, so
# it must see the POST-insertion state xs_insert (aligned with
# update_ins_ids / insertion_quality below, and matching inference in
# remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on
# here since this head is what we are training.
if self.planner.insertion_planner:
insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1]
else:
insertion_conf = None
# Compute per-sample loss
# IMPORTANT: interpolant_sample.xt has been reordered via st permutation
# We need to map back to the original positions to compare with batch
# Use the st (permutation) to get the ground truth in the reordered space
st = interpolant_sample.st # [B*R, L] permutation indices
batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth
# Now compare in the reordered space
binary_label = (xs_unmask == batch_reordered).float()
# Only compute loss on positions that were updated
per_token_loss = F.binary_cross_entropy_with_logits(
remasking_conf.squeeze(-1), # [B*R, L]
binary_label, # [B*R, L]
reduction="none" # [B*R, L]
)
per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L]
# Mask out non-updated positions and average per sample
unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R]
# compute insertion planner loss
# For positions where masks were inserted, we evaluate the quality of insertion
# by computing the probability that the ground truth token would be predicted at that position
# IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted
# The original prediction was computed from xt (before insertion)
with torch.no_grad():
prediction_after_insert: ModelPrediction = self(xs_insert, t)
# Get the token prediction probabilities at inserted mask positions
# prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert
token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V]
# For each gap where masks were inserted, compute the sum of probabilities
# of the ground truth tokens that were deleted in that specific gap
# gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices
# batch: [B*R, L] - ground truth tokens in original space (before permutation)
vocab_size = token_probs.shape[-1]
L = token_probs.shape[1]
max_gaps = gap_assignment.shape[1]
# For each gap, create a vocabulary mask of tokens that belong to that gap
# gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx
gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float)
# Vectorized: gather tokens from batch for all gaps at once
# tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions
tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L]
# valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad
valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L]
# Scatter tokens into vocabulary dimension: mark which tokens appear in each gap
gap_vocab_mask.scatter_add_(
2, # scatter along vocabulary dimension
tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L]
valid_mask.float() # values to add [B*R, max_gaps, L]
)
# Binarize: a token either appears in the gap or not
gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V]
# For each insertion position in xs_insert, determine which gap it corresponds to
# Position p in xs_insert corresponds to gap p (insertions occur between existing tokens)
# Vectorized: compute for all positions at once
# token_probs: [B*R, L, V]
# gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1
insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L]
# Only consider quality at positions where masks were actually inserted
insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L]
# Compute insertion planner loss only if insertion_planner is enabled
if insertion_conf is not None:
# The planner predicts insertion confidence with insertion_conf
# We want to train it to predict high confidence when insertion_quality is high
# Use Bernoulli cross-entropy: treat insertion_quality as the "success probability"
# Binary cross-entropy with insertion_quality as continuous labels in [0,1]
ins_per_token_loss = F.binary_cross_entropy_with_logits(
insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits
insertion_quality, # [B*R, L] - ground truth token probability as quality metric
reduction="none"
)
# Only compute loss where masks were actually inserted
ins_per_token_loss = ins_per_token_loss * update_ins_ids.float()
# Average per sample
ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8)
else:
# No insertion planner - set loss to zero
ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss)
# Add to total loss
per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss
# Weight by importance sampling weights
weighted_loss = per_sample_loss * batch_weights # [B*R]
# βββ AUC / label-balance diagnostics (the loss alone hides degenerate
# targets; near-0 BCE can mean "all labels one class", not "learned") βββ
with torch.no_grad():
metrics = {}
sel_u = update_unmask_ids.bool()
if sel_u.any():
u_scores = remasking_conf.squeeze(-1)[sel_u]
u_labels = binary_label[sel_u]
metrics["unmask_auc"] = _binary_auc(u_scores, u_labels)
metrics["unmask_label_mean"] = u_labels.mean().item()
metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item()
metrics["unmask_n"] = float(sel_u.sum().item())
if insertion_conf is not None:
sel_i = update_ins_ids.bool()
if sel_i.any():
i_scores = insertion_conf.squeeze(-1)[sel_i]
i_targets = insertion_quality[sel_i]
i_labels = (i_targets > 0.5).float()
metrics["insert_auc"] = _binary_auc(i_scores, i_labels)
metrics["insert_target_mean"] = i_targets.mean().item()
metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item()
metrics["insert_n"] = float(sel_i.sum().item())
self._last_planner_metrics = metrics
return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean()
|