|
|
import argparse |
|
|
import math |
|
|
import os |
|
|
from functools import partial |
|
|
from collections import Counter |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_from_disk |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
from torch.utils.data import DataLoader |
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
from pytorch_lightning.strategies import DDPStrategy |
|
|
from rdkit import Chem |
|
|
|
|
|
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
from peptide_analyzer import PeptideAnalyzer |
|
|
import dataloading_for_dynamic_batching as dynamic_dataloader |
|
|
|
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module): |
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
def forward(self, x, seq_len=None): |
|
|
if seq_len is None: |
|
|
seq_len = x.shape[1] |
|
|
|
|
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
cos_emb = emb.cos()[None, :, :] |
|
|
sin_emb = emb.sin()[None, :, :] |
|
|
|
|
|
return cos_emb, sin_emb |
|
|
|
|
|
def rotate_half(x): |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(1, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
|
|
|
def forward(self, t): |
|
|
return self.mlp(t.unsqueeze(-1)) |
|
|
|
|
|
class MultiHeadAttentionWithRoPE(nn.Module): |
|
|
def __init__(self, hidden_size, n_heads): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = hidden_size // n_heads |
|
|
|
|
|
assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads" |
|
|
|
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
self.out_proj = nn.Linear(hidden_size, hidden_size) |
|
|
|
|
|
self.rope = RotaryPositionalEmbedding(self.head_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, seq_len, hidden_size = x.shape |
|
|
|
|
|
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
cos, sin = self.rope(q, seq_len) |
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) |
|
|
output = self.out_proj(attn_output) |
|
|
|
|
|
return output |
|
|
|
|
|
class DiTBlock(nn.Module): |
|
|
def __init__(self, hidden_size, n_heads): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads) |
|
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(hidden_size, 4 * hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(4 * hidden_size, hidden_size) |
|
|
) |
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
def forward(self, x, c): |
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
|
|
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
|
|
attn_output = self.attn(x_norm1) |
|
|
x = x + gate_msa.unsqueeze(1) * attn_output |
|
|
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
|
|
mlp_output = self.mlp(x_norm2) |
|
|
x = x + gate_mlp.unsqueeze(1) * mlp_output |
|
|
return x |
|
|
|
|
|
class MDLM(nn.Module): |
|
|
def __init__(self, vocab_size, model_dim, n_heads, n_layers): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.model_dim = model_dim |
|
|
self.mask_token_id = vocab_size |
|
|
|
|
|
self.token_embedder = nn.Embedding(vocab_size, model_dim) |
|
|
self.time_embedder = TimestepEmbedder(model_dim) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList([ |
|
|
DiTBlock(model_dim, n_heads) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
self.final_norm = nn.LayerNorm(model_dim) |
|
|
self.lm_head = nn.Linear(model_dim, vocab_size) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
if module.weight is not None: |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def forward(self, x, t): |
|
|
x_embed = self.token_embedder(x) |
|
|
t_embed = self.time_embedder(t) |
|
|
for block in self.transformer_blocks: |
|
|
x_embed = block(x_embed, t_embed) |
|
|
x_embed = self.final_norm(x_embed) |
|
|
logits = self.lm_head(x_embed) |
|
|
return logits |
|
|
|
|
|
|
|
|
class MDLMLightningModule(pl.LightningModule): |
|
|
def __init__(self, args, tokenizer): |
|
|
super().__init__() |
|
|
self.save_hyperparameters(ignore=['tokenizer']) |
|
|
self.args = args |
|
|
self.tokenizer = tokenizer |
|
|
self.peptide_analyzer = PeptideAnalyzer() |
|
|
|
|
|
|
|
|
self.model = MDLM( |
|
|
vocab_size=tokenizer.vocab_size, |
|
|
model_dim=args.model_dim, |
|
|
n_heads=args.n_heads, |
|
|
n_layers=args.n_layers |
|
|
) |
|
|
|
|
|
|
|
|
self.automatic_optimization = True |
|
|
|
|
|
self.validation_step_outputs = [] |
|
|
|
|
|
def forward(self, x, t): |
|
|
return self.model(x, t) |
|
|
|
|
|
def _compute_invalid_loss(self, logits): |
|
|
batch_token_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) |
|
|
|
|
|
penalties = torch.tensor( |
|
|
[1 if not self.peptide_analyzer.is_peptide(seq) else 0 for seq in sampled_sequences], |
|
|
dtype=torch.float32, |
|
|
device=self.device |
|
|
) |
|
|
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device) |
|
|
|
|
|
scaled_penalty = penalties[:, None] * sampled_probs |
|
|
|
|
|
return scaled_penalty |
|
|
|
|
|
def _loss(self, logits, x_1, attn_mask): |
|
|
|
|
|
ce_loss = F.cross_entropy( |
|
|
logits.view(-1, self.model.vocab_size), |
|
|
x_1.view(-1), |
|
|
reduction='none' |
|
|
).view(x_1.shape[0], -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
invalid_loss = self._compute_invalid_loss(logits) |
|
|
|
|
|
loss = ce_loss + self.args.validity_weight * invalid_loss |
|
|
nlls = loss * attn_mask |
|
|
|
|
|
num_tokens = attn_mask.sum() |
|
|
batch_nll = nlls.sum() |
|
|
token_nll = batch_nll / num_tokens |
|
|
|
|
|
return token_nll, (ce_loss*attn_mask).sum() / num_tokens, (invalid_loss*attn_mask).sum() / num_tokens |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
x_1 = batch['input_ids'].clone().detach().to(self.device) |
|
|
attn_mask = batch['attention_mask'].clone().detach().to(self.device) |
|
|
bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool() |
|
|
batch_size, _ = x_1.shape |
|
|
|
|
|
|
|
|
x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device) |
|
|
t_continuous = torch.rand(batch_size, device=self.device) |
|
|
|
|
|
|
|
|
|
|
|
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma |
|
|
non_peptide_prob = t_continuous.view(-1, 1) |
|
|
|
|
|
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) |
|
|
mask = torch.rand(x_1.shape, device=self.device) < masking_prob |
|
|
x_t = torch.where(mask, x_1, x_0) |
|
|
|
|
|
logits = self.model(x_t, t_continuous) |
|
|
|
|
|
token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask) |
|
|
|
|
|
|
|
|
self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
|
|
|
|
|
|
return token_nll |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
x_1 = batch['input_ids'].clone().detach().to(self.device) |
|
|
attn_mask = batch['attention_mask'].clone().detach().to(self.device) |
|
|
bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool() |
|
|
batch_size, _ = x_1.shape |
|
|
|
|
|
|
|
|
x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device) |
|
|
t_continuous = torch.rand(batch_size, device=self.device) |
|
|
|
|
|
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma |
|
|
non_peptide_prob = t_continuous.view(-1, 1) |
|
|
|
|
|
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) |
|
|
mask = torch.rand(x_1.shape, device=self.device) < masking_prob |
|
|
x_t = torch.where(mask, x_1, x_0) |
|
|
|
|
|
logits = self.model(x_t, t_continuous) |
|
|
|
|
|
token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask) |
|
|
self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = AdamW( |
|
|
self.parameters(), |
|
|
lr=self.args.learning_rate, |
|
|
weight_decay=self.args.weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self.trainer, 'estimated_stepping_batches'): |
|
|
num_training_steps = self.trainer.estimated_stepping_batches |
|
|
else: |
|
|
|
|
|
num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs |
|
|
|
|
|
warmup_steps = int(num_training_steps * 0.1) |
|
|
|
|
|
def lr_lambda(current_step): |
|
|
if current_step < warmup_steps: |
|
|
lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1) |
|
|
lr = (self.args.learning_rate * 0.1) + lr_range * (current_step / warmup_steps) |
|
|
return lr / self.args.learning_rate |
|
|
else: |
|
|
progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps) |
|
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1) |
|
|
lr = (self.args.learning_rate * 0.1) + lr_range * cosine_decay |
|
|
return lr / self.args.learning_rate |
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"interval": "step", |
|
|
"frequency": 1, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
checkpoint_dir = (args.checkpoint_dir + |
|
|
f"correct_lr{args.learning_rate}_wd{args.weight_decay}_layer{args.n_layers}_" |
|
|
f"head{args.n_heads}_valweight{args.validity_weight}") |
|
|
print(f"Saving to {checkpoint_dir}") |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', |
|
|
'/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') |
|
|
print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}") |
|
|
|
|
|
|
|
|
data_module = dynamic_dataloader.CustomDataModule('./data/11M_smiles_old_tokenizer_no_limit/', tokenizer) |
|
|
|
|
|
|
|
|
model = MDLMLightningModule(args, tokenizer) |
|
|
|
|
|
|
|
|
logger = WandbLogger( |
|
|
project="smiles-redi-training", |
|
|
entity="programmablebio", |
|
|
name=f"lr{args.learning_rate}_dim{args.model_dim}_head{args.n_heads}_layer{args.n_layers}", |
|
|
save_dir=checkpoint_dir |
|
|
) |
|
|
|
|
|
|
|
|
callbacks = [ |
|
|
ModelCheckpoint( |
|
|
dirpath=checkpoint_dir, |
|
|
filename='best', |
|
|
monitor='val/token_nll', |
|
|
mode='min', |
|
|
save_top_k=1, |
|
|
save_last=True, |
|
|
|
|
|
), |
|
|
LearningRateMonitor(logging_interval='step') |
|
|
] |
|
|
|
|
|
|
|
|
trainer = pl.Trainer( |
|
|
max_epochs=args.epochs, |
|
|
devices=torch.cuda.device_count(), |
|
|
accelerator='gpu', |
|
|
strategy=DDPStrategy(find_unused_parameters=False), |
|
|
num_nodes=int(os.environ.get("SLURM_NNODES", 1)), |
|
|
precision="bf16", |
|
|
gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None, |
|
|
callbacks=callbacks, |
|
|
logger=logger, |
|
|
log_every_n_steps=100, |
|
|
check_val_every_n_epoch=True, |
|
|
|
|
|
accumulate_grad_batches=8, |
|
|
enable_progress_bar=True, |
|
|
enable_model_summary=True |
|
|
) |
|
|
|
|
|
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") |
|
|
print("Starting training...") |
|
|
|
|
|
|
|
|
trainer.fit(model, data_module) |
|
|
|
|
|
print("Training complete.") |
|
|
print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Train ReDi model for SMILES generation with RoPE using PyTorch Lightning") |
|
|
|
|
|
|
|
|
parser.add_argument("--model_dim", type=int, default=1024) |
|
|
parser.add_argument("--n_heads", type=int, default=8) |
|
|
parser.add_argument("--n_layers", type=int, default=6) |
|
|
|
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=50) |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
|
parser.add_argument("--weight_decay", type=float, default=1e-5) |
|
|
parser.add_argument("--label_smoothing", type=float, default=0) |
|
|
parser.add_argument("--grad_clip", type=float, default=1.0) |
|
|
parser.add_argument("--gamma", type=float, default=2.0) |
|
|
|
|
|
|
|
|
parser.add_argument("--validity_weight", type=float, default=0.1) |
|
|
parser.add_argument("--validity_check_freq", type=int, default=10) |
|
|
parser.add_argument("--validity_eval_batches", type=int, default=20) |
|
|
|
|
|
|
|
|
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles") |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |