File size: 6,091 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gc
import torch
import torch.nn as nn
import lightning.pytorch as pl

from omegaconf import OmegaConf
from transformers import AutoModel
from torchmetrics.classification import BinaryAUROC, BinaryAccuracy

from src.utils.model_utils import _print
from src.guidance.utils import CosineWarmup


config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")

class SolubilityClassifier(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        self.auroc = BinaryAUROC()
        self.accuracy = BinaryAccuracy()

        self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm)
        for p in self.esm_model.parameters():
            p.requires_grad = False

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.model.d_model,
            nhead=config.model.num_heads,
            dropout=config.model.dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers)
        self.layer_norm = nn.LayerNorm(config.model.d_model)
        self.dropout = nn.Dropout(config.model.dropout)
        self.mlp = nn.Sequential(
            nn.Linear(config.model.d_model, config.model.d_model // 2),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(config.model.d_model // 2, 1),
        )

    # -------# Classifier step #-------- #
    def forward(self, batch):
        if 'input_ids' in batch:
            esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask'])
        elif 'embeds' in batch:
            esm_embeds = batch['embeds']
        encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0))
        encodings = self.dropout(self.layer_norm(encodings))
        logits = self.mlp(encodings).squeeze(-1)
        return logits

    
    # -------# Training / Evaluation #-------- #
    def training_step(self, batch, batch_idx):
        train_loss, _ = self.compute_loss(batch)
        self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True)
        self.save_ckpt()
        return train_loss

    def validation_step(self, batch, batch_idx):
        val_loss, _ = self.compute_loss(batch)
        self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
        return val_loss

    def test_step(self, batch):
        test_loss, preds = self.compute_loss(batch)
        auroc, accuracy = self.get_metrics(batch, preds)
        self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
        return test_loss

    def on_test_epoch_end(self):
        self.auroc.reset()
        self.accuracy.reset()
    
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        gc.collect()
        torch.cuda.empty_cache()

    def configure_optimizers(self):
        path = self.config.training
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr)
        lr_scheduler = CosineWarmup(
            optimizer,
            warmup_steps=path.warmup_steps,
            total_steps=path.max_steps,
        )
        scheduler_dict = {
            "scheduler": lr_scheduler,
            "interval": 'step',
            'frequency': 1,
            'monitor': 'val/loss',
            'name': 'learning_rate'
        }
        return [optimizer], [scheduler_dict]
    
    def save_ckpt(self):
        curr_step = self.global_step
        save_every = self.config.training.val_check_interval
        if curr_step % save_every == 0 and curr_step > 0:  # Save every 250 steps
            ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt"
            self.trainer.save_checkpoint(ckpt_path)
    
    # -------# Loss and Test Set Metrics #-------- #
    @torch.no_grad
    def get_esm_embeddings(self, input_ids, attention_mask):
        outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state
        return embeddings

    def compute_loss(self, batch):
        """Helper method to handle loss calculation"""
        labels = batch['labels']
        preds = self.forward(batch)
        loss = self.loss_fn(preds, labels)
        loss_mask = (labels != self.config.model.label_pad_value) # only calculate loss over non-pad tokens
        loss = (loss * loss_mask).sum() / loss_mask.sum()
        return loss, preds

    def get_metrics(self, batch, preds):
        """Helper method to compute metrics"""
        labels = batch['labels']

        valid_mask = (labels != self.config.model.label_pad_value)
        labels = labels[valid_mask]
        preds = preds[valid_mask]

        _print(f"labels {labels.shape}")
        _print(f"preds {preds.shape}")

        auroc = self.auroc.forward(preds, labels)
        accuracy = self.accuracy.forward(preds, labels)
        return auroc, accuracy

    # -------# Helper Functions #-------- #
    def get_state_dict(self, ckpt_path):
        """Helper method to load and process a trained model's state dict from saved checkpoint"""
        def remove_model_prefix(state_dict):
            for k in state_dict.keys():
                if "model." in k:
                    k.replace('model.', '')
            return state_dict  

        checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
        state_dict = checkpoint.get("state_dict", checkpoint)

        if any(k.startswith("model.") for k in state_dict.keys()):
            state_dict = remove_model_prefix(state_dict)
        
        return state_dict