import os import torch import torch.nn as nn import pytorch_lightning as pl from omegaconf import DictConfig import torch.nn.functional as F from model.transformer import AnyOrderMaskInsertionFlow from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction from .bregman import jump_kernel_elbo, mse from .schedule import get_schedule_from_config from lightning_modules.any_order import AnyOrderInsertionFlowModule from model.model_wrapper import RemaskingAnyOrder from sampling import _sample_tokens import re from typing import Dict, Any from dataclasses import dataclass def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: """ Returns a new state_dict where any key containing '._orig_mod.' is replaced by removing the '_orig_mod' segment, e.g. 'model._orig_mod.vocab_embed.embedding' becomes 'model.vocab_embed.embedding' """ new_state_dict: Dict[str, Any] = {} for key, value in state_dict.items(): # remove all occurrences of '._orig_mod.' clean_key = re.sub(r"\._orig_mod\.", ".", key) new_state_dict[clean_key] = value return new_state_dict @torch.no_grad() def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float: """Rank-based AUROC (Mann-Whitney U statistic). AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN when only one class is present (AUC undefined). Ties are not averaged, which is fine for continuous logits used here. """ scores = scores.float().reshape(-1) labels = labels.float().reshape(-1) n_pos = labels.sum() n_neg = labels.numel() - n_pos if n_pos == 0 or n_neg == 0: return float("nan") order = torch.argsort(scores) ranks = torch.empty_like(scores) ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype) auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) return auc.item() class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule): """ Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint and add the schedule model on top. """ def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False): # Initialize parent class first super().__init__(config) self.args = args self.insertion_planner = insertion_planner # Save hyperparameters for this class (overrides parent's save) self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args']) # Load pretrained model weights BEFORE initializing planner to avoid circular reference if pretrained_checkpoint is not None: self.load_pretrained_model(pretrained_checkpoint) # Initialize adaptive schedule model AFTER loading pretrained weights self.planner = RemaskingAnyOrder( backbone=self, d_model=self.config.model.hidden_size, insertion_planner=insertion_planner) def load_pretrained_model(self, checkpoint_path: str): """ Load pretrained AnyOrderInsertionFlowModule weights. Only loads the base model and interpolant, not the schedule model. """ print(f"Loading pretrained model from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # Extract state dict - handle different checkpoint formats if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # Strip _orig_mod keys if present state_dict = strip_orig_mod_keys(state_dict) # Filter out planner keys (if any exist from a previous FT checkpoint) base_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('planner.')} # Load the base model weights # Use strict=False to ignore missing schedule_model keys incompatible_keys = self.load_state_dict(base_state_dict, strict=False) # Filter out expected missing planner keys for cleaner output unexpected_missing = [k for k in incompatible_keys.missing_keys if not k.startswith('planner.')] planner_missing = [k for k in incompatible_keys.missing_keys if k.startswith('planner.')] if unexpected_missing: print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}") if planner_missing: print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)") if incompatible_keys.unexpected_keys: print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}") # Freeze base model if specified if self.config.training.get('freeze_base_model', False): print("Freezing base model parameters") for name, param in self.named_parameters(): if not name.startswith('planner.'): param.requires_grad = False def forward(self, x, t, return_features=False): # Use parent class forward method return super().forward(x, t, return_features=return_features) def training_loss(self, x1, t): # Use parent class training_loss for base model loss # Planner is trained separately via loss_planner_flexible with reward gradients unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t) return unmask_loss, insertion_loss, total_loss def training_step(self, batch, batch_idx): # Extract input data if isinstance(batch, dict): batch = batch["input_ids"] x1 = batch t = self.sample_time(x1.shape[0], x1.device) # Calculate the base model loss (planner trained separately, not here) unmask_loss, len_loss, loss = self.training_loss(x1, t) # Log component losses self.log("train/unmask_loss", unmask_loss, prog_bar=True) self.log("train/len_loss", len_loss, prog_bar=True) self.log("train/total_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): if isinstance(batch, dict): batch = batch["input_ids"] x1 = batch t = self.sample_time(x1.shape[0], x1.device) unmask_loss, len_loss, loss = self.training_loss(x1, t) self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True) self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True) self.log("val_loss", loss, prog_bar=True, sync_dist=True) return loss @classmethod def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs): """ Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner. Extracts config from original pretrained checkpoint and loads finetuned weights. """ print(f"Loading finetuned checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False) # Check if this is a wrapped checkpoint (from PeptideFinetuner) hparams = checkpoint.get('hyper_parameters', {}) state_dict = checkpoint.get('state_dict', {}) # Check for policy_model prefix in state_dict (indicates PeptideFinetuner wrapper) has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys()) if has_policy_prefix: # Detect model type (molecule vs peptide) based on vocab size in checkpoint # Molecule models have vocab size ~1882, peptide models have ~587 vocab_size = None for k, v in state_dict.items(): if 'vocab_embed.embedding' in k: vocab_size = v.shape[0] break is_molecule_model = vocab_size is not None and vocab_size > 1000 model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner" print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})") # Extract args from hyperparameters if 'args' not in hparams: raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.") args = hparams['args'] print(f"Found args in hyperparameters, type: {type(args)}") # Get original checkpoint path from args # Handle both Namespace (hasattr) and dict (get) access patterns original_ckpt_path = None if hasattr(args, 'checkpoint_path'): original_ckpt_path = args.checkpoint_path elif isinstance(args, dict) and 'checkpoint_path' in args: original_ckpt_path = args['checkpoint_path'] # If checkpoint_path is not set or is None, use default pretrained checkpoint # Select appropriate default based on detected model type if original_ckpt_path is None: _repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if is_molecule_model: original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt') print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint") else: original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt') print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint") # Try to load config directly from checkpoint first (new checkpoints) # Fall back to loading from original checkpoint (old checkpoints) if 'config' in checkpoint: print("Found config directly in checkpoint") config = checkpoint['config'] else: print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}") # Load config from original pretrained checkpoint orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False) if 'config' not in orig_ckpt: raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config") config = orig_ckpt['config'] # Ensure adaptive schedule is enabled # Need to disable struct mode to add new keys to OmegaConf config from omegaconf import OmegaConf if hasattr(config, 'training'): OmegaConf.set_struct(config, False) config.training.use_adaptive_schedule = True OmegaConf.set_struct(config, True) # Create args object if needed if not hasattr(args, '__dict__'): # Convert dict to object with attributes class Args: pass args_obj = Args() for k, v in args.items(): setattr(args_obj, k, v) args = args_obj # Initialize model with config and args model = cls( config=config, args=args, pretrained_checkpoint=None, # Don't reload pretrained, weights already in checkpoint insertion_planner=getattr(args, 'insertion_planner', False) ) # Extract policy_model weights from state_dict policy_state = {} for k, v in state_dict.items(): if k.startswith('policy_model.'): # Strip 'policy_model.' prefix new_key = k[len('policy_model.'):] policy_state[new_key] = v # Load the finetuned weights incompatible = model.load_state_dict(policy_state, strict=False) if incompatible.missing_keys or incompatible.unexpected_keys: print(f"Warning: Incompatible keys when loading finetuned weights:") if incompatible.missing_keys: print(f" Missing: {incompatible.missing_keys[:5]}...") if incompatible.unexpected_keys: print(f" Unexpected: {incompatible.unexpected_keys[:5]}...") # Initialize or load EMA params if model.use_ema: if "ema_params" in checkpoint: # Load EMA params from checkpoint model.ema_params = checkpoint["ema_params"] print("Loaded EMA params from checkpoint") else: # Initialize empty EMA params (will be populated if needed) model.ema_params = { name: param.clone().detach() for name, param in model.named_parameters() } print("Initialized EMA params from current model state") else: model.ema_params = {} # Load planner state if it exists if "planner_state" in checkpoint and hasattr(model, 'planner'): model.planner.load_state_dict(checkpoint["planner_state"], strict=False) print("Loaded planner state from checkpoint") return model else: # Not a wrapped checkpoint, use default Lightning loading # But we still need to provide required __init__ arguments raise NotImplementedError( "Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. " "Please provide config and args as kwargs." ) def on_save_checkpoint(self, checkpoint): """Save config and EMA params, including planner state.""" # Call parent to save config and base model EMA super().on_save_checkpoint(checkpoint) # Explicitly save planner state if hasattr(self, 'planner'): checkpoint["planner_state"] = self.planner.state_dict() def on_load_checkpoint(self, checkpoint): """Load config and reinitialize interpolant, including planner.""" # For finetuned checkpoints loaded via custom load_from_checkpoint, # config may not be in checkpoint (it's loaded from original checkpoint) if "config" in checkpoint: # Call parent to restore config and interpolant super().on_load_checkpoint(checkpoint) else: # Config already set during __init__ via load_from_checkpoint # Just restore EMA params if they exist if self.use_ema and "ema_params" in checkpoint: self.ema_params = checkpoint["ema_params"] # Restore planner state if it exists in checkpoint if hasattr(self, 'planner') and "planner_state" in checkpoint: self.planner.load_state_dict(checkpoint["planner_state"]) print("Loaded planner from checkpoint") def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): r""" Weighted denoising cross entropy loss X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) x: [B, L] (no mask) num_replicates: R, number of replicates of each row in x weight_func: w(lambda) for each sample, 1/lambda by default centering_strength: float, controls how much of the mean is subtracted (DMPO-style) softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) """ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] if centering: batch_weights = batch_weights - centering_strength * batch_weights.mean() batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] t = lamda # compute unmasking and insertion loss interpolant_sample = self.interpolant.sample_interpolant(t, batch) unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) prediction: ModelPrediction = self(interpolant_sample.xt, t) scale_factor = self.config.interpolant.max_length match self.unmask_loss_fn: case "elbo": mask_indices = interpolant_sample.mask_indices unmask_loss_all = torch.zeros_like(unmask_weight) # [B*R, L] unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy( prediction.token_logits[mask_indices], interpolant_sample.unmasked[mask_indices], reduction="none", ) unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor # [B*R] case _: raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") match self.insert_loss_fn: case "expectation": gaps, gaps_mask = interpolant_sample.gaps_and_mask insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo( gaps[gaps_mask], prediction.expected_gaps[gaps_mask] ) insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] case "distribution": gaps, gaps_mask = interpolant_sample.gaps_and_mask insertion_loss_all = torch.zeros_like(insert_weight) # [B*R, L+1] insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy( prediction.length_posterior[gaps_mask], gaps[gaps_mask] ) insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor # [B*R] total_loss = unmask_loss + insertion_loss # [B*R] # end compute unmasking and insertion loss weighted_loss = total_loss * batch_weights # [B*R] return weighted_loss.mean() def one_step_sampler(self, xt, t, pred_rate=None): """ Sample one step of unmasking using model predictions. Args: xt: Current state [B, L] t: Time [B] pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model. Returns: new_xt: Next state [B, L] update_ids: Boolean mask of updated positions [B, L] """ mask = self.interpolant.mask_token pad = self.interpolant.pad_token batch_size, L = xt.shape device = xt.device steps = self.args.total_num_steps dt = 1.0 / steps max_length = self.interpolant.max_length # Use actual tensor dimension L instead of max_length to handle replicated batches batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, L) ) pos_idx_L = ( torch.arange(L, device=device) .view(1, L) .expand(batch_size, L) ) # ——— predict and convert rates ——— if pred_rate is None: pred_rate = self(xt, t) pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # add "stay" probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # Renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # Avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # Create uniform distribution over valid tokens (excluding mask and pad) num_zero_prob = mask_has_zero_prob.sum().item() uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) uniform_prob[:, :mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: # Normalize to sum to 1 trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # update indices--boolean tensor of shape (B, max_length) # A position is updated if: # 1. The token changed (xt != new_xt) # 2. It's not a pad position # 3. It WAS a mask token that got unmasked (so we check xt == mask, not xt != mask) # Debug before fix old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask) # Correct logic: updated positions are where mask tokens were changed update_ids = (xt != new_xt) & (xt != pad) if self.insertion_planner is False: return new_xt, update_ids # ——— Poisson insertion (tau-leaping) — can insert multiple masks per gap ——— ext = torch.poisson(len_rate * dt).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) # Use ext.shape[1] to get the actual max_length dimension from the data actual_max_length = ext.shape[1] - 1 # ext is (B, L+1), so L = ext.shape[1] - 1 gaps = torch.arange(ext.shape[1], device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= actual_max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) # Create position indices that match xt_tmp's shape pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1) mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] new_ins_xt = xt_tmp # Newly inserted masks: positions that are mask now but weren't before. newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad) update_ins_ids = newly_inserted_masks return new_xt, update_ids, new_ins_xt, update_ins_ids def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): r""" Weighted denoising cross entropy loss X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) x: [B, L] (no mask) num_replicates: R, number of replicates of each row in x weight_func: w(lambda) for each sample, 1/lambda by default centering_strength: float, controls how much of the mean is subtracted (DMPO-style) softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) """ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] batch_size = batch.shape[0] batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] if centering: batch_weights = batch_weights - centering_strength * batch_weights.mean() batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] t = lamda scale_factor = self.config.interpolant.max_length # compute unmasking and insertion loss interpolant_sample = self.interpolant.sample_interpolant(t, batch) unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) prediction: ModelPrediction = self(interpolant_sample.xt, t) with torch.no_grad(): # no need to compute gradient in this step sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction) # one_step_sampler returns (xs, update_ids) or (xs, update_ids, new_ins_xt, update_ins_ids) xs, update_ids = sampler_out[0], sampler_out[1] # The remasking head scores the freshly-decoded tokens to decide which to # remask, so it reads the POST-unmask state xs (matching inference, which # calls the planner on the decoded new_xt). planner = self.planner(xs, t) remasking_conf = planner["remasking_conf"] # [B*R, L, 1] # Compute per-sample loss # IMPORTANT: interpolant_sample.xt has been reordered via st permutation # We need to map back to the original positions to compare with batch st = interpolant_sample.st # [B*R, L] permutation indices batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth binary_label = (xs == batch_reordered).float() # Only compute loss on positions that were updated per_token_loss = F.binary_cross_entropy_with_logits( remasking_conf.squeeze(-1), # [B*R, L] binary_label, # [B*R, L] reduction="none" # [B*R, L] ) per_token_loss = per_token_loss * update_ids.float() # [B*R, L] # Mask out non-updated positions and average per sample per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) # [B*R] # Weight by importance sampling weights weighted_loss = per_sample_loss * batch_weights # [B*R] # ——— AUC / label-balance diagnostics (see loss_insert_planner_flexible) ——— with torch.no_grad(): metrics = {} sel_u = update_ids.bool() if sel_u.any(): u_scores = remasking_conf.squeeze(-1)[sel_u] u_labels = binary_label[sel_u] metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) metrics["unmask_label_mean"] = u_labels.mean().item() metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() metrics["unmask_n"] = float(sel_u.sum().item()) self._last_planner_metrics = metrics return weighted_loss.mean() def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): r""" Weighted denoising cross entropy loss X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) log_rnd: [B] — pre-computed importance weights x: [B, L] (no mask) num_replicates: R, number of replicates of each row in x weight_func: w(lambda) for each sample, 1/lambda by default centering_strength: float, controls how much of the mean is subtracted (DMPO-style) softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) """ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L] batch_size = batch.shape[0] batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) # [B] if centering: batch_weights = batch_weights - centering_strength * batch_weights.mean() batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R] lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R] t = lamda scale_factor = self.config.interpolant.max_length # compute unmasking and insertion loss # deleted mask: binary tensor [B*R, L] where true tokens in batch were deleted # gap_assignment: [B*R, max_gaps, L] maps x1 positions to gap indices interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch) unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) prediction: ModelPrediction = self(interpolant_sample.xt, t) with torch.no_grad(): # no need to compute gradient in this step xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction) # The remasking head scores the freshly-decoded tokens to decide which to # remask, so it must see the POST-unmask state xs_unmask (matching # inference in inference_quality.py, which calls the planner on the # decoded new_xt). Grad stays on here since this head is what we train. planner = self.planner(xs_unmask, t) remasking_conf = planner["remasking_conf"] # [B*R, L, 1] # The insertion-quality head scores the freshly-inserted mask tokens, so # it must see the POST-insertion state xs_insert (aligned with # update_ins_ids / insertion_quality below, and matching inference in # remasking_scheduleaware.apply_schedule_aware_insertion). Grad stays on # here since this head is what we are training. if self.planner.insertion_planner: insertion_conf = self.planner(xs_insert, t)["insertion_conf"] # [B*R, L, 1] else: insertion_conf = None # Compute per-sample loss # IMPORTANT: interpolant_sample.xt has been reordered via st permutation # We need to map back to the original positions to compare with batch # Use the st (permutation) to get the ground truth in the reordered space st = interpolant_sample.st # [B*R, L] permutation indices batch_reordered = torch.gather(batch, 1, st) # Apply same permutation to ground truth # Now compare in the reordered space binary_label = (xs_unmask == batch_reordered).float() # Only compute loss on positions that were updated per_token_loss = F.binary_cross_entropy_with_logits( remasking_conf.squeeze(-1), # [B*R, L] binary_label, # [B*R, L] reduction="none" # [B*R, L] ) per_token_loss = per_token_loss * update_unmask_ids.float() # [B*R, L] # Mask out non-updated positions and average per sample unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) # [B*R] # compute insertion planner loss # For positions where masks were inserted, we evaluate the quality of insertion # by computing the probability that the ground truth token would be predicted at that position # IMPORTANT: We need to recompute predictions using xs_insert since that's where the masks were inserted # The original prediction was computed from xt (before insertion) with torch.no_grad(): prediction_after_insert: ModelPrediction = self(xs_insert, t) # Get the token prediction probabilities at inserted mask positions # prediction_after_insert.token_logits: [B*R, L, V] - logits for all positions in xs_insert token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) # [B*R, L, V] # For each gap where masks were inserted, compute the sum of probabilities # of the ground truth tokens that were deleted in that specific gap # gap_assignment: [B*R, max_gaps, L] - maps x1 positions to gap indices # batch: [B*R, L] - ground truth tokens in original space (before permutation) vocab_size = token_probs.shape[-1] L = token_probs.shape[1] max_gaps = gap_assignment.shape[1] # For each gap, create a vocabulary mask of tokens that belong to that gap # gap_vocab_mask[b, gap_idx, token_id] = 1 if token_id was deleted in gap gap_idx gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float) # Vectorized: gather tokens from batch for all gaps at once # tokens_expanded[b, gap_idx, pos] = batch[b, pos] for all positions tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) # [B*R, max_gaps, L] # valid_mask[b, gap_idx, pos] = 1 if position pos belongs to gap gap_idx and is not pad valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) # [B*R, max_gaps, L] # Scatter tokens into vocabulary dimension: mark which tokens appear in each gap gap_vocab_mask.scatter_add_( 2, # scatter along vocabulary dimension tokens_expanded.clamp(0, vocab_size - 1), # token indices [B*R, max_gaps, L] valid_mask.float() # values to add [B*R, max_gaps, L] ) # Binarize: a token either appears in the gap or not gap_vocab_mask = (gap_vocab_mask > 0).float() # [B*R, max_gaps, V] # For each insertion position in xs_insert, determine which gap it corresponds to # Position p in xs_insert corresponds to gap p (insertions occur between existing tokens) # Vectorized: compute for all positions at once # token_probs: [B*R, L, V] # gap_vocab_mask[:, :L, :]: [B*R, L, V] - vocab mask for gaps 0 to L-1 insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) # [B*R, L] # Only consider quality at positions where masks were actually inserted insertion_quality = insertion_quality_full * update_ins_ids.float() # [B*R, L] # Compute insertion planner loss only if insertion_planner is enabled if insertion_conf is not None: # The planner predicts insertion confidence with insertion_conf # We want to train it to predict high confidence when insertion_quality is high # Use Bernoulli cross-entropy: treat insertion_quality as the "success probability" # Binary cross-entropy with insertion_quality as continuous labels in [0,1] ins_per_token_loss = F.binary_cross_entropy_with_logits( insertion_conf.squeeze(-1), # [B*R, L] - planner's insertion confidence logits insertion_quality, # [B*R, L] - ground truth token probability as quality metric reduction="none" ) # Only compute loss where masks were actually inserted ins_per_token_loss = ins_per_token_loss * update_ins_ids.float() # Average per sample ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8) else: # No insertion planner - set loss to zero ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss) # Add to total loss per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss # Weight by importance sampling weights weighted_loss = per_sample_loss * batch_weights # [B*R] # ——— AUC / label-balance diagnostics (the loss alone hides degenerate # targets; near-0 BCE can mean "all labels one class", not "learned") ——— with torch.no_grad(): metrics = {} sel_u = update_unmask_ids.bool() if sel_u.any(): u_scores = remasking_conf.squeeze(-1)[sel_u] u_labels = binary_label[sel_u] metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) metrics["unmask_label_mean"] = u_labels.mean().item() metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() metrics["unmask_n"] = float(sel_u.sum().item()) if insertion_conf is not None: sel_i = update_ins_ids.bool() if sel_i.any(): i_scores = insertion_conf.squeeze(-1)[sel_i] i_targets = insertion_quality[sel_i] i_labels = (i_targets > 0.5).float() metrics["insert_auc"] = _binary_auc(i_scores, i_labels) metrics["insert_target_mean"] = i_targets.mean().item() metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item() metrics["insert_n"] = float(sel_i.sum().item()) self._last_planner_metrics = metrics return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean()