AReUReDi / peptide /generation.py
Tong Chen
add files
d2693e0
import argparse
import math
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer
# --- Model Architecture ---
def modulate(x, shift, scale):
"""
Modulates the input tensor x with a shift and scale.
"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
"""
Embeds a continuous scalar timestep t in [0, 1] into a vector representation.
"""
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
# t is shape (batch_size,), needs to be (batch_size, 1) for the Linear layer.
return self.mlp(t.unsqueeze(-1))
class DiTBlock(nn.Module):
"""
A single block of the Diffusion Transformer.
"""
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
"""
Masked Diffusion Language Model (MDLM) using a DiT backbone.
"""
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.seq_len = seq_len
self.model_dim = model_dim
self.mask_token_id = vocab_size # Use vocab_size as the ID for the mask token
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) # +1 for the mask token
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([
DiTBlock(model_dim, n_heads) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
def forward(self, x, t):
seq_len = x.shape[1]
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
# --- Generation Function ---
def generate_samples(model, device, num_samples, seq_len, steps, temperature):
"""
Generates samples by starting from a random sequence and progressively refining it.
"""
model.eval()
# Start with a completely random sequence of tokens
shape = (num_samples, seq_len)
x = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
# Cosine schedule determines how many tokens we *keep* from the previous step.
# It goes from 0 (keep none) to seq_len (keep all).
keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
keep_schedule = torch.round(keep_schedule).long()
with torch.no_grad():
progress_bar = tqdm(range(steps), desc="Generating Samples")
for i in progress_bar:
# Time `t` should go from 0 (pure noise) up to 1 (pure data)
t_continuous = torch.full((num_samples,), (i) / steps, device=device)
logits = model(x, t_continuous)
# Apply temperature scaling to control diversity
scaled_logits = logits / temperature
probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
# Sample a full new sequence from the model's prediction
sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(shape)
# For the last step, the new sample is our final result
if i == steps - 1:
x = sampled_tokens
break
# Determine which tokens from the *newly sampled sequence* to keep, based on confidence
confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
# Find the indices of the most confident tokens to keep
num_to_keep = keep_schedule[i]
_, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
# Create a mask for the tokens we are keeping
keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
# The next sequence `x` is a mix:
# - Where keep_mask is True, we use the new, confident sampled_tokens.
# - Where keep_mask is False, we keep the tokens from the previous step `x`.
x = torch.where(keep_mask, sampled_tokens, x)
return x
# --- Main Execution ---
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Loading checkpoint from {args.checkpoint}...")
try:
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
model_args = checkpoint['args']
except FileNotFoundError:
print(f"Error: Checkpoint file not found at {args.checkpoint}")
return
except Exception as e:
print(f"Error loading checkpoint: {e}")
return
print("Initializing model...")
model = MDLM(
vocab_size=model_args.vocab_size,
seq_len=model_args.seq_len,
model_dim=model_args.model_dim,
n_heads=model_args.n_heads,
n_layers=model_args.n_layers
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Model loaded successfully.")
gen_len = args.gen_len if args.gen_len is not None else model_args.seq_len
if gen_len > model_args.seq_len:
raise ValueError(f"Requested generation length ({gen_len}) is greater than the model's max length ({model_args.seq_len}).")
print(f"Generating sequences of length {gen_len}.")
generated_tokens = generate_samples(
model=model,
device=device,
num_samples=args.num_samples,
seq_len=gen_len,
steps=args.gen_steps,
temperature=args.temperature
)
print("Decoding and saving samples...")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
with open(args.output_file, 'w') as f:
for sample_tokens in generated_tokens:
sequence = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False)
clean_sequence = sequence.replace(" ", "")[5:-5]
f.write(clean_sequence + "\n")
print(clean_sequence)
print(f"Generation complete. {args.num_samples} sequences saved to {args.output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate samples from a trained ReDi (MDLM) model starting from random noise.")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
parser.add_argument("--num_samples", type=int, default=128, help="Number of samples to generate.")
parser.add_argument("--output_file", type=str, default="./generated_peptides.txt", help="File to save the generated peptide sequences.")
parser.add_argument("--gen_steps", type=int, default=16, help="Number of steps for the progressive refinement process.")
parser.add_argument("--gen_len", type=int, default=None, help="Desired length of the generated sequences. Defaults to the model's maximum trained length.")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. >1 increases diversity, <1 decreases it.")
args = parser.parse_args()
main(args)