| """Unified peptide sampling with quality-guided planning. |
| |
| Supports 4 quality modes and optional RND (importance weight) computation. |
| |
| Quality modes: |
| "none" - No planner, no remasking (policy-only) |
| "both" - Both unmasking + insertion planners active |
| "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled) |
| "insertion_only" - Only insertion planner (unmasking planner disabled) |
| |
| RND toggle: |
| compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights |
| compute_rnd=False - Run policy model only (use with ELBO-based RND or eval) |
| """ |
|
|
| import os |
| import torch |
| import numpy as np |
| import pandas as pd |
| import torch.nn.functional as F |
| from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens |
| from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion |
|
|
| QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"} |
|
|
| |
| |
| |
| |
| _QUALITY_DEBUG = os.environ.get("A2D2_QUALITY_DEBUG", "") not in ("", "0", "false", "False") |
|
|
|
|
| @torch.no_grad() |
| def _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| compute_rnd=False, |
| pretrained=None, |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| unmask_quality_threshold=None, |
| unmask_all=False, |
| freq_penalty=0.0, |
| return_trace=False, |
| ): |
| """Core discrete diffusion sampling loop for peptide generation. |
| |
| Args: |
| model: Finetuned policy model. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: One of "none", "both", "unmasking_only", "insertion_only". |
| compute_rnd: Whether to compute step-wise log importance weights. |
| pretrained: Frozen pretrained model (required if compute_rnd=True). |
| remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf"). |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. |
| return_trace: Whether to record sampling trace. |
| |
| Returns: |
| (xt, log_rnd, sampling_trace) |
| log_rnd is None when compute_rnd=False. |
| """ |
| assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}" |
| if compute_rnd: |
| assert pretrained is not None, "pretrained model required when compute_rnd=True" |
|
|
| |
| use_remasking = quality_mode != "none" |
| disable_unmasking_planner = quality_mode in ("none", "insertion_only") |
| disable_insertion_planner = quality_mode in ("none", "unmasking_only") |
|
|
| device = next(model.parameters()).device |
|
|
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| neg_inf = torch.tensor(-np.inf, device=device) |
|
|
| if use_remasking and remasking_mode == "remdm_conf": |
| remasking_score = torch.zeros((batch_size, max_length), device=device) |
|
|
| log_rnd = None |
|
|
| dbg_total_remasked = 0 |
| dbg_total_proposed_ins = 0 |
| dbg_total_filtered = 0 |
|
|
| for i in range(steps): |
| step_remasked = 0 |
| step_proposed_ins = 0 |
| step_filtered = 0 |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| if compute_rnd: |
| pretrained_pred = pretrained(xt, t) |
| pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) |
| pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() |
| pretrained_len_rate = pretrained_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| if compute_rnd: |
| pretrained_unmask_rate[xt != mask] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
| if compute_rnd: |
| pretrained_trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), |
| ) |
|
|
| |
| |
| |
| |
| if i == steps - 1 or unmask_all: |
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
|
|
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| num_zero_prob = mask_has_zero_prob.sum().item() |
| uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) |
| uniform_prob[:, :mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| |
| |
| |
| |
| |
| |
| sample_prob = trans_prob |
| if freq_penalty > 0.0: |
| V = trans_prob.shape[-1] |
| clean_tok = (xt != mask) & (xt != pad) |
| counts = torch.zeros(batch_size, V, device=device, dtype=trans_prob.dtype) |
| counts.scatter_add_(1, torch.where(clean_tok, xt, torch.zeros_like(xt)), |
| clean_tok.to(trans_prob.dtype)) |
| sample_prob = trans_prob * torch.exp(-freq_penalty * counts).unsqueeze(1) |
|
|
| new_xt = _sample_tokens(sample_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| |
| if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1: |
| token_probs = F.softmax(unmask_rate, dim=-1) |
| chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) |
| changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
| remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score) |
|
|
| |
| if use_remasking and i < steps - 1: |
| if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None): |
| remasking_conf = torch.zeros((batch_size, max_length), device=device) |
| else: |
| planner_out = model.planner(new_xt, t) |
| remasking_conf = planner_out["remasking_conf"].squeeze(-1) |
|
|
| clean_index = (new_xt != mask) & (new_xt != pad) |
|
|
| if remasking_mode == "remdm": |
| remasking_score_temp = torch.rand(remasking_conf.shape, device=device) |
| elif remasking_mode == "remdm_conf": |
| remasking_score_temp = -1.0 * remasking_conf |
| elif remasking_mode == "schedule_aware": |
| |
| |
| |
| |
| if not disable_unmasking_planner: |
| new_xt = apply_schedule_aware_remasking( |
| model, new_xt, t, dt, remasking_conf, clean_index, |
| mask, neg_inf, batch_size, |
| unmask_quality_threshold=unmask_quality_threshold, |
| ) |
| remasking_score_temp = None |
| else: |
| raise ValueError(f"Unknown remasking_mode: {remasking_mode}") |
|
|
| if remasking_score_temp is not None: |
| remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) |
| for j in range(batch_size): |
| k = min(num_remasking, int(clean_index[j].sum().item())) |
| if k > 0: |
| _, select_indices = torch.topk(remasking_score_temp[j], k=k) |
| new_xt[j, select_indices] = mask |
|
|
| if _QUALITY_DEBUG: |
| |
| |
| step_remasked = int((clean_index & (new_xt == mask)).sum().item()) |
|
|
| if return_trace: |
| for batch_idx in range(batch_size): |
| for pos in range(max_length): |
| if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=pos, |
| token=mask, |
| ) |
| ) |
|
|
| |
| if compute_rnd: |
| lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
| lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
|
|
| changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
|
|
| log_policy_step = (lp * changed_mask).sum(dim=1) |
| log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) |
|
|
| log_rnd = log_pretrained_step - log_policy_step |
|
|
| |
| if i != steps - 1: |
| ext = torch.poisson(len_rate * dt).long() |
|
|
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
|
|
| if _QUALITY_DEBUG: |
| |
| |
| step_proposed_ins = int(ext.sum().item()) |
|
|
| |
| if use_remasking and not disable_insertion_planner: |
| if compute_rnd: |
| xt_tmp_before = xt_tmp.clone() |
|
|
| dbg_nonpad_before = int((xt_tmp != pad).sum().item()) if _QUALITY_DEBUG else 0 |
|
|
| xt_tmp = apply_schedule_aware_insertion( |
| model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length, |
| orig_mask, new_pos_orig, quality_threshold |
| ) |
|
|
| if _QUALITY_DEBUG: |
| |
| |
| step_filtered = dbg_nonpad_before - int((xt_tmp != pad).sum().item()) |
|
|
| if compute_rnd: |
| |
| ext_corrected = torch.zeros_like(ext) |
| for b in range(batch_size): |
| after_len = xt_tmp[b].ne(pad).sum().item() |
| orig_len = xt_len[b].item() |
| surviving_insertions = after_len - orig_len |
| if total_ext[b] > 0: |
| ratio = surviving_insertions / total_ext[b].item() |
| ext_corrected[b] = (ext[b].float() * ratio).long() |
| else: |
| ext_corrected = ext |
| else: |
| ext_corrected = ext |
|
|
| |
| if compute_rnd: |
| insertion_rate = (len_rate * dt).clamp(min=1e-10) |
| pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) |
|
|
| log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1) |
| log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) |
|
|
| log_insert_diff = log_pretrained_insert - log_policy_insert |
| log_rnd += log_insert_diff |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| for batch_idx in range(batch_size): |
| for j in range(max_length): |
| if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[batch_idx, j].item(), |
| ) |
| ) |
|
|
| if i != steps - 1: |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[batch_idx, id]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| if _QUALITY_DEBUG: |
| dbg_total_remasked += step_remasked |
| dbg_total_proposed_ins += step_proposed_ins |
| dbg_total_filtered += step_filtered |
| print( |
| f"[QUALITY {quality_mode}] step {i+1}/{steps}: " |
| f"remasked {step_remasked} unmasked tokens -> mask | " |
| f"insertions proposed {step_proposed_ins}, " |
| f"filtered {step_filtered}, kept {step_proposed_ins - step_filtered}" |
| ) |
|
|
| xt = xt_tmp |
| t = t + dt |
|
|
| if _QUALITY_DEBUG: |
| print( |
| f"[QUALITY {quality_mode}] TOTAL over {steps} steps (batch_size={batch_size}): " |
| f"remasked {dbg_total_remasked} unmasked tokens | " |
| f"insertions proposed {dbg_total_proposed_ins}, " |
| f"filtered {dbg_total_filtered}, kept {dbg_total_proposed_ins - dbg_total_filtered}" |
| ) |
|
|
| return xt, log_rnd, sampling_trace |
|
|
|
|
| @torch.no_grad() |
| def sample_peptides_buffer( |
| model, reward_model, analyzer, tokenizer, |
| steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| compute_rnd=False, |
| pretrained=None, |
| alpha=0.1, |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| min_length=0, |
| ): |
| """Generate peptides for training buffer. |
| |
| Args: |
| model: Finetuned policy model. |
| reward_model: Multi-objective scoring function. |
| analyzer: PeptideAnalyzer for validation. |
| tokenizer: Tokenizer for decoding. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: "none", "both", "unmasking_only", or "insertion_only". |
| compute_rnd: If True, compute step-wise log importance weights (requires pretrained). |
| If False, returns placeholder zero log_rnd (for ELBO-based RND). |
| pretrained: Frozen pretrained model (required when compute_rnd=True). |
| alpha: RND scaling factor. |
| remasking_mode: Remasking strategy. |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. |
| |
| Returns: |
| (valid_x, log_rnd, scalar_rewards, sampling_trace) |
| """ |
| xt, log_rnd, trace = _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode=quality_mode, |
| compute_rnd=compute_rnd, |
| pretrained=pretrained, |
| remasking_mode=remasking_mode, |
| num_remasking=num_remasking, |
| quality_threshold=quality_threshold, |
| ) |
|
|
| device = xt.device |
| decoded_samples = tokenizer.batch_decode(xt) |
|
|
| valid_x_final = [] |
| validSequences = [] |
| valid_log_rnd = [] |
|
|
| for idx, seq in enumerate(decoded_samples): |
| if not analyzer.is_peptide(seq): |
| continue |
| token_len = int((xt[idx] != pad).sum().item()) |
| if min_length > 0 and token_len < min_length: |
| continue |
| valid_x_final.append(xt[idx]) |
| validSequences.append(seq) |
| if compute_rnd: |
| valid_log_rnd.append(log_rnd[idx]) |
|
|
| print("len valid sequences:", len(validSequences)) |
|
|
| if len(validSequences) == 0: |
| print("[WARNING] No valid peptides generated in this batch") |
| empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) |
| empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) |
| empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) |
| return empty_x, empty_log_rnd, empty_rewards, trace |
|
|
| score_vectors = reward_model(input_seqs=validSequences) |
| scalar_rewards = np.sum(score_vectors, axis=-1) |
| scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) |
|
|
| print(f"scalar reward dim{len(scalar_rewards)}") |
| valid_x_final = torch.stack(valid_x_final, dim=0) |
|
|
| if compute_rnd: |
| valid_log_rnd = torch.stack(valid_log_rnd, dim=0) |
| log_rnd_out = valid_log_rnd + (scalar_rewards / alpha) |
| else: |
| log_rnd_out = torch.zeros(len(validSequences), dtype=torch.float32, device=device) |
|
|
| return valid_x_final, log_rnd_out, scalar_rewards, trace |
|
|
|
|
| @torch.no_grad() |
| def sample_peptides_eval( |
| model, reward_model, analyzer, tokenizer, |
| steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| unmask_quality_threshold=None, |
| unmask_all=False, |
| freq_penalty=0.0, |
| dataframe=False, |
| return_valid=False, |
| ): |
| """Generate peptides for evaluation. |
| |
| Args: |
| model: Finetuned policy model. |
| reward_model: Multi-objective scoring function. |
| analyzer: PeptideAnalyzer for validation. |
| tokenizer: Tokenizer for decoding. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: "none", "both", "unmasking_only", or "insertion_only". |
| remasking_mode: Remasking strategy. |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. |
| dataframe: If True, include a pandas DataFrame in the return. |
| return_valid: If True, return decoded valid sequences instead of raw token tensors. |
| |
| Returns: |
| For multi-objective (5 objectives): |
| (samples, affinity, sol, hemo, nf, permeability, valid_fraction[, df]) |
| For single objective: |
| (samples, sol, valid_fraction[, df]) |
| When return_valid=True, samples is replaced with validSequences list. |
| """ |
| xt, _, trace = _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode=quality_mode, |
| compute_rnd=False, |
| remasking_mode=remasking_mode, |
| num_remasking=num_remasking, |
| quality_threshold=quality_threshold, |
| unmask_quality_threshold=unmask_quality_threshold, |
| unmask_all=unmask_all, |
| freq_penalty=freq_penalty, |
| ) |
|
|
| device = xt.device |
| samples = xt.to(device) |
| decoded_samples = tokenizer.batch_decode(samples) |
|
|
| valid_x_final = [] |
| validSequences = [] |
|
|
| for idx, seq in enumerate(decoded_samples): |
| if analyzer.is_peptide(seq): |
| valid_x_final.append(samples[idx]) |
| validSequences.append(seq) |
|
|
| print("len valid sequences:", len(validSequences)) |
|
|
| valid_fraction = len(validSequences) / batch_size |
|
|
| |
| num_objectives = len(reward_model.score_func_names) if hasattr(reward_model, 'score_func_names') else 5 |
|
|
| if len(validSequences) != 0: |
| score_vectors = reward_model(input_seqs=validSequences) |
| average_scores = score_vectors.T |
|
|
| if num_objectives == 1: |
| sol = average_scores[0] |
| else: |
| affinity = average_scores[0] |
| sol = average_scores[1] |
| hemo = average_scores[2] |
| nf = average_scores[3] |
| permeability = average_scores[4] |
| else: |
| zeros = [0.0] |
|
|
| if num_objectives == 1: |
| sol = zeros |
| else: |
| affinity = zeros |
| sol = zeros |
| hemo = zeros |
| nf = zeros |
| permeability = zeros |
|
|
| if num_objectives == 1: |
| if dataframe: |
| df = pd.DataFrame({ |
| "Peptide Sequence": validSequences, |
| "Solubility": sol if len(validSequences) else [0.0], |
| }) |
| if return_valid: |
| return validSequences, sol, valid_fraction, df |
| return samples, sol, valid_fraction, df |
|
|
| if return_valid: |
| return validSequences, sol, valid_fraction |
| return samples, sol, valid_fraction |
|
|
| if dataframe: |
| df = pd.DataFrame({ |
| "Peptide Sequence": validSequences, |
| "Binding Affinity": affinity if len(validSequences) else [0.0], |
| "Solubility": sol if len(validSequences) else [0.0], |
| "Hemolysis": hemo if len(validSequences) else [0.0], |
| "Nonfouling": nf if len(validSequences) else [0.0], |
| "Permeability": permeability if len(validSequences) else [0.0], |
| }) |
| if return_valid: |
| return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction, df |
| return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df |
|
|
| if return_valid: |
| return validSequences, affinity, sol, hemo, nf, permeability, valid_fraction |
| return samples, affinity, sol, hemo, nf, permeability, valid_fraction |
|
|