import gc import os import math from re import L import torch import lightning as pl import torch.nn.functional as F from transformers import AutoModel from src.madsbm.wt_peptide.control_field import PeptideControlField from src.PeptiVerse.inference import PeptiVersePredictor from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms class MadSBM(pl.LightningModule): def __init__(self, config, guidance=None): super().__init__() self.config = config self.model = PeptideControlField(config) self.tokenizer = self.model.tokenizer self.vocab_size = self.tokenizer.vocab_size self.mask_id = self.tokenizer.mask_token_id self.pad_id = self.tokenizer.pad_token_id self.embed_model = AutoModel.from_pretrained(config.model.esm_model) self.embed_model.eval() for param in self.embed_model.parameters(): param.requires_grad = False self.time_schedule = config.time_embed.time_schedule self.anneal_frac = config.time_embed.anneal_frac self.eps = float(config.time_embed.min_time) self.t_max = 1.0 - self.eps # -------# Main Training Logic #-------- # def forward(self, input_ids, attention_mask, t): return self.model(xt=input_ids, attention_mask=attention_mask, t=t) def step(self, batch): x1 = batch['input_ids'] attn_mask = batch['attention_mask'] maskable = self.is_maskable(x1) t = self.sample_t(x1) xt = self.noise_seq(x1, t, maskable_mask=maskable) outs = self.forward(xt, attn_mask, t) if self.config.model.ablate: logits = outs['dit'] else: logits = outs['madsbm'] max_u_logit = outs['dit'].max().item() max_esm_logit = outs['esm'].max().item() loss_token = F.cross_entropy( logits.view(-1, logits.size(-1)), x1.view(-1), reduction = 'none', ignore_index=self.pad_id ) loss_token = loss_token.view(x1.size(0), x1.size(1)) sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0) loss = sample_loss.mean() ppl = torch.exp(loss) return loss, ppl, max_u_logit, max_esm_logit def noise_seq(self, x1, t, maskable_mask): B, L = x1.shape t = t.unsqueeze(1) # B, 1 # reveal if u < t, mask if u >= t u = torch.rand((B, L), device=x1.device) masked = (u < t) & maskable_mask xt = x1.clone() xt = xt.masked_fill(masked, self.mask_id) return xt # -------# Time Schedules #-------- # def sample_t(self, x1): ts = self.time_schedule if ts == 'linear': return self.sample_linear_t(x1) elif ts == 'exponential': return self.sample_exp_t(x1) elif ts == 'uniform': return self.sample_uni_t(x1) else: raise ValueError(f"Unrecognized time scheduler type: {ts}") def sample_uni_t(self, x1): B = x1.size(0) T = self.config.time_embed.n_timesteps discrete_ts = torch.randint(1, T+1, (B,), device=x1.device) timesteps = discrete_ts.float() / float(T) _print(f'timesteps: {timesteps}') return timesteps.clamp(min=self.eps, max=self.t_max) def sample_linear_t(self, x1): B = x1.size(0) eps = self.eps # fraction of total training steps completed frac = float(self.global_step) / float(self.tot_steps) t_max = 1.0 - eps if frac < self.anneal_frac: # normalize progress within the anneal window prog = frac / max(1e-12, self.anneal_frac) # maps [0, anneal_frac) to [0,1) t_min = eps + prog * (t_max - eps) # linear increase from eps to 1.0-eps t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device) else: # after anneal_frac of training steps completed, then uniform sample over entire range [eps, 1.0-eps] t = eps + (t_max - eps) * torch.rand(B, device=x1.device) return t.clamp(min=eps, max=t_max) def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6): # TODO - FIX THIS METHOD IF NEEDED !! """ Exponentially anneal center of t from t_min to t_max over training. Implement if linear schedule isn't expressive enough But for annealing over training steps, which can be a very large quantity, exponential approximates linear schedule """ # k controls how fast the curve rises. k = self.config.training.exp_time_k progress = self.trainer.step / self.tot_steps frac = 1.0 - torch.exp(-k * torch.tensor(progress)) center = t_min + frac * (t_max - t_min) # add small jitter so we don't collapse onto a distribution t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center return t.clamp(min=t_min, max=t_max) # -------# Model Training / Evaluation #-------- # def training_step(self, batch): loss, ppl = self.step(batch) self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False) return loss def validation_step(self, batch): loss, ppl = self.step(batch) self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) return loss def test_step(self, batch): loss, ppl, max_u, max_esm = self.step(batch) self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) return loss def on_after_backward(self): pre_norm = compute_grad_norms(self.parameters()) self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True) # torch.nn.utils.clip_grad_norm_(self.parameters(), float(self.config.training.grad_clip_val)) # post_norm = compute_grad_norms(self.parameters()) # self.log('train/grad_norm_POST_clip', post_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True) def configure_optimizers(self): optimizer = torch.optim.AdamW( params = self.model.parameters(), lr = self.config.optim.lr, weight_decay = self.config.optim.weight_decay, betas = (self.config.optim.beta1, self.config.optim.beta2) ) self.tot_steps = self.trainer.estimated_stepping_batches warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs) lr_scheduler = CosineWarmup( optimizer = optimizer, warmup_steps = warmup_steps, total_steps = self.tot_steps ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "interval": "step", "frequency": 1 } } def on_save_checkpoint(self, checkpoint: dict): """ Don't save the classifier model used for FBD calculation in the ckpt """ sd = checkpoint.get('state_dict', None) if sd is None: return keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")] for k in keys_to_remove: sd.pop(k, None) checkpoint['state_dict'] = sd # -------# Helper methods #-------- # def is_maskable(self, input_ids: torch.Tensor): return ( (input_ids != self.tokenizer.pad_token_id) & (input_ids != self.tokenizer.cls_token_id) & (input_ids != self.tokenizer.eos_token_id) ) def validate_config(self): assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path" assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim' assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim' assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim' def get_state_dict(self, ckpt_path): def remove_model_prefix(state_dict): for k, v in state_dict.items(): if "model." in k: k.replace('model.', '') return state_dict checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu') state_dict = checkpoint.get("state_dict", checkpoint) if any(k.startswith("model.") for k in state_dict.keys()): state_dict = remove_model_prefix(state_dict) return state_dict def cleanup(self): torch.cuda.empty_cache() gc.collect()