| | import gc |
| | import os |
| | import math |
| | from re import L |
| | import torch |
| |
|
| | import lightning as pl |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoModel |
| |
|
| | from src.madsbm.wt_peptide.control_field import PeptideControlField |
| | from src.PeptiVerse.inference import PeptiVersePredictor |
| | from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms |
| |
|
| |
|
| | class MadSBM(pl.LightningModule): |
| | def __init__(self, config, guidance=None): |
| | super().__init__() |
| |
|
| | self.config = config |
| | self.model = PeptideControlField(config) |
| | self.tokenizer = self.model.tokenizer |
| | self.vocab_size = self.tokenizer.vocab_size |
| | |
| | self.mask_id = self.tokenizer.mask_token_id |
| | self.pad_id = self.tokenizer.pad_token_id |
| |
|
| | self.embed_model = AutoModel.from_pretrained(config.model.esm_model) |
| | self.embed_model.eval() |
| | for param in self.embed_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.time_schedule = config.time_embed.time_schedule |
| | self.anneal_frac = config.time_embed.anneal_frac |
| | self.eps = float(config.time_embed.min_time) |
| | self.t_max = 1.0 - self.eps |
| | |
| |
|
| | |
| | def forward(self, input_ids, attention_mask, t): |
| | return self.model(xt=input_ids, attention_mask=attention_mask, t=t) |
| |
|
| | def step(self, batch): |
| | x1 = batch['input_ids'] |
| | attn_mask = batch['attention_mask'] |
| | maskable = self.is_maskable(x1) |
| |
|
| | t = self.sample_t(x1) |
| | xt = self.noise_seq(x1, t, maskable_mask=maskable) |
| |
|
| | outs = self.forward(xt, attn_mask, t) |
| | if self.config.model.ablate: |
| | logits = outs['dit'] |
| | else: |
| | logits = outs['madsbm'] |
| | max_u_logit = outs['dit'].max().item() |
| | max_esm_logit = outs['esm'].max().item() |
| |
|
| | loss_token = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | x1.view(-1), |
| | reduction = 'none', |
| | ignore_index=self.pad_id |
| | ) |
| | loss_token = loss_token.view(x1.size(0), x1.size(1)) |
| |
|
| | sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0) |
| |
|
| | loss = sample_loss.mean() |
| | ppl = torch.exp(loss) |
| | |
| | return loss, ppl, max_u_logit, max_esm_logit |
| |
|
| | def noise_seq(self, x1, t, maskable_mask): |
| | B, L = x1.shape |
| | t = t.unsqueeze(1) |
| | |
| | |
| | u = torch.rand((B, L), device=x1.device) |
| | masked = (u < t) & maskable_mask |
| |
|
| | xt = x1.clone() |
| | xt = xt.masked_fill(masked, self.mask_id) |
| |
|
| | return xt |
| | |
| | |
| | def sample_t(self, x1): |
| | ts = self.time_schedule |
| | if ts == 'linear': |
| | return self.sample_linear_t(x1) |
| | elif ts == 'exponential': |
| | return self.sample_exp_t(x1) |
| | elif ts == 'uniform': |
| | return self.sample_uni_t(x1) |
| | else: |
| | raise ValueError(f"Unrecognized time scheduler type: {ts}") |
| |
|
| | def sample_uni_t(self, x1): |
| | B = x1.size(0) |
| | T = self.config.time_embed.n_timesteps |
| |
|
| | discrete_ts = torch.randint(1, T+1, (B,), device=x1.device) |
| | timesteps = discrete_ts.float() / float(T) |
| | _print(f'timesteps: {timesteps}') |
| | return timesteps.clamp(min=self.eps, max=self.t_max) |
| |
|
| |
|
| | def sample_linear_t(self, x1): |
| | B = x1.size(0) |
| | eps = self.eps |
| |
|
| | |
| | frac = float(self.global_step) / float(self.tot_steps) |
| | t_max = 1.0 - eps |
| |
|
| | if frac < self.anneal_frac: |
| | |
| | prog = frac / max(1e-12, self.anneal_frac) |
| | t_min = eps + prog * (t_max - eps) |
| | t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device) |
| | else: |
| | |
| | t = eps + (t_max - eps) * torch.rand(B, device=x1.device) |
| |
|
| | return t.clamp(min=eps, max=t_max) |
| |
|
| |
|
| | def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6): |
| | |
| | """ |
| | Exponentially anneal center of t from t_min to t_max over training. |
| | |
| | Implement if linear schedule isn't expressive enough |
| | But for annealing over training steps, which can be a very large quantity, |
| | exponential approximates linear schedule |
| | """ |
| | |
| | k = self.config.training.exp_time_k |
| | progress = self.trainer.step / self.tot_steps |
| | frac = 1.0 - torch.exp(-k * torch.tensor(progress)) |
| | center = t_min + frac * (t_max - t_min) |
| |
|
| | |
| | t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center |
| | return t.clamp(min=t_min, max=t_max) |
| |
|
| |
|
| |
|
| | |
| | def training_step(self, batch): |
| | loss, ppl = self.step(batch) |
| | self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) |
| | self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False) |
| | return loss |
| | |
| | def validation_step(self, batch): |
| | loss, ppl = self.step(batch) |
| | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | return loss |
| |
|
| | def test_step(self, batch): |
| | loss, ppl, max_u, max_esm = self.step(batch) |
| | self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | return loss |
| |
|
| | def on_after_backward(self): |
| | pre_norm = compute_grad_norms(self.parameters()) |
| | self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True) |
| |
|
| | |
| | |
| | |
| | |
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW( |
| | params = self.model.parameters(), |
| | lr = self.config.optim.lr, |
| | weight_decay = self.config.optim.weight_decay, |
| | betas = (self.config.optim.beta1, self.config.optim.beta2) |
| | ) |
| |
|
| | self.tot_steps = self.trainer.estimated_stepping_batches |
| | warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs) |
| |
|
| | lr_scheduler = CosineWarmup( |
| | optimizer = optimizer, |
| | warmup_steps = warmup_steps, |
| | total_steps = self.tot_steps |
| | ) |
| |
|
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": lr_scheduler, |
| | "interval": "step", |
| | "frequency": 1 |
| | } |
| | } |
| |
|
| | def on_save_checkpoint(self, checkpoint: dict): |
| | """ |
| | Don't save the classifier model used for FBD calculation in the ckpt |
| | """ |
| | sd = checkpoint.get('state_dict', None) |
| | if sd is None: |
| | return |
| | keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")] |
| | for k in keys_to_remove: |
| | sd.pop(k, None) |
| | checkpoint['state_dict'] = sd |
| |
|
| |
|
| | |
| | def is_maskable(self, input_ids: torch.Tensor): |
| | return ( |
| | (input_ids != self.tokenizer.pad_token_id) |
| | & (input_ids != self.tokenizer.cls_token_id) |
| | & (input_ids != self.tokenizer.eos_token_id) |
| | ) |
| |
|
| | def validate_config(self): |
| | assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path" |
| | assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim' |
| | assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim' |
| | assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim' |
| |
|
| | def get_state_dict(self, ckpt_path): |
| | def remove_model_prefix(state_dict): |
| | for k, v in state_dict.items(): |
| | if "model." in k: |
| | k.replace('model.', '') |
| | return state_dict |
| |
|
| | checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu') |
| | state_dict = checkpoint.get("state_dict", checkpoint) |
| |
|
| | if any(k.startswith("model.") for k in state_dict.keys()): |
| | state_dict = remove_model_prefix(state_dict) |
| | |
| | return state_dict |
| |
|
| | def cleanup(self): |
| | torch.cuda.empty_cache() |
| | gc.collect() |