MadSBM / src /madsbm /wt_peptide /sbm_module.py
Shrey Goel
cleaned training code
0fa2d2b
import gc
import os
import math
from re import L
import torch
import lightning as pl
import torch.nn.functional as F
from transformers import AutoModel
from src.madsbm.wt_peptide.control_field import PeptideControlField
from src.PeptiVerse.inference import PeptiVersePredictor
from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms
class MadSBM(pl.LightningModule):
def __init__(self, config, guidance=None):
super().__init__()
self.config = config
self.model = PeptideControlField(config)
self.tokenizer = self.model.tokenizer
self.vocab_size = self.tokenizer.vocab_size
self.mask_id = self.tokenizer.mask_token_id
self.pad_id = self.tokenizer.pad_token_id
self.embed_model = AutoModel.from_pretrained(config.model.esm_model)
self.embed_model.eval()
for param in self.embed_model.parameters():
param.requires_grad = False
self.time_schedule = config.time_embed.time_schedule
self.anneal_frac = config.time_embed.anneal_frac
self.eps = float(config.time_embed.min_time)
self.t_max = 1.0 - self.eps
# -------# Main Training Logic #-------- #
def forward(self, input_ids, attention_mask, t):
return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
def step(self, batch):
x1 = batch['input_ids']
attn_mask = batch['attention_mask']
maskable = self.is_maskable(x1)
t = self.sample_t(x1)
xt = self.noise_seq(x1, t, maskable_mask=maskable)
outs = self.forward(xt, attn_mask, t)
if self.config.model.ablate:
logits = outs['dit']
else:
logits = outs['madsbm']
max_u_logit = outs['dit'].max().item()
max_esm_logit = outs['esm'].max().item()
loss_token = F.cross_entropy(
logits.view(-1, logits.size(-1)),
x1.view(-1),
reduction = 'none',
ignore_index=self.pad_id
)
loss_token = loss_token.view(x1.size(0), x1.size(1))
sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0)
loss = sample_loss.mean()
ppl = torch.exp(loss)
return loss, ppl, max_u_logit, max_esm_logit
def noise_seq(self, x1, t, maskable_mask):
B, L = x1.shape
t = t.unsqueeze(1) # B, 1
# reveal if u < t, mask if u >= t
u = torch.rand((B, L), device=x1.device)
masked = (u < t) & maskable_mask
xt = x1.clone()
xt = xt.masked_fill(masked, self.mask_id)
return xt
# -------# Time Schedules #-------- #
def sample_t(self, x1):
ts = self.time_schedule
if ts == 'linear':
return self.sample_linear_t(x1)
elif ts == 'exponential':
return self.sample_exp_t(x1)
elif ts == 'uniform':
return self.sample_uni_t(x1)
else:
raise ValueError(f"Unrecognized time scheduler type: {ts}")
def sample_uni_t(self, x1):
B = x1.size(0)
T = self.config.time_embed.n_timesteps
discrete_ts = torch.randint(1, T+1, (B,), device=x1.device)
timesteps = discrete_ts.float() / float(T)
_print(f'timesteps: {timesteps}')
return timesteps.clamp(min=self.eps, max=self.t_max)
def sample_linear_t(self, x1):
B = x1.size(0)
eps = self.eps
# fraction of total training steps completed
frac = float(self.global_step) / float(self.tot_steps)
t_max = 1.0 - eps
if frac < self.anneal_frac:
# normalize progress within the anneal window
prog = frac / max(1e-12, self.anneal_frac) # maps [0, anneal_frac) to [0,1)
t_min = eps + prog * (t_max - eps) # linear increase from eps to 1.0-eps
t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device)
else:
# after anneal_frac of training steps completed, then uniform sample over entire range [eps, 1.0-eps]
t = eps + (t_max - eps) * torch.rand(B, device=x1.device)
return t.clamp(min=eps, max=t_max)
def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6):
# TODO - FIX THIS METHOD IF NEEDED !!
"""
Exponentially anneal center of t from t_min to t_max over training.
Implement if linear schedule isn't expressive enough
But for annealing over training steps, which can be a very large quantity,
exponential approximates linear schedule
"""
# k controls how fast the curve rises.
k = self.config.training.exp_time_k
progress = self.trainer.step / self.tot_steps
frac = 1.0 - torch.exp(-k * torch.tensor(progress))
center = t_min + frac * (t_max - t_min)
# add small jitter so we don't collapse onto a distribution
t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center
return t.clamp(min=t_min, max=t_max)
# -------# Model Training / Evaluation #-------- #
def training_step(self, batch):
loss, ppl = self.step(batch)
self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True)
self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False)
return loss
def validation_step(self, batch):
loss, ppl = self.step(batch)
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
return loss
def test_step(self, batch):
loss, ppl, max_u, max_esm = self.step(batch)
self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
return loss
def on_after_backward(self):
pre_norm = compute_grad_norms(self.parameters())
self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
# torch.nn.utils.clip_grad_norm_(self.parameters(), float(self.config.training.grad_clip_val))
# post_norm = compute_grad_norms(self.parameters())
# self.log('train/grad_norm_POST_clip', post_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
params = self.model.parameters(),
lr = self.config.optim.lr,
weight_decay = self.config.optim.weight_decay,
betas = (self.config.optim.beta1, self.config.optim.beta2)
)
self.tot_steps = self.trainer.estimated_stepping_batches
warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs)
lr_scheduler = CosineWarmup(
optimizer = optimizer,
warmup_steps = warmup_steps,
total_steps = self.tot_steps
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"frequency": 1
}
}
def on_save_checkpoint(self, checkpoint: dict):
"""
Don't save the classifier model used for FBD calculation in the ckpt
"""
sd = checkpoint.get('state_dict', None)
if sd is None:
return
keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")]
for k in keys_to_remove:
sd.pop(k, None)
checkpoint['state_dict'] = sd
# -------# Helper methods #-------- #
def is_maskable(self, input_ids: torch.Tensor):
return (
(input_ids != self.tokenizer.pad_token_id)
& (input_ids != self.tokenizer.cls_token_id)
& (input_ids != self.tokenizer.eos_token_id)
)
def validate_config(self):
assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim'
assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim'
assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim'
def get_state_dict(self, ckpt_path):
def remove_model_prefix(state_dict):
for k, v in state_dict.items():
if "model." in k:
k.replace('model.', '')
return state_dict
checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu')
state_dict = checkpoint.get("state_dict", checkpoint)
if any(k.startswith("model.") for k in state_dict.keys()):
state_dict = remove_model_prefix(state_dict)
return state_dict
def cleanup(self):
torch.cuda.empty_cache()
gc.collect()