AReUReDi / smiles /train.py
Tong Chen
add files
295b1cd
raw
history blame
17.2 kB
import argparse
import math
import os
from functools import partial
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from rdkit import Chem
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from peptide_analyzer import PeptideAnalyzer
import dataloading_for_dynamic_batching as dynamic_dataloader
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[1]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_emb = emb.cos()[None, :, :]
sin_emb = emb.sin()[None, :, :]
return cos_emb, sin_emb
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# --- Model Architecture with RoPE ---
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
return self.mlp(t.unsqueeze(-1))
class MultiHeadAttentionWithRoPE(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
self.head_dim = hidden_size // n_heads
assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads"
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x):
batch_size, seq_len, hidden_size = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(q, seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
output = self.out_proj(attn_output)
return output
class DiTBlock(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output = self.attn(x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
def __init__(self, vocab_size, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.model_dim = model_dim
self.mask_token_id = vocab_size
self.token_embedder = nn.Embedding(vocab_size, model_dim)
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([
DiTBlock(model_dim, n_heads) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
if module.weight is not None:
module.weight.data.fill_(1.0)
def forward(self, x, t):
x_embed = self.token_embedder(x)
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
# --- PyTorch Lightning Module ---
class MDLMLightningModule(pl.LightningModule):
def __init__(self, args, tokenizer):
super().__init__()
self.save_hyperparameters(ignore=['tokenizer'])
self.args = args
self.tokenizer = tokenizer
self.peptide_analyzer = PeptideAnalyzer()
# Initialize model
self.model = MDLM(
vocab_size=tokenizer.vocab_size,
model_dim=args.model_dim,
n_heads=args.n_heads,
n_layers=args.n_layers
)
# For tracking steps
self.automatic_optimization = True
self.validation_step_outputs = []
def forward(self, x, t):
return self.model(x, t)
def _compute_invalid_loss(self, logits):
batch_token_ids = torch.argmax(logits, dim=-1) # (batch_size, seq_length)
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
penalties = torch.tensor(
[1 if not self.peptide_analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
dtype=torch.float32,
device=self.device
)
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
return scaled_penalty
def _loss(self, logits, x_1, attn_mask):
# Standard cross-entropy loss
ce_loss = F.cross_entropy(
logits.view(-1, self.model.vocab_size),
x_1.view(-1),
reduction='none'
).view(x_1.shape[0], -1)
# ce_loss = (ce_loss * attn_mask).sum() / attn_mask.sum()
# validity_weight = self.args.validity_weight * min(1.0, (self.current_epoch + 1) / self.trainer.max_epochs)
invalid_loss = self._compute_invalid_loss(logits) # (batch_size, seq_length)
loss = ce_loss + self.args.validity_weight * invalid_loss
nlls = loss * attn_mask
num_tokens = attn_mask.sum()
batch_nll = nlls.sum()
token_nll = batch_nll / num_tokens
return token_nll, (ce_loss*attn_mask).sum() / num_tokens, (invalid_loss*attn_mask).sum() / num_tokens
def training_step(self, batch, batch_idx):
x_1 = batch['input_ids'].clone().detach().to(self.device)
attn_mask = batch['attention_mask'].clone().detach().to(self.device)
bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool()
batch_size, _ = x_1.shape
# ReDi approach: random start -> target
x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device)
t_continuous = torch.rand(batch_size, device=self.device)
# mask = torch.rand(x_1.shape, device=self.device) < t_continuous.view(-1, 1)
# x_t = torch.where(mask, x_1, x_0)
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma # slower increase
non_peptide_prob = t_continuous.view(-1, 1) # linear increase
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob)
mask = torch.rand(x_1.shape, device=self.device) < masking_prob
x_t = torch.where(mask, x_1, x_0)
logits = self.model(x_t, t_continuous)
token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask)
# Logging
self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True)
self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True)
self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True)
# self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=x_1.size(0))
return token_nll
def validation_step(self, batch, batch_idx):
x_1 = batch['input_ids'].clone().detach().to(self.device)
attn_mask = batch['attention_mask'].clone().detach().to(self.device)
bond_mask = batch['bond_mask'].clone().detach().to(self.device).bool()
batch_size, _ = x_1.shape
# ReDi approach: random start -> target
x_0 = torch.randint(0, self.model.vocab_size, x_1.shape, device=self.device)
t_continuous = torch.rand(batch_size, device=self.device)
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma # slower increase
non_peptide_prob = t_continuous.view(-1, 1) # linear increase
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob)
mask = torch.rand(x_1.shape, device=self.device) < masking_prob
x_t = torch.where(mask, x_1, x_0)
logits = self.model(x_t, t_continuous)
token_nll, ce_loss, invalid_loss = self._loss(logits, x_1, attn_mask)
self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=x_1.size(0), sync_dist=True)
self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True)
self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=x_1.size(0), sync_dist=True)
def configure_optimizers(self):
optimizer = AdamW(
self.parameters(),
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay
)
# Calculate total steps
if hasattr(self.trainer, 'estimated_stepping_batches'):
num_training_steps = self.trainer.estimated_stepping_batches
else:
# Fallback calculation
num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs
warmup_steps = int(num_training_steps * 0.1)
def lr_lambda(current_step):
if current_step < warmup_steps:
lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1)
lr = (self.args.learning_rate * 0.1) + lr_range * (current_step / warmup_steps)
return lr / self.args.learning_rate
else:
progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
lr_range = self.args.learning_rate - (self.args.learning_rate * 0.1)
lr = (self.args.learning_rate * 0.1) + lr_range * cosine_decay
return lr / self.args.learning_rate
scheduler = LambdaLR(optimizer, lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
},
}
# --- Main Execution ---
def main(args):
# Set up checkpoint directory
checkpoint_dir = (args.checkpoint_dir +
f"correct_lr{args.learning_rate}_wd{args.weight_decay}_layer{args.n_layers}_"
f"head{args.n_heads}_valweight{args.validity_weight}")
print(f"Saving to {checkpoint_dir}")
os.makedirs(checkpoint_dir, exist_ok=True)
print("Loading tokenizer...")
tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt',
'/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt')
print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
# Initialize data module
data_module = dynamic_dataloader.CustomDataModule('./data/11M_smiles_old_tokenizer_no_limit/', tokenizer)
# Initialize model
model = MDLMLightningModule(args, tokenizer)
# Set up logger
logger = WandbLogger(
project="smiles-redi-training", # or your preferred project name
entity="programmablebio",
name=f"lr{args.learning_rate}_dim{args.model_dim}_head{args.n_heads}_layer{args.n_layers}",
save_dir=checkpoint_dir
)
# Set up callbacks
callbacks = [
ModelCheckpoint(
dirpath=checkpoint_dir,
filename='best',
monitor='val/token_nll',
mode='min',
save_top_k=1,
save_last=True,
# every_n_train_steps=10000 # This will save every 1000 steps AND when val/nll improves
),
LearningRateMonitor(logging_interval='step')
]
# Initialize trainer
trainer = pl.Trainer(
max_epochs=args.epochs,
devices=torch.cuda.device_count(),
accelerator='gpu',
strategy=DDPStrategy(find_unused_parameters=False),
num_nodes=int(os.environ.get("SLURM_NNODES", 1)),
precision="bf16",
gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None,
callbacks=callbacks,
logger=logger,
log_every_n_steps=100,
check_val_every_n_epoch=True,
# val_check_interval=10000,
accumulate_grad_batches=8,
enable_progress_bar=True,
enable_model_summary=True
)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")
print("Starting training...")
# Train the model
trainer.fit(model, data_module)
print("Training complete.")
print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train ReDi model for SMILES generation with RoPE using PyTorch Lightning")
# Model arguments
parser.add_argument("--model_dim", type=int, default=1024)
parser.add_argument("--n_heads", type=int, default=8)
parser.add_argument("--n_layers", type=int, default=6)
# Training arguments
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument("--label_smoothing", type=float, default=0)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--gamma", type=float, default=2.0)
# Validity arguments
parser.add_argument("--validity_weight", type=float, default=0.1)
parser.add_argument("--validity_check_freq", type=int, default=10)
parser.add_argument("--validity_eval_batches", type=int, default=20)
# Logging arguments
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles")
args = parser.parse_args()
main(args)