File size: 17,166 Bytes
bd082dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
from pathlib import Path
from typing import Tuple, Optional, Dict
import numpy as np
import torch
from torch import nn
import lightning.pytorch as pl
import logging
import huggingface_hub

from .ligands.rdkit_utils import validate_smile, calc_chem_desc, tanimoto_smiles
from .ligands.smiles_tokenizer import ChemformerTokenizer
from .noise_schedule import _sample_t, q_xt, _sample_categorical, LogLinearNoise
from .decoder_rope import Decoder_RoPE

logger = logging.getLogger("lightning")


class ModelGenerator(pl.LightningModule):
    """
    ProtoBind-Diff model with SMILES and ESM-2 protein encodings.
    """
    @staticmethod
    def get_exp_dir(
            exp_dir: str | None,
            output_dir: str,
            exp_dir_prefix: str,
            split: str
    ) -> Path:
        """Determines the experiment directory path."""
        if exp_dir:
            return Path(exp_dir)
        return Path(output_dir) / split / exp_dir_prefix

    def __init__(self, *args, **kwargs):
        """Initializes the Lightning Module, saves hyperparameters, and configures the model."""
        super().__init__()

        is_load = kwargs['load']
        if not is_load:
            self.save_hyperparameters()

        self.data_dir = Path(kwargs["data_dir"])
        exp_dir = kwargs.get('exp_dir', None)
        self.exp_dir = self.get_exp_dir(
            exp_dir=exp_dir,
            output_dir=kwargs["output_dir"],
            exp_dir_prefix=kwargs["exp_dir_prefix"],
            split=kwargs["split"]
        )

        self.configure_model_params(**kwargs)

    def configure_model_params(self, **kwargs):
        """Parses keyword arguments to configure the model, tokenizer, and training parameters."""

        self.learning_rate = kwargs.pop('learning_rate')
        self.weight_decay = float(kwargs.pop('weight_decay'))

        # Decoder params for masked diffusion
        decoder_params = {
            'nhead': kwargs['num_heads_decoder'],
            'n_layers': kwargs['num_decoder_layers'],
            'hidden_size': kwargs['decoder_hidd_dim'],
            'expand_feedforward': kwargs['expand_feedforward'],
            'decoder_name': kwargs['decoder_name'],
        }
        # Tokenizer params
        tokenizer_path = kwargs.get('tokenizer_path')
        if tokenizer_path:
            self.tokenizer = ChemformerTokenizer(filename=tokenizer_path)
        else:
            self.tokenizer = ChemformerTokenizer(filename=self.data_dir / f"{kwargs['tokenizer_json_name']}.json")

        # Masking params
        self.noise = LogLinearNoise()
        self.mask_index = self.tokenizer.mask_token_id

        # Sampler params
        self.model_length = 170
        self.noise_removal = True
        self.nucleus_p = 0.9
        self.eta = 0.1
        self.sampling_steps = 100
        self.time_conditioning = False

        self.return_attention = False

        self.model = ProtobindMaskedDiffusion(
            embedding_dim=kwargs['seq_embedding_dim'],
            mask_index=self.mask_index,
            vocab_size=self.tokenizer.vocab_size,
            decoder_params=decoder_params,
            dropout=kwargs['dropout'],
        )
        self.optimizer = kwargs.get('optimizer', 'Adam')

    def generate_mols(self, sequence: Tuple[torch.Tensor, torch.Tensor],
                      return_attention=False) -> Tuple[np.array, torch.Tensor,np.array]:
        """Generates and validates SMILES strings for a given protein sequence.

        This method calls the internal sampler, decodes the generated tokens into
        SMILES strings, and filters out any invalid molecules.

        Args:
            sequence (Tuple[torch.Tensor, torch.Tensor]): The conditioned protein sequence
                embedding and its length.
            return_attention (bool): Whether to return attention maps from the sampler.

        Returns:
            Tuple[np.array, torch.Tensor, np.array]: A tuple containing the valid SMILES
            strings, corresponding attention maps, and the mask of valid indices.
        """
        samples, attention = self._sample(sequence, return_attention=return_attention)
        text_samples = self.tokenizer.decode(samples.long())
        text_samples = np.array([validate_smile(smile) for smile in text_samples])

        mask_invalid = (text_samples != None) & (text_samples != '.') & (text_samples != '')
        text_samples = text_samples[mask_invalid]
        if attention is not None:
            attention = attention[mask_invalid]

        return text_samples, attention, mask_invalid

    def predict_step(self, batch, batch_idx):
        sequence, smiles, seq_id, smi_id = batch
        gen_samples, attention, mask_invalid = self.generate_mols(
            sequence, return_attention=self.return_attention)
        seq_id = seq_id[mask_invalid]
        return gen_samples, attention, seq_id

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, "train", batch_idx)

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        # dataloader_idx to predict on several validation sets
        return self.common_step(batch, "val", batch_idx, dataloader_idx)

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self.common_step(batch, "test", batch_idx)

    def common_step(self, batch, description, batch_idx, dataloader_idx=None):
        """Performs a common training, validation, or test step.

            This method takes a batch, applies noise according to the diffusion
            timestep, runs the model forward, calculates the loss, and logs metrics.

            Args:
                batch (Tuple): The input batch from the dataloader.
                description (str): The step description (e.g., 'train', 'val').
                batch_idx (int): The index of the batch.

            Returns:
                torch.Tensor: The calculated loss for the batch.
         """
        sequence, smiles, seq_id, smi_id = batch

        # Get data and apply noise
        X, length = smiles
        bs = X.shape[0]
        X = X.squeeze(-1)
        padding_mask = (X != 0).float()  # 0 is pad token id
        t = _sample_t(X.shape[0], X.device)
        sigma, dsigma = self.noise(t)
        move_chance = 1 - torch.exp(-sigma[:, None])
        xt = q_xt(X, move_chance, self.mask_index)
        xt = xt.unsqueeze(dim=2)
        smiles_t = (xt, length, None)

        pred_x, _ = self.model(sequence, smiles_t, sigma, padding_mask)
        total_loss = self.loss_mdlm(X.long(), pred_x, sigma, dsigma, padding_mask=None)

        if batch_idx % 50 == 0:
            tokens = pred_x.argmax(dim=-1) * padding_mask
            true_smiles = self.tokenizer.decode(X.long())
            pred_smiles = [smile for smile in self.tokenizer.decode(tokens)]
            pred_smiles_valid = [validate_smile(smile) for smile in pred_smiles]
            
            try:
                tanimoto = np.asarray([tanimoto_smiles(mol_pred, mol_ref) for mol_pred, mol_ref
                                   in zip(pred_smiles_valid, true_smiles) if mol_pred is not None])
                tanimoto_mean = np.mean(tanimoto) if len(tanimoto) > 0 else 0
                num_mols_valid = len(tanimoto)
            except:
                num_mols_valid = 0
                tanimoto_mean = 0.0

            self.log(f"{description}_tanimoto", tanimoto_mean, prog_bar=True,
                     on_epoch=True, sync_dist=True)
            self.log(f"{description}_perc_of_valid", num_mols_valid / bs * 100, prog_bar=True,
                     on_epoch=True, sync_dist=True)

        self.log(f"{description}_loss", total_loss, prog_bar=True, on_epoch=True,
                 sync_dist=True, batch_size=bs)
        return total_loss

    def configure_optimizers(self):
        if self.weight_decay > 0.:
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        else:
            optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def loss_mdlm(self, x_0, model_output, sigma, dsigma, padding_mask=None):
        """Loss for SUBS parameterization, continuous time case"""
        log_p_theta = torch.gather(
            input=model_output,
            dim=-1,
            index=x_0[:, :, None]).squeeze(-1)

        loss = - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]

        if padding_mask is not None:
            return (loss * padding_mask).sum() / padding_mask.sum()
        return loss.mean()
    
    def _sample_prior(self, *batch_dims):
        return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64)

    def _ddpm_caching_update(self, sequence, x, t, dt, p_x0=None, conf=None,
                             return_attention=False):
        attention = None
        if t.ndim > 1:
            t = t.squeeze(-1)
        sigma_t, _ = self.noise(t)
        assert t.ndim == 1
        move_chance_t = t[:, None, None]
        move_chance_s = (t - dt)[:, None, None]
        assert move_chance_t.ndim == 3, move_chance_t.shape
        padding_mask = (x != 0).float()
        
        if p_x0 is None:
            p_x0, attention = self.model(sequence, (x.unsqueeze(dim=2), None, None), sigma_t,
                                         padding_mask, return_attention=return_attention)
            p_x0 = p_x0.exp()
            if self.nucleus_p < 1:
                sorted_probs, sorted_indices = torch.sort(p_x0, descending=True, dim=-1)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                top_p_mask = cumulative_probs <= self.nucleus_p
                top_p_mask[..., 0] = True
                nucleus_probs = sorted_probs * top_p_mask
                nucleus_probs /= nucleus_probs.sum(dim=-1, keepdim=True)
                p_x0 = torch.zeros_like(p_x0).scatter_(-1, sorted_indices, nucleus_probs)
        
        assert move_chance_t.ndim == p_x0.ndim
        
        # Use remdm-cap sampler
        alpha_t = (1 - move_chance_t)[0].item()
        alpha_s = (1 - move_chance_s)[0].item()
        if alpha_t > 0:
            sigma = min(self.eta, (1 - alpha_s) / alpha_t)
        else:
            sigma = self.eta
        q_xs = p_x0 * (1 - sigma)
        q_xs[..., self.mask_index] = sigma
        q_xs_2 = p_x0 * ((alpha_s - (1 - sigma) * alpha_t) / (1 - alpha_t))
        q_xs_2[..., self.mask_index] = (1 - alpha_s - sigma * alpha_t) / (1 - alpha_t)
        copy_flag = (x != self.mask_index).to(torch.bool)
        q_xs = torch.where(copy_flag.unsqueeze(-1), q_xs, q_xs_2)
        xs = _sample_categorical(q_xs)

        if torch.allclose(xs, x) and not self.time_conditioning:
            p_x0_cache = p_x0
        else:
            p_x0_cache = None

        return p_x0_cache, xs, conf, attention

    @torch.no_grad()
    def _sample(self, sequence, eps=1e-3, return_attention=False):
        """Generate samples from the model"""
        num_steps = self.sampling_steps
        bs = sequence[0].shape[0]
        x = self._sample_prior(bs, self.model_length).to(self.device)

        timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
        dt = (1 - eps) / num_steps
        p_x0_cache = None

        min_t = timesteps[-1].item()
        confident_score = - torch.ones_like(x, device=self.device) * torch.inf
        
        for i in range(num_steps):
            t = timesteps[i] * torch.ones(bs, 1, device=self.device)
            p_x0_cache, x_next, confident_score, attention = self._ddpm_caching_update(
                sequence, x, t, dt, p_x0=p_x0_cache, conf=confident_score,
                return_attention=return_attention)

            if (not torch.allclose(x_next, x)):
                p_x0_cache = None
            x = x_next

        if self.noise_removal: 
            t = min_t * torch.ones(bs, 1, device=self.device)
            unet_conditioning = self.noise(t)[0]
            padding_mask = (x != 0).float()
            x, attention = self.model(sequence, (x, None, None), unet_conditioning.squeeze(-1),
                                      padding_mask, return_attention=return_attention)
            x = x.argmax(dim=-1)
        return x, attention


