| | import torch |
| | import numpy as np |
| | from transformers import AutoTokenizer, AutoModel |
| | from models.diffusion import Diffusion |
| | from configs.config import Config |
| | from utils.esm_utils import load_esm2_model, get_latents |
| |
|
| | def mask_sequence(sequence, mask_char='X'): |
| | """Masks parts of the sequence based on the mask_char.""" |
| | mask_indices = [i for i, char in enumerate(sequence) if char == mask_char] |
| | masked_sequence = sequence.replace(mask_char, '[MASK]') |
| | return masked_sequence, mask_indices |
| |
|
| | def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices): |
| | """Generates the filled sequence for the masked regions.""" |
| | inputs = tokenizer(masked_sequence, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = esm_model(**inputs) |
| | latents = outputs.last_hidden_state.squeeze(0) |
| | |
| | sigma = torch.rand(1, device=latents.device) |
| | noisy_latents = model.forward(latents, sigma) |
| | denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
| | |
| | filled_sequence = list(masked_sequence) |
| | for idx in mask_indices: |
| | token_id = torch.argmax(denoised_latents[idx]).item() |
| | filled_sequence[idx] = tokenizer.decode([token_id]) |
| | |
| | return ''.join(filled_sequence) |
| |
|
| | def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length): |
| | """Generates a scaffold sequence to connect multiple peptides.""" |
| | total_peptide_length = sum(len(peptide) for peptide in peptides) |
| | scaffold_length = final_length - total_peptide_length |
| | if scaffold_length <= 0: |
| | raise ValueError("Final length must be greater than the combined length of the peptides.") |
| | |
| | scaffold = "[MASK]" * scaffold_length |
| | masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:]) |
| | |
| | inputs = tokenizer(masked_sequence, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = esm_model(**inputs) |
| | latents = outputs.last_hidden_state.squeeze(0) |
| | |
| | sigma = torch.rand(1, device=latents.device) |
| | noisy_latents = model.forward(latents, sigma) |
| | denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
| | |
| | filled_sequence = list(masked_sequence) |
| | scaffold_start = len(peptides[0]) |
| | scaffold_end = scaffold_start + scaffold_length |
| | for idx in range(scaffold_start, scaffold_end): |
| | token_id = torch.argmax(denoised_latents[idx]).item() |
| | filled_sequence[idx] = tokenizer.decode([token_id]) |
| | |
| | return ''.join(filled_sequence) |
| |
|
| | def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length): |
| | """Generates a de novo protein sequence of the specified length.""" |
| | scaffold = "[MASK]" * sequence_length |
| | masked_sequence = scaffold |
| | |
| | inputs = tokenizer(masked_sequence, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = esm_model(**inputs) |
| | latents = outputs.last_hidden_state.squeeze(0) |
| | |
| | sigma = torch.rand(1, device=latents.device) |
| | noisy_latents = model.forward(latents, sigma) |
| | denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
| | |
| | filled_sequence = list(masked_sequence) |
| | for idx in range(sequence_length): |
| | token_id = torch.argmax(denoised_latents[idx]).item() |
| | filled_sequence[idx] = tokenizer.decode([token_id]) |
| | |
| | return ''.join(filled_sequence) |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | |
| | parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.") |
| | subparsers = parser.add_subparsers(dest="mode") |
| |
|
| | |
| | parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.") |
| | parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.") |
| | parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.") |
| |
|
| | |
| | parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.") |
| | parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.") |
| |
|
| | |
| | parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.") |
| | parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | config = Config() |
| |
|
| | |
| | tokenizer, esm_model = load_esm2_model(config.model_name) |
| | diffusion_model = Diffusion.load_from_checkpoint(config.training["save_dir"] + "example.ckpt", config=config, latent_dim=config.latent_dim) |
| | diffusion_model.eval() |
| |
|
| | if args.mode == "scaffold": |
| | peptides = args.peptides |
| | final_length = args.final_length |
| | filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length) |
| | print(f"Peptides: {' '.join(peptides)}") |
| | print(f"Final Length: {final_length}") |
| | print(f"Generated Protein: {filled_sequence}") |
| |
|
| | elif args.mode == "fill": |
| | sequence = args.sequence |
| | masked_sequence, mask_indices = mask_sequence(sequence) |
| | filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices) |
| | print(f"Original Sequence: {sequence}") |
| | print(f"Masked Sequence: {masked_sequence}") |
| | print(f"Filled Sequence: {filled_sequence}") |
| |
|
| | elif args.mode == "de_novo": |
| | sequence_length = args.sequence_length |
| | filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length) |
| | print(f"De Novo Sequence Length: {sequence_length}") |
| | print(f"Generated Protein: {filled_sequence}") |
| |
|