"""Unified peptide sampling with quality-guided planning. Supports 4 quality modes and optional RND (importance weight) computation. Quality modes: "none" - No planner, no remasking (policy-only) "both" - Both unmasking + insertion planners active "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled) "insertion_only" - Only insertion planner (unmasking planner disabled) RND toggle: compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights compute_rnd=False - Run policy model only (use with ELBO-based RND or eval) """ import os import torch import numpy as np import pandas as pd import torch.nn.functional as F from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"} # When set (e.g. A2D2_QUALITY_DEBUG=1), the diffusion loop prints, per step, how # many already-unmasked tokens get remasked and how many proposed insertions get # filtered by the quality planner, plus a per-batch total. Off by default so it # never spams training/eval runs. _QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False") @torch.no_grad() def _diffusion_loop( model, steps, mask, pad, batch_size, max_length, quality_mode="both", compute_rnd=False, pretrained=None, remasking_mode="schedule_aware", num_remasking=1, quality_threshold=1, unmask_quality_threshold=None, unmask_all=False, freq_penalty=0.0, return_trace=False, ): """Core discrete diffusion sampling loop for peptide generation. Args: model: Finetuned policy model. steps: Number of diffusion steps. mask: Mask token ID. pad: Pad token ID. batch_size: Number of sequences to generate. max_length: Maximum sequence length. quality_mode: One of "none", "both", "unmasking_only", "insertion_only". compute_rnd: Whether to compute step-wise log importance weights. pretrained: Frozen pretrained model (required if compute_rnd=True). remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf"). num_remasking: Number of tokens to remask per step. quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. return_trace: Whether to record sampling trace. Returns: (xt, log_rnd, sampling_trace) log_rnd is None when compute_rnd=False. """ assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}" if compute_rnd: assert pretrained is not None, "pretrained model required when compute_rnd=True" # Derive flags from quality_mode use_remasking = quality_mode != "none" disable_unmasking_planner = quality_mode in ("none", "insertion_only") disable_insertion_planner = quality_mode in ("none", "unmasking_only") device = next(model.parameters()).device # Initialize all-pad sequence xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # Precompute index tensors batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None neg_inf = torch.tensor(-np.inf, device=device) if use_remasking and remasking_mode == "remdm_conf": remasking_score = torch.zeros((batch_size, max_length), device=device) log_rnd = None dbg_total_remasked = 0 dbg_total_proposed_ins = 0 dbg_total_filtered = 0 for i in range(steps): step_remasked = 0 step_proposed_ins = 0 step_filtered = 0 # --- Policy model forward --- pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # --- Pretrained model forward (for RND) --- if compute_rnd: pretrained_pred = pretrained(xt, t) pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) # --- Unmask step (Euler) --- mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) if compute_rnd: pretrained_unmask_rate[xt != mask] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) # Add "stay" probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) if compute_rnd: pretrained_trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), ) # Remove mask token from sampling so every masked position is decoded. # The final step always does this; unmask_all does it every step, so the # schedule-aware remasking below re-masks the lowest-quality tokens back # down to the schedule's expected mask count. if i == steps - 1 or unmask_all: if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): num_zero_prob = mask_has_zero_prob.sum().item() uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) uniform_prob[:, :mask] = 1.0 / mask trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum # --- Frequency penalty: down-weight residues already abundant in the # sequence so (re)decoded masked positions don't collapse onto the modal # token (glycine). Only masked positions are sampled; clean positions are # overwritten below, so penalizing the whole tensor is harmless. mask/pad # never accumulate counts, so their entries stay untouched. Applied to a # copy so trans_prob (used for RND log-probs) is unchanged. sample_prob = trans_prob if freq_penalty > 0.0: V = trans_prob.shape[-1] clean_tok = (xt != mask) & (xt != pad) # (B, L) counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype) counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)), clean_tok.to(trans_prob.dtype)) sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1) new_xt = _sample_tokens(sample_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # Update remasking_score buffer for remdm_conf mode if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1: token_probs = F.softmax(unmask_rate, dim=-1) # (B, L, V) chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) # (B, L) changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad) remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score) # --- Remasking step --- if use_remasking and i < steps - 1: if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None): remasking_conf = torch.zeros((batch_size, max_length), device=device) else: planner_out = model.planner(new_xt, t) remasking_conf = planner_out["remasking_conf"].squeeze(-1) # (B, L) clean_index = (new_xt != mask) & (new_xt != pad) # (B, L) if remasking_mode == "remdm": remasking_score_temp = torch.rand(remasking_conf.shape, device=device) elif remasking_mode == "remdm_conf": remasking_score_temp = -1.0 * remasking_conf elif remasking_mode == "schedule_aware": # Only remask when the unmasking planner is active. Otherwise # (e.g. insertion_only / no_unmasking_planner) remasking_conf is # all zeros, so this would remask schedule-excess tokens by # position rather than by quality. if not disable_unmasking_planner: new_xt = apply_schedule_aware_remasking( model, new_xt, t, dt, remasking_conf, clean_index, mask, neg_inf, batch_size, unmask_quality_threshold=unmask_quality_threshold, ) remasking_score_temp = None else: raise ValueError(f"Unknown remasking_mode: {remasking_mode}") if remasking_score_temp is not None: remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) for j in range(batch_size): k = min(num_remasking, int(clean_index[j].sum().item())) if k > 0: _, select_indices = torch.topk(remasking_score_temp[j], k=k) new_xt[j, select_indices] = mask if _QUALITY_DEBUG: # Positions that were clean before this remasking block and are # now mask are exactly the unmasked tokens that got remasked. step_remasked = int((clean_index & (new_xt == mask)).sum().item()) if return_trace: for batch_idx in range(batch_size): for pos in range(max_length): if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="change", position=pos, token=mask, ) ) # --- Compute log probabilities for RND --- if compute_rnd: lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) log_policy_step = (lp * changed_mask).sum(dim=1) log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) log_rnd = log_pretrained_step - log_policy_step # (B,) # --- Insertion step --- if i != steps - 1: ext = torch.poisson(len_rate * dt).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] if _QUALITY_DEBUG: # ext has been masked by the max-length validity check above, so # this is the number of fresh mask tokens actually inserted. step_proposed_ins = int(ext.sum().item()) # Schedule-aware insertion quality filtering if use_remasking and not disable_insertion_planner: if compute_rnd: xt_tmp_before = xt_tmp.clone() dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0 xt_tmp = apply_schedule_aware_insertion( model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length, orig_mask, new_pos_orig, quality_threshold ) if _QUALITY_DEBUG: # Filtering only drops/compacts tokens, so the drop in # non-pad count is the number of insertions filtered out. step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item()) if compute_rnd: # Compute corrected ext based on what actually stayed ext_corrected = torch.zeros_like(ext) for b in range(batch_size): after_len = xt_tmp[b].ne(pad).sum().item() orig_len = xt_len[b].item() surviving_insertions = after_len - orig_len if total_ext[b] > 0: ratio = surviving_insertions / total_ext[b].item() ext_corrected[b] = (ext[b].float() * ratio).long() else: ext_corrected = ext else: ext_corrected = ext # Compute insertion log_rnd if compute_rnd: insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1) log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) log_insert_diff = log_pretrained_insert - log_policy_insert log_rnd += log_insert_diff else: xt_tmp = new_xt if return_trace: for batch_idx in range(batch_size): for j in range(max_length): if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="change", position=j, token=new_xt[batch_idx, j].item(), ) ) if i != steps - 1: for j in range(max_length): id = max_length - j - 1 if ext[batch_idx, id]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="insertion", position=id, token=mask, ) ) if _QUALITY_DEBUG: dbg_total_remasked += step_remasked dbg_total_proposed_ins += step_proposed_ins dbg_total_filtered += step_filtered print( f"[QUALITY {quality_mode}] step {i+1}/{steps}: " f"remasked {step_remasked} unmasked tokens -> mask | " f"insertions proposed {step_proposed_ins}, " f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}" ) xt = xt_tmp t = t + dt if _QUALITY_DEBUG: print( f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): " f"remasked {dbg_total_remasked} unmasked tokens | " f"insertions proposed {dbg_total_proposed_ins}, " f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}" ) return xt, log_rnd, sampling_trace @torch.no_grad() def sample_peptides_buffer( model, reward_model, analyzer, tokenizer, steps, mask, pad, batch_size, max_length, quality_mode="both", compute_rnd=False, pretrained=None, alpha=0.1, remasking_mode="schedule_aware", num_remasking=1, quality_threshold=1, min_length=0, ): """Generate peptides for training buffer. Args: model: Finetuned policy model. reward_model: Multi-objective scoring function. analyzer: PeptideAnalyzer for validation. tokenizer: Tokenizer for decoding. steps: Number of diffusion steps. mask: Mask token ID. pad: Pad token ID. batch_size: Number of sequences to generate. max_length: Maximum sequence length. quality_mode: "none", "both", "unmasking_only", or "insertion_only". compute_rnd: If True, compute step-wise log importance weights (requires pretrained). If False, returns placeholder zero log_rnd (for ELBO-based RND). pretrained: Frozen pretrained model (required when compute_rnd=True). alpha: RND scaling factor. remasking_mode: Remasking strategy. num_remasking: Number of tokens to remask per step. quality_threshold: Threshold for insertion quality filtering. Returns: (valid_x, log_rnd, scalar_rewards, sampling_trace) """ xt, log_rnd, trace = _diffusion_loop( model, steps, mask, pad, batch_size, max_length, quality_mode=quality_mode, compute_rnd=compute_rnd, pretrained=pretrained, remasking_mode=remasking_mode, num_remasking=num_remasking, quality_threshold=quality_threshold, ) device = xt.device decoded_samples = tokenizer.batch_decode(xt) valid_x_final = [] validSequences = [] valid_log_rnd = [] for idx, seq in enumerate(decoded_samples): if not analyzer.is_peptide(seq): continue token_len = int((xt[idx] != pad).sum().item()) if min_length > 0 and token_len < min_length: continue valid_x_final.append(xt[idx]) validSequences.append(seq) if compute_rnd: valid_log_rnd.append(log_rnd[idx]) print("len valid sequences:", len(validSequences)) if len(validSequences) == 0: print("[WARNING] No valid peptides generated in this batch") empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) return empty_x, empty_log_rnd, empty_rewards, trace score_vectors = reward_model(input_seqs=validSequences) scalar_rewards = np.sum(score_vectors, axis=-1) scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) print(f"scalar reward dim{len(scalar_rewards)}") valid_x_final = torch.stack(valid_x_final, dim=0) if compute_rnd: valid_log_rnd = torch.stack(valid_log_rnd, dim=0) log_rnd_out = valid_log_rnd + (scalar_rewards / alpha) else: log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device) return valid_x_final, log_rnd_out, scalar_rewards, trace @torch.no_grad() def sample_peptides_eval( model, reward_model, analyzer, tokenizer, steps, mask, pad, batch_size, max_length, quality_mode="both", remasking_mode="schedule_aware", num_remasking=1, quality_threshold=1, unmask_quality_threshold=None, unmask_all=False, freq_penalty=0.0, dataframe=False, return_valid=False, ): """Generate peptides for evaluation. Args: model: Finetuned policy model. reward_model: Multi-objective scoring function. analyzer: PeptideAnalyzer for validation. tokenizer: Tokenizer for decoding. steps: Number of diffusion steps. mask: Mask token ID. pad: Pad token ID. batch_size: Number of sequences to generate. max_length: Maximum sequence length. quality_mode: "none", "both", "unmasking_only", or "insertion_only". remasking_mode: Remasking strategy. num_remasking: Number of tokens to remask per step. quality_threshold: Threshold for insertion quality filtering. dataframe: If True, include a pandas DataFrame in the return. return_valid: If True, return decoded valid sequences instead of raw token tensors. Returns: For multi-objective (5 objectives): (samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df]) For single objective: (samples, sol, valid_fraction[, df]) When return_valid=True, samples is replaced with validSequences list. """ xt, _, trace = _diffusion_loop( model, steps, mask, pad, batch_size, max_length, quality_mode=quality_mode, compute_rnd=False, remasking_mode=remasking_mode, num_remasking=num_remasking, quality_threshold=quality_threshold, unmask_quality_threshold=unmask_quality_threshold, unmask_all=unmask_all, freq_penalty=freq_penalty, ) device = xt.device samples = xt.to(device) decoded_samples = tokenizer.batch_decode(samples) valid_x_final = [] validSequences = [] for idx, seq in enumerate(decoded_samples): if analyzer.is_peptide(seq): valid_x_final.append(samples[idx]) validSequences.append(seq) print("len valid sequences:", len(validSequences)) valid_fraction = len(validSequences) / batch_size # Determine number of objectives from reward model num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5 if len(validSequences) != 0: score_vectors = reward_model(input_seqs=validSequences) # (N, num_objectives) average_scores = score_vectors.T if num_objectives == 1: sol = average_scores[0] else: affinity = average_scores[0] sol = average_scores[1] hemo = average_scores[2] nf = average_scores[3] permeability = average_scores[4] else: zeros = [0.0] if num_objectives == 1: sol = zeros else: affinity = zeros sol = zeros hemo = zeros nf = zeros permeability = zeros if num_objectives == 1: if dataframe: df = pd.DataFrame({ "Peptide Sequence": validSequences, "Solubility": sol if len(validSequences) else [0.0], }) if return_valid: return validSequences, sol, valid_fraction, df return samples, sol, valid_fraction, df if return_valid: return validSequences, sol, valid_fraction return samples, sol, valid_fraction if dataframe: df = pd.DataFrame({ "Peptide Sequence": validSequences, "Binding Affinity": affinity if len(validSequences) else [0.0], "Solubility": sol if len(validSequences) else [0.0], "Hemolysis": hemo if len(validSequences) else [0.0], "Nonfouling": nf if len(validSequences) else [0.0], "Permeability": permeability if len(validSequences) else [0.0], }) if return_valid: return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df if return_valid: return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction return samples, affinity, sol, hemo, nf, permeability, valid_fraction