from pathlib import Path from typing import Tuple, Optional, Dict import numpy as np import torch from torch import nn import lightning.pytorch as pl import logging import huggingface_hub from .ligands.rdkit_utils import validate_smile, calc_chem_desc, tanimoto_smiles from .ligands.smiles_tokenizer import ChemformerTokenizer from .noise_schedule import _sample_t, q_xt, _sample_categorical, LogLinearNoise from .decoder_rope import Decoder_RoPE logger = logging.getLogger("lightning") class ModelGenerator(pl.LightningModule): """ ProtoBind-Diff model with SMILES and ESM-2 protein encodings. """ @staticmethod def get_exp_dir( exp_dir: str | None, output_dir: str, exp_dir_prefix: str, split: str ) -> Path: """Determines the experiment directory path.""" if exp_dir: return Path(exp_dir) return Path(output_dir) / split / exp_dir_prefix def __init__(self, *args, **kwargs): """Initializes the Lightning Module, saves hyperparameters, and configures the model.""" super().__init__() is_load = kwargs['load'] if not is_load: self.save_hyperparameters() self.data_dir = Path(kwargs["data_dir"]) exp_dir = kwargs.get('exp_dir', None) self.exp_dir = self.get_exp_dir( exp_dir=exp_dir, output_dir=kwargs["output_dir"], exp_dir_prefix=kwargs["exp_dir_prefix"], split=kwargs["split"] ) self.configure_model_params(**kwargs) def configure_model_params(self, **kwargs): """Parses keyword arguments to configure the model, tokenizer, and training parameters.""" self.learning_rate = kwargs.pop('learning_rate') self.weight_decay = float(kwargs.pop('weight_decay')) # Decoder params for masked diffusion decoder_params = { 'nhead': kwargs['num_heads_decoder'], 'n_layers': kwargs['num_decoder_layers'], 'hidden_size': kwargs['decoder_hidd_dim'], 'expand_feedforward': kwargs['expand_feedforward'], 'decoder_name': kwargs['decoder_name'], } # Tokenizer params tokenizer_path = kwargs.get('tokenizer_path') if tokenizer_path: self.tokenizer = ChemformerTokenizer(filename=tokenizer_path) else: self.tokenizer = ChemformerTokenizer(filename=self.data_dir / f"{kwargs['tokenizer_json_name']}.json") # Masking params self.noise = LogLinearNoise() self.mask_index = self.tokenizer.mask_token_id # Sampler params self.model_length = 170 self.noise_removal = True self.nucleus_p = 0.9 self.eta = 0.1 self.sampling_steps = 100 self.time_conditioning = False self.return_attention = False self.model = ProtobindMaskedDiffusion( embedding_dim=kwargs['seq_embedding_dim'], mask_index=self.mask_index, vocab_size=self.tokenizer.vocab_size, decoder_params=decoder_params, dropout=kwargs['dropout'], ) self.optimizer = kwargs.get('optimizer', 'Adam') def generate_mols(self, sequence: Tuple[torch.Tensor, torch.Tensor], return_attention=False) -> Tuple[np.array, torch.Tensor,np.array]: """Generates and validates SMILES strings for a given protein sequence. This method calls the internal sampler, decodes the generated tokens into SMILES strings, and filters out any invalid molecules. Args: sequence (Tuple[torch.Tensor, torch.Tensor]): The conditioned protein sequence embedding and its length. return_attention (bool): Whether to return attention maps from the sampler. Returns: Tuple[np.array, torch.Tensor, np.array]: A tuple containing the valid SMILES strings, corresponding attention maps, and the mask of valid indices. """ samples, attention = self._sample(sequence, return_attention=return_attention) text_samples = self.tokenizer.decode(samples.long()) text_samples = np.array([validate_smile(smile) for smile in text_samples]) mask_invalid = (text_samples != None) & (text_samples != '.') & (text_samples != '') text_samples = text_samples[mask_invalid] if attention is not None: attention = attention[mask_invalid] return text_samples, attention, mask_invalid def predict_step(self, batch, batch_idx): sequence, smiles, seq_id, smi_id = batch gen_samples, attention, mask_invalid = self.generate_mols( sequence, return_attention=self.return_attention) seq_id = seq_id[mask_invalid] return gen_samples, attention, seq_id def training_step(self, batch, batch_idx): return self.common_step(batch, "train", batch_idx) def validation_step(self, batch, batch_idx, dataloader_idx=None): # dataloader_idx to predict on several validation sets return self.common_step(batch, "val", batch_idx, dataloader_idx) def test_step(self, batch, batch_idx, dataloader_idx=0): return self.common_step(batch, "test", batch_idx) def common_step(self, batch, description, batch_idx, dataloader_idx=None): """Performs a common training, validation, or test step. This method takes a batch, applies noise according to the diffusion timestep, runs the model forward, calculates the loss, and logs metrics. Args: batch (Tuple): The input batch from the dataloader. description (str): The step description (e.g., 'train', 'val'). batch_idx (int): The index of the batch. Returns: torch.Tensor: The calculated loss for the batch. """ sequence, smiles, seq_id, smi_id = batch # Get data and apply noise X, length = smiles bs = X.shape[0] X = X.squeeze(-1) padding_mask = (X != 0).float() # 0 is pad token id t = _sample_t(X.shape[0], X.device) sigma, dsigma = self.noise(t) move_chance = 1 - torch.exp(-sigma[:, None]) xt = q_xt(X, move_chance, self.mask_index) xt = xt.unsqueeze(dim=2) smiles_t = (xt, length, None) pred_x, _ = self.model(sequence, smiles_t, sigma, padding_mask) total_loss = self.loss_mdlm(X.long(), pred_x, sigma, dsigma, padding_mask=None) if batch_idx % 50 == 0: tokens = pred_x.argmax(dim=-1) * padding_mask true_smiles = self.tokenizer.decode(X.long()) pred_smiles = [smile for smile in self.tokenizer.decode(tokens)] pred_smiles_valid = [validate_smile(smile) for smile in pred_smiles] try: tanimoto = np.asarray([tanimoto_smiles(mol_pred, mol_ref) for mol_pred, mol_ref in zip(pred_smiles_valid, true_smiles) if mol_pred is not None]) tanimoto_mean = np.mean(tanimoto) if len(tanimoto) > 0 else 0 num_mols_valid = len(tanimoto) except: num_mols_valid = 0 tanimoto_mean = 0.0 self.log(f"{description}_tanimoto", tanimoto_mean, prog_bar=True, on_epoch=True, sync_dist=True) self.log(f"{description}_perc_of_valid", num_mols_valid / bs * 100, prog_bar=True, on_epoch=True, sync_dist=True) self.log(f"{description}_loss", total_loss, prog_bar=True, on_epoch=True, sync_dist=True, batch_size=bs) return total_loss def configure_optimizers(self): if self.weight_decay > 0.: optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) else: optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer def loss_mdlm(self, x_0, model_output, sigma, dsigma, padding_mask=None): """Loss for SUBS parameterization, continuous time case""" log_p_theta = torch.gather( input=model_output, dim=-1, index=x_0[:, :, None]).squeeze(-1) loss = - log_p_theta * (dsigma / torch.expm1(sigma))[:, None] if padding_mask is not None: return (loss * padding_mask).sum() / padding_mask.sum() return loss.mean() def _sample_prior(self, *batch_dims): return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64) def _ddpm_caching_update(self, sequence, x, t, dt, p_x0=None, conf=None, return_attention=False): attention = None if t.ndim > 1: t = t.squeeze(-1) sigma_t, _ = self.noise(t) assert t.ndim == 1 move_chance_t = t[:, None, None] move_chance_s = (t - dt)[:, None, None] assert move_chance_t.ndim == 3, move_chance_t.shape padding_mask = (x != 0).float() if p_x0 is None: p_x0, attention = self.model(sequence, (x.unsqueeze(dim=2), None, None), sigma_t, padding_mask, return_attention=return_attention) p_x0 = p_x0.exp() if self.nucleus_p < 1: sorted_probs, sorted_indices = torch.sort(p_x0, descending=True, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) top_p_mask = cumulative_probs <= self.nucleus_p top_p_mask[..., 0] = True nucleus_probs = sorted_probs * top_p_mask nucleus_probs /= nucleus_probs.sum(dim=-1, keepdim=True) p_x0 = torch.zeros_like(p_x0).scatter_(-1, sorted_indices, nucleus_probs) assert move_chance_t.ndim == p_x0.ndim # Use remdm-cap sampler alpha_t = (1 - move_chance_t)[0].item() alpha_s = (1 - move_chance_s)[0].item() if alpha_t > 0: sigma = min(self.eta, (1 - alpha_s) / alpha_t) else: sigma = self.eta q_xs = p_x0 * (1 - sigma) q_xs[..., self.mask_index] = sigma q_xs_2 = p_x0 * ((alpha_s - (1 - sigma) * alpha_t) / (1 - alpha_t)) q_xs_2[..., self.mask_index] = (1 - alpha_s - sigma * alpha_t) / (1 - alpha_t) copy_flag = (x != self.mask_index).to(torch.bool) q_xs = torch.where(copy_flag.unsqueeze(-1), q_xs, q_xs_2) xs = _sample_categorical(q_xs) if torch.allclose(xs, x) and not self.time_conditioning: p_x0_cache = p_x0 else: p_x0_cache = None return p_x0_cache, xs, conf, attention @torch.no_grad() def _sample(self, sequence, eps=1e-3, return_attention=False): """Generate samples from the model""" num_steps = self.sampling_steps bs = sequence[0].shape[0] x = self._sample_prior(bs, self.model_length).to(self.device) timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) dt = (1 - eps) / num_steps p_x0_cache = None min_t = timesteps[-1].item() confident_score = - torch.ones_like(x, device=self.device) * torch.inf for i in range(num_steps): t = timesteps[i] * torch.ones(bs, 1, device=self.device) p_x0_cache, x_next, confident_score, attention = self._ddpm_caching_update( sequence, x, t, dt, p_x0=p_x0_cache, conf=confident_score, return_attention=return_attention) if (not torch.allclose(x_next, x)): p_x0_cache = None x = x_next if self.noise_removal: t = min_t * torch.ones(bs, 1, device=self.device) unet_conditioning = self.noise(t)[0] padding_mask = (x != 0).float() x, attention = self.model(sequence, (x, None, None), unet_conditioning.squeeze(-1), padding_mask, return_attention=return_attention) x = x.argmax(dim=-1) return x, attention class ProtobindMaskedDiffusion(nn.Module, huggingface_hub.PyTorchModelHubMixin): """The core Protobind-Diff model, which uses a Transformer decoder with RoPE. This model is designed for a masked diffusion task and supports conditioning on ESM-2 protein embeddings and generating ligands with a ChemformerTokenizer. """ def __init__(self, embedding_dim: int, mask_index: int, vocab_size: int, decoder_params: Optional[dict] = None, dropout: float = 0.2, parametrization_strategy: str = 'subs', **kwargs) -> None: """Initializes the ProtobindMaskedDiffusion model. Args: embedding_dim (int): The dimension of the protein sequence embeddings. mask_index (int): The token ID for the MASK token. vocab_size (int): The size of the ligand's vocabulary. decoder_params (Optional[dict]): A dictionary of parameters for the internal Transformer decoder (e.g., nhead, n_layers). dropout (float): The dropout rate. parametrization_strategy (str): The diffusion parameterization to use. Currently only 'subs' is supported. """ super().__init__() self.neg_infinity = -1000000.0 self.parametrization_strategy = parametrization_strategy self.decoder_name = decoder_params.pop('decoder_name') expand_feedforward = decoder_params.pop('expand_feedforward') self.mask_index = mask_index # Decoder options if self.decoder_name == 'decoder_re': self.decoder = Decoder_RoPE(vocab_size, embedding_dim, expand_feedforward=expand_feedforward, dropout=dropout, **decoder_params) else: raise ValueError(f"Model only supports decoder with rotary embeddings ('decoder_re'), got: {self.decoder_name}") def forward(self, sequence: Tuple[torch.Tensor, torch.Tensor], ligands: Tuple[torch.Tensor, torch.Tensor], sigma: torch.Tensor, mask_ligand: torch.Tensor, return_attention: bool = False) -> torch.Tensor: """Performs the main forward pass of the diffusion model. Args: sequence (Tuple[torch.Tensor, torch.Tensor]): A tuple of the conditioning protein sequence embeddings and their lengths. ligands (Tuple[torch.Tensor, torch.Tensor]): A tuple containing the noised ligand `xt`and its length. sigma (torch.Tensor): The diffusion timestep (noise level). mask_ligand (torch.Tensor): The padding mask for the ligand. return_attention (bool): If True, returns attention weights from the decoder. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final predicted logits and the attention weights. """ sequence, sequence_lengths = sequence xt, ligand_lengths, _ = ligands # Decode ligand ligand_masked = xt.squeeze(-1).long() ligand_decoded, attention = self.decoder(ligand_masked, sigma, sequence, sequence_lengths, lig_padding_mask=None, return_attention=return_attention) # Apply parametrization ligand_decoded = self.parametrization(ligand_decoded, xt) return ligand_decoded, attention def parametrization(self, logits, xt): """Applies the chosen parameterization to the model's output logits. The 'subs' strategy modifies the logits to represent the probability p(x_{t-1}|x_t), enforcing that unmasked tokens remain unchanged. Args: logits (torch.Tensor): The raw output logits from the decoder. xt (torch.Tensor): The noised input ligand at timestep t. Returns: torch.Tensor: The re-parameterized logits. """ if self.parametrization_strategy == 'subs': # log prob at the mask index = - infinity logits[:, :, self.mask_index] += self.neg_infinity # Normalize the logits logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) # Apply updates for unmasked tokens xt = xt.squeeze(-1) unmasked_indices = (xt != self.mask_index) logits[unmasked_indices] = self.neg_infinity logits[unmasked_indices, xt[unmasked_indices].long()] = 0 else: raise NotImplementedError(f'Parametrization strategy {self.parametrization_strategy} not implemented') return logits