| """Unified molecule sampling with quality-guided planning. |
| |
| Supports 4 quality modes and optional RND (importance weight) computation. |
| |
| Quality modes: |
| "none" - No planner, no remasking (policy-only) |
| "both" - Both unmasking + insertion planners active |
| "unmasking_only" - Only unmasking/remasking planner (insertion planner disabled) |
| "insertion_only" - Only insertion planner (unmasking planner disabled) |
| |
| RND toggle: |
| compute_rnd=True - Run pretrained model in parallel, compute step-wise log importance weights |
| compute_rnd=False - Run policy model only (use with ELBO-based RND or eval) |
| """ |
|
|
| import torch |
| import numpy as np |
| import pandas as pd |
| import torch.nn.functional as F |
| from sampling import SamplingResult, SamplingTraceDatapoint, _sample_tokens |
| from remasking_scheduleaware import apply_schedule_aware_remasking, apply_schedule_aware_insertion |
| from mol_utils.utils_chem import batch_safe_to_smiles, batch_validate_and_extract |
| from tdc import Evaluator, Oracle |
|
|
| QUALITY_MODES = {"none", "both", "unmasking_only", "insertion_only"} |
|
|
|
|
| @torch.no_grad() |
| def _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| compute_rnd=False, |
| pretrained=None, |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| temperature=1.0, |
| return_trace=False, |
| unmask_quality_threshold=None, |
| ): |
| """Core discrete diffusion sampling loop for molecule generation. |
| |
| Args: |
| model: Finetuned policy model. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: One of "none", "both", "unmasking_only", "insertion_only". |
| compute_rnd: Whether to compute step-wise log importance weights. |
| pretrained: Frozen pretrained model (required if compute_rnd=True). |
| remasking_mode: Remasking strategy ("schedule_aware", "remdm", "remdm_conf"). |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. |
| temperature: Sampling temperature (1.0 = no scaling). |
| return_trace: Whether to record sampling trace. |
| |
| Returns: |
| (xt, log_rnd, sampling_trace) |
| log_rnd is None when compute_rnd=False. |
| """ |
| assert quality_mode in QUALITY_MODES, f"quality_mode must be one of {QUALITY_MODES}" |
| if compute_rnd: |
| assert pretrained is not None, "pretrained model required when compute_rnd=True" |
|
|
| |
| use_remasking = quality_mode != "none" |
| disable_unmasking_planner = quality_mode in ("none", "insertion_only") |
| disable_insertion_planner = quality_mode in ("none", "unmasking_only") |
|
|
| device = next(model.parameters()).device |
|
|
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| neg_inf = torch.tensor(-np.inf, device=device) |
|
|
| if use_remasking and remasking_mode == "remdm_conf": |
| remasking_score = torch.zeros((batch_size, max_length), device=device) |
|
|
| log_rnd = None |
|
|
| for i in range(steps): |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| if compute_rnd: |
| pretrained_pred = pretrained(xt, t) |
| pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) |
| pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() |
| pretrained_len_rate = pretrained_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| if compute_rnd: |
| pretrained_unmask_rate[xt != mask] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
| if compute_rnd: |
| pretrained_trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), |
| ) |
|
|
| |
| if temperature != 1.0: |
| logits = torch.log(trans_prob + 1e-10) / temperature |
| trans_prob = torch.softmax(logits, dim=-1) |
|
|
| |
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
|
|
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| num_zero_prob = mask_has_zero_prob.sum().item() |
| uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) |
| uniform_prob[:, :mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| |
| if use_remasking and remasking_mode == "remdm_conf" and i < steps - 1: |
| token_probs = F.softmax(unmask_rate, dim=-1) |
| chosen_probs = torch.gather(token_probs, dim=-1, index=new_xt.unsqueeze(-1)).squeeze(-1) |
| changed_mask_to_token = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
| remasking_score = torch.where(changed_mask_to_token, chosen_probs, remasking_score) |
|
|
| |
| if use_remasking and i < steps - 1: |
| if disable_unmasking_planner or not (hasattr(model, 'planner') and model.planner is not None): |
| remasking_conf = torch.zeros((batch_size, max_length), device=device) |
| else: |
| planner_out = model.planner(new_xt, t) |
| remasking_conf = planner_out["remasking_conf"].squeeze(-1) |
|
|
| clean_index = (new_xt != mask) & (new_xt != pad) |
|
|
| if remasking_mode == "schedule_aware": |
| new_xt = apply_schedule_aware_remasking( |
| model, new_xt, t, dt, remasking_conf, clean_index, |
| mask, neg_inf, batch_size, |
| unmask_quality_threshold=unmask_quality_threshold, |
| ) |
| remasking_score_temp = None |
| else: |
| raise ValueError(f"Unknown remasking_mode: {remasking_mode}") |
|
|
| if remasking_score_temp is not None: |
| remasking_score_temp = torch.where(clean_index, remasking_score_temp, neg_inf) |
| for j in range(batch_size): |
| k = min(num_remasking, int(clean_index[j].sum().item())) |
| if k > 0: |
| _, select_indices = torch.topk(remasking_score_temp[j], k=k) |
| new_xt[j, select_indices] = mask |
|
|
| if return_trace: |
| for batch_idx in range(batch_size): |
| for pos in range(max_length): |
| if clean_index[batch_idx, pos] and new_xt[batch_idx, pos] == mask: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=pos, |
| token=mask, |
| ) |
| ) |
|
|
| |
| if compute_rnd: |
| lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
| lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
|
|
| changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
|
|
| log_policy_step = (lp * changed_mask).sum(dim=1) |
| log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) |
|
|
| log_rnd = log_pretrained_step - log_policy_step |
|
|
| |
| if i != steps - 1: |
| ext = torch.poisson(len_rate * dt).long() |
|
|
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
|
|
| |
| if use_remasking and not disable_insertion_planner: |
| if compute_rnd: |
| xt_tmp_before = xt_tmp.clone() |
|
|
| xt_tmp = apply_schedule_aware_insertion( |
| model, xt_tmp, new_xt, t, dt, ext, mask, pad, max_length, |
| orig_mask, new_pos_orig, quality_threshold |
| ) |
|
|
| if compute_rnd: |
| |
| ext_corrected = torch.zeros_like(ext) |
| for b in range(batch_size): |
| after_len = xt_tmp[b].ne(pad).sum().item() |
| orig_len = xt_len[b].item() |
| surviving_insertions = after_len - orig_len |
| if total_ext[b] > 0: |
| ratio = surviving_insertions / total_ext[b].item() |
| ext_corrected[b] = (ext[b].float() * ratio).long() |
| else: |
| ext_corrected = ext |
| else: |
| ext_corrected = ext |
|
|
| |
| if compute_rnd: |
| insertion_rate = (len_rate * dt).clamp(min=1e-10) |
| pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) |
|
|
| log_policy_insert = (ext_corrected * torch.log(insertion_rate) - insertion_rate).sum(dim=1) |
| log_pretrained_insert = (ext_corrected * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) |
|
|
| log_insert_diff = log_pretrained_insert - log_policy_insert |
| log_rnd += log_insert_diff |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| for batch_idx in range(batch_size): |
| for j in range(max_length): |
| if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[batch_idx, j].item(), |
| ) |
| ) |
|
|
| if i != steps - 1: |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[batch_idx, id]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| xt = xt_tmp |
| t = t + dt |
|
|
| return xt, log_rnd, sampling_trace |
|
|
|
|
| def _decode_and_validate(model, tokenizer, samples): |
| """Decode token IDs to SMILES and validate. |
| |
| Returns: |
| (validSequences, valid_indices): list of valid SMILES, list of batch indices. |
| """ |
| decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True) |
|
|
| use_bracket_safe = model.config.training.get('use_bracket_safe', False) |
| smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True) |
|
|
| |
| validSequences = [] |
| valid_indices = [] |
| for idx, s in enumerate(smiles_samples): |
| if s: |
| largest_frag = sorted(s.split('.'), key=len)[-1] |
| validSequences.append(largest_frag) |
| valid_indices.append(idx) |
|
|
| return validSequences, valid_indices |
|
|
|
|
| @torch.no_grad() |
| def sample_mol_buffer( |
| model, pretrained, reward_model, tokenizer, |
| steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| alpha=0.1, |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| temperature=1.0, |
| use_quality_filter=True, |
| ): |
| """Generate molecules for training buffer. Always computes step-wise RND. |
| |
| Args: |
| model: Finetuned policy model. |
| pretrained: Frozen pretrained model. |
| reward_model: Molecule scoring function. |
| tokenizer: SAFE tokenizer for decoding. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: "none", "both", "unmasking_only", or "insertion_only". |
| alpha: RND scaling factor. |
| remasking_mode: Remasking strategy. |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. None if schedule-driven. |
| temperature: Sampling temperature. |
| use_quality_filter: If True, filter to QED>=0.6 and SA<=4. |
| |
| Returns: |
| (valid_x, log_rnd, scalar_rewards, sampling_trace) |
| """ |
| xt, log_rnd, trace = _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode=quality_mode, |
| compute_rnd=True, |
| pretrained=pretrained, |
| remasking_mode=remasking_mode, |
| num_remasking=num_remasking, |
| quality_threshold=quality_threshold, |
| temperature=temperature, |
| ) |
|
|
| device = xt.device |
| samples = xt.to(device) |
|
|
| validSequences, valid_indices = _decode_and_validate(model, tokenizer, samples) |
|
|
| valid_x_final = [samples[idx] for idx in valid_indices] |
| valid_log_rnd = [log_rnd[idx] for idx in valid_indices] |
|
|
| print("len valid sequences:", len(validSequences)) |
|
|
| if len(validSequences) == 0: |
| print("[WARNING] No valid molecules generated in this batch") |
| empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) |
| empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) |
| empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) |
| return empty_x, empty_log_rnd, empty_rewards, trace |
|
|
| |
| score_vectors = reward_model(input_seqs=validSequences) |
| scalar_rewards = np.sum(score_vectors, axis=-1) |
| scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) |
|
|
| print(f"scalar reward dim{len(scalar_rewards)}") |
| valid_log_rnd = torch.stack(valid_log_rnd, dim=0) |
|
|
| log_rnd = valid_log_rnd + (scalar_rewards / alpha) |
| valid_x_final = torch.stack(valid_x_final, dim=0) |
|
|
| |
| if use_quality_filter: |
| qed_scores = score_vectors[:, 0] |
| if score_vectors.shape[1] > 1: |
| sa_scores = score_vectors[:, 1] |
| else: |
| _oracle_sa = Oracle('sa') |
| raw_sa = np.array(_oracle_sa(validSequences)) |
| sa_scores = raw_sa |
| quality_mask = (qed_scores >= 0.6) & (sa_scores <= 4) |
|
|
| n_quality = quality_mask.sum() |
| print(f"Quality filtering: {n_quality}/{len(validSequences)} sequences pass (QED>=0.6, SA<=4)") |
|
|
| if n_quality == 0: |
| print("[WARNING] No quality molecules in this batch") |
| empty_x = torch.empty((0, max_length), dtype=torch.long, device=device) |
| empty_log_rnd = torch.empty((0,), dtype=torch.float32, device=device) |
| empty_rewards = torch.empty((0,), dtype=torch.float32, device=device) |
| return empty_x, empty_log_rnd, empty_rewards, trace |
|
|
| quality_mask_torch = torch.as_tensor(quality_mask, dtype=torch.bool, device=device) |
|
|
| quality_x_final = valid_x_final[quality_mask_torch] |
| quality_log_rnd = log_rnd[quality_mask_torch] |
| quality_rewards = scalar_rewards[quality_mask_torch] |
| else: |
| print(f"No quality filtering applied - using all {len(validSequences)} valid molecules") |
| quality_x_final = valid_x_final |
| quality_log_rnd = log_rnd |
| quality_rewards = scalar_rewards |
|
|
| return quality_x_final, quality_log_rnd, quality_rewards, trace |
|
|
|
|
| @torch.no_grad() |
| def sample_mol_eval( |
| model, reward_model, tokenizer, |
| steps, mask, pad, batch_size, max_length, |
| quality_mode="both", |
| remasking_mode="schedule_aware", |
| num_remasking=1, |
| quality_threshold=1, |
| temperature=1.0, |
| evaluator=None, |
| dataframe=False, |
| unmask_quality_threshold=None, |
| ): |
| """Generate molecules for evaluation. |
| |
| Args: |
| model: Finetuned policy model. |
| reward_model: Molecule scoring function. |
| tokenizer: SAFE tokenizer for decoding. |
| steps: Number of diffusion steps. |
| mask: Mask token ID. |
| pad: Pad token ID. |
| batch_size: Number of sequences to generate. |
| max_length: Maximum sequence length. |
| quality_mode: "none", "both", "unmasking_only", or "insertion_only". |
| remasking_mode: Remasking strategy. |
| num_remasking: Number of tokens to remask per step. |
| quality_threshold: Threshold for insertion quality filtering. Pass None |
| to use schedule-driven deletion with no threshold gate |
| temperature: Sampling temperature. |
| evaluator: TDC Evaluator for diversity (created if None). |
| dataframe: If True, include a pandas DataFrame in the return. |
| |
| Returns: |
| Without dataframe: |
| (validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction) |
| With dataframe: |
| (validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df) |
| validSequences is the raw list including duplicates; qed/sa are scored |
| on the unique set. Caller can dedup with set(validSequences). The |
| dataframe (when requested) has one row per unique molecule. |
| """ |
| if evaluator is None: |
| evaluator = Evaluator('diversity') |
|
|
| xt, _, trace = _diffusion_loop( |
| model, steps, mask, pad, batch_size, max_length, |
| quality_mode=quality_mode, |
| compute_rnd=False, |
| remasking_mode=remasking_mode, |
| num_remasking=num_remasking, |
| quality_threshold=quality_threshold, |
| temperature=temperature, |
| unmask_quality_threshold=unmask_quality_threshold, |
| ) |
|
|
| device = xt.device |
| samples = xt.to(device) |
|
|
| decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True) |
|
|
| use_bracket_safe = model.config.training.get('use_bracket_safe', False) |
| smiles_samples = batch_safe_to_smiles(decoded_samples, use_bracket_safe=use_bracket_safe, fix=True) |
|
|
| |
| validSequences = [sorted(s.split('.'), key=len)[-1] for s in smiles_samples if s] |
|
|
| print("len valid sequences:", len(validSequences)) |
| valid_fraction = len(validSequences) / batch_size |
| uniqueSequences = list(set(validSequences)) |
| uniqueness = len(uniqueSequences) / len(validSequences) if len(validSequences) > 0 else 0 |
| diversity = evaluator(uniqueSequences) if len(uniqueSequences) > 0 else 0 |
|
|
| |
| if len(uniqueSequences) > 0: |
| score_vectors_temp = reward_model(input_seqs=list(uniqueSequences)) |
| qed_scores = score_vectors_temp[:, 0] |
|
|
| |
| _oracle_sa = Oracle('sa') |
| raw_sa_scores = np.array(_oracle_sa(list(uniqueSequences))) |
|
|
| quality_count = sum((qed_scores >= 0.6) & (raw_sa_scores <= 4)) |
| quality = quality_count / batch_size |
| print(f'Quality:\t{quality}') |
|
|
| qed = qed_scores |
| sa = raw_sa_scores |
| else: |
| zeros = [0.0] |
| qed = zeros |
| sa = zeros |
| quality = 0.0 |
|
|
| if dataframe: |
| df = pd.DataFrame({ |
| "Mol Sequence": uniqueSequences, |
| "QED": qed if len(uniqueSequences) else [0.0], |
| "SA": sa if len(uniqueSequences) else [0.0], |
| }) |
| return validSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df |
|
|
| return validSequences, qed, sa, uniqueness, diversity, quality, valid_fraction |
|
|