AReUReDi / peptide /moo.py
Tong Chen
add files
d2693e0
import argparse
import math
import random
from collections import Counter
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
from peptide_classifiers import *
# --- Model Architecture (Must match the trained model) ---
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
return self.mlp(t.unsqueeze(-1))
class DiTBlock(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.seq_len = seq_len
self.model_dim = model_dim
self.mask_token_id = vocab_size
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
def forward(self, x, t):
seq_len = x.shape[1]
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
class MOGGenerator:
def __init__(self, model, device, objectives, args):
self.model = model
self.device = device
self.objectives = objectives
self.args = args
self.num_objectives = len(objectives)
def _get_scores(self, x_batch):
"""Calculates the normalized scores for a batch of sequences."""
scores = []
for obj_func in self.objectives:
scores.append(obj_func(x_batch.to(self.device)))
return torch.stack(scores, dim=0)
def _barker_g(self, u):
"""Barker balancing function."""
return u / (1 + u)
def generate(self):
"""Main generation loop."""
shape = (self.args.num_samples, self.args.gen_len + 2)
x = torch.randint(5, self.model.vocab_size, shape, dtype=torch.long, device=self.device)
x[:, 0] = 0
x[:, -1] = 2
if args.weights is None:
weights = torch.full((self.num_objectives,), 1/self.num_objectives, device=self.device).view(-1,1)
else:
weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1)
if len(weights) != self.num_objectives:
raise ValueError("Number of weights must match number of objectives.")
print(f"Weights: {weights}")
if self.args.min_threshold is not None:
min_threshold = torch.tensor(self.args.min_threshold, device=self.device)
else:
min_threshold = None
total_optimization_steps = self.args.optimization_steps * self.args.gen_len
with torch.no_grad():
for t in tqdm(range(total_optimization_steps), desc="MOG Generation"):
# Anneal guidance strength
eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (total_optimization_steps - 1))
# eta_t = 0.5 * (self.args.eta_min + self.args.eta_max)
# Choose a random position to mutate
mut_idx = random.randint(1, self.args.gen_len)
# Determine the generation timestep
# We cycle through the timesteps to ensure all are visited
generation_step = t % self.args.optimization_steps
time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device)
# Get proposal distribution from ReDi model for the chosen position
logits = self.model(x, time_t)
probs = F.softmax(logits, dim=-1)
pos_probs = probs[:, mut_idx, :]
pos_probs[:, x[:, mut_idx]] = 0 # We don't evalute the same token
# Prune candidate vocabulary using top-p sampling
sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cumulative_probs > self.args.top_p
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
remove_mask[..., 0] = 0
# Get the set of candidate tokens for each sample in the batch
candidate_tokens_list = []
for i in range(self.args.num_samples):
sample_mask = remove_mask[i]
candidates = sorted_indices[i, ~sample_mask]
candidate_tokens_list.append(candidates)
# Get current scores
current_scores = self._get_scores(x)
w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values)
# Evaluate all candidate tokens for each sample
final_proposal_tokens = []
for i in range(self.args.num_samples):
candidates = candidate_tokens_list[i]
candidates = torch.tensor([token for token in candidates if token not in [0,1,2,3]], device=candidates.device)
num_candidates = len(candidates)
# Create a batch of proposed sequences for the current sample
x_prop_batch = x[i].repeat(num_candidates, 1)
x_prop_batch[:, mut_idx] = candidates
# Evaluate all proposals
proposal_scores = self._get_scores(x_prop_batch)
proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values
w_proposal = torch.exp(eta_t * proposal_s_omega)
# Get ReDi probabilities for the candidates
redi_probs = pos_probs[i, candidates]
# Calculate unnormalized guided probabilities
tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i])
# Normalize and sample the final token
final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9)
index = torch.multinomial(final_probs, 1).item()
if torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]):
final_token = candidates[index]
print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}")
print(f"Previous Scores: {current_scores[:,i]}")
print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}")
print(f"New Scores: {proposal_scores[:,index]}")
else:
final_token = x[i][mut_idx]
# final_token = candidates[index]
final_proposal_tokens.append(final_token)
# Update the sequences with the chosen tokens
x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens)
scores = self._get_scores(x)
return x
# --- Main Execution ---
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
target = args.target
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/ReDi_discrete/peptides/classifier_ckpt/binding_affinity_unpooled.pt', device)
affinity_model = AffinityModel(affinity_predictor, target_sequence)
hemolysis_model = HemolysisModel(device=device)
nonfouling_model = NonfoulingModel(device=device)
solubility_model = SolubilityModel(device=device)
halflife_model = HalfLifeModel(device=device)
print(f"Loading checkpoint from {args.checkpoint}...")
try:
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
model_args = checkpoint['args']
except Exception as e:
print(f"Error loading checkpoint: {e}")
return
print("Initializing model...")
model = MDLM(
vocab_size=model_args.vocab_size,
seq_len=model_args.seq_len,
model_dim=model_args.model_dim,
n_heads=model_args.n_heads,
n_layers=model_args.n_layers
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Model loaded successfully.")
# List of all objective functions
OBJECTIVE_FUNCTIONS = [hemolysis_model, nonfouling_model, solubility_model, halflife_model, affinity_model]
mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args)
hemolysis = []
nonfouling = []
solubility = []
halflife = []
affinity = []
for _ in range(args.num_batches):
generated_tokens = mog_generator.generate()
final_scores = mog_generator._get_scores(generated_tokens).detach().cpu().numpy()
with open(args.output_file, 'a', newline='') as f:
writer = csv.writer(f)
for i in range(args.num_samples):
sample_tokens = generated_tokens[i]
print(sample_tokens)
sequence_str = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False).replace(" ", "")[5:-5]
scores = final_scores[:, i]
writer.writerow([sequence_str] + scores.tolist())
print([sequence_str] + scores.tolist())
print("Generation complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.")
parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.")
parser.add_argument("--output_file", type=str, default="./mog_peptides.txt", help="File to save the generated sequences.")
parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.")
parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.")
parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).")
parser.add_argument("--min_threshold", type=float, nargs='+', required=False, help="minimum threshold for the objectives (e.g., 0.2 0.2).")
parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.")
parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.")
parser.add_argument("--target", type=str, required=True)
args = parser.parse_args()
main(args)