class ProtobindMaskedDiffusion(nn.Module, huggingface_hub.PyTorchModelHubMixin):
    """The core Protobind-Diff model, which uses a Transformer decoder with RoPE.

    This model is designed for a masked diffusion task and supports conditioning
    on ESM-2 protein embeddings and generating ligands with a ChemformerTokenizer.
    """


    def __init__(self,
                 embedding_dim: int,
                 mask_index: int,
                 vocab_size: int,
                 decoder_params: Optional[dict] = None,
                 dropout: float = 0.2,
                 parametrization_strategy: str = 'subs',
                 **kwargs) -> None:
        """Initializes the ProtobindMaskedDiffusion model.

        Args:
            embedding_dim (int): The dimension of the protein sequence embeddings.
            mask_index (int): The token ID for the MASK token.
            vocab_size (int): The size of the ligand's vocabulary.
            decoder_params (Optional[dict]): A dictionary of parameters for the
                internal Transformer decoder (e.g., nhead, n_layers).
            dropout (float): The dropout rate.
            parametrization_strategy (str): The diffusion parameterization to use.
                Currently only 'subs' is supported.
        """
        super().__init__()

        self.neg_infinity = -1000000.0
        self.parametrization_strategy = parametrization_strategy
        self.decoder_name = decoder_params.pop('decoder_name')
        expand_feedforward = decoder_params.pop('expand_feedforward')
        self.mask_index = mask_index

        # Decoder options
        if self.decoder_name == 'decoder_re':
            self.decoder = Decoder_RoPE(vocab_size, embedding_dim, expand_feedforward=expand_feedforward,
                                        dropout=dropout, **decoder_params)
        else:
            raise ValueError(f"Model only supports decoder with rotary embeddings ('decoder_re'), got: {self.decoder_name}")

    def forward(self,
                sequence: Tuple[torch.Tensor, torch.Tensor],
                ligands: Tuple[torch.Tensor, torch.Tensor],
                sigma: torch.Tensor,
                mask_ligand: torch.Tensor,
                return_attention: bool = False) -> torch.Tensor:
        """Performs the main forward pass of the diffusion model.

        Args:
            sequence (Tuple[torch.Tensor, torch.Tensor]): A tuple of the conditioning
                protein sequence embeddings and their lengths.
            ligands (Tuple[torch.Tensor, torch.Tensor]): A tuple
                containing the noised ligand `xt`and its length.
            sigma (torch.Tensor): The diffusion timestep (noise level).
            mask_ligand (torch.Tensor): The padding mask for the ligand.
            return_attention (bool): If True, returns attention weights from the decoder.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing the final predicted logits
            and the attention weights.
        """

        sequence, sequence_lengths = sequence
        xt, ligand_lengths, _ = ligands

        # Decode ligand
        ligand_masked = xt.squeeze(-1).long()
        ligand_decoded, attention = self.decoder(ligand_masked,
                                                 sigma,
                                                 sequence,
                                                 sequence_lengths,
                                                 lig_padding_mask=None,
                                                 return_attention=return_attention)

        # Apply parametrization
        ligand_decoded = self.parametrization(ligand_decoded, xt)

        return ligand_decoded, attention

    def parametrization(self, logits, xt):
        """Applies the chosen parameterization to the model's output logits.

        The 'subs' strategy modifies the logits to represent the probability
        p(x_{t-1}|x_t), enforcing that unmasked tokens remain unchanged.

        Args:
            logits (torch.Tensor): The raw output logits from the decoder.
            xt (torch.Tensor): The noised input ligand at timestep t.

        Returns:
            torch.Tensor: The re-parameterized logits.
        """
        if self.parametrization_strategy == 'subs':
            # log prob at the mask index = - infinity
            logits[:, :, self.mask_index] += self.neg_infinity

            # Normalize the logits
            logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)

            # Apply updates for unmasked tokens
            xt = xt.squeeze(-1)
            unmasked_indices = (xt != self.mask_index)
            logits[unmasked_indices] = self.neg_infinity
            logits[unmasked_indices, xt[unmasked_indices].long()] = 0
        else:
            raise NotImplementedError(f'Parametrization strategy {self.parametrization_strategy} not implemented')
        return logits