| import torch |
| import torch.distributed as dist |
| from models import DiT_models |
| from download import find_model |
| from diffusion import create_diffusion |
| import argparse |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| from train_autoencoder import ldmol_autoencoder |
| from utils import AE_SMILES_decoder, molT5_encoder, AE_SMILES_encoder, regexTokenizer |
| from rdkit import Chem |
| import time |
| import random |
| from torch.optim.sgd import SGD |
|
|
|
|
| @torch.no_grad() |
| def main(args): |
| """ |
| Run sampling. |
| """ |
| torch.backends.cuda.matmul.allow_tf32 = args.tf32 |
| assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
| torch.set_grad_enabled(False) |
|
|
| |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| device = rank % torch.cuda.device_count() |
| seed = args.global_seed * dist.get_world_size() + rank |
| torch.manual_seed(seed) |
| torch.cuda.set_device(device) |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
|
|
| if args.ckpt is None: |
| raise ValueError("Please specify a checkpoint path with --ckpt.") |
|
|
| |
| latent_size = 127 |
| in_channels = 64 |
| cross_attn = 768 |
| condition_dim = 1024 |
| model = DiT_models[args.model]( |
| input_size=latent_size, |
| in_channels=in_channels, |
| cross_attn=cross_attn, |
| condition_dim=condition_dim, |
| ).to(device) |
| |
| ckpt_path = args.ckpt |
| state_dict = find_model(ckpt_path) |
| msg = model.load_state_dict(state_dict, strict=False) |
| if rank == 0: print('DiT from ', ckpt_path, msg) |
| model.eval() |
|
|
| ae_config = { |
| 'bert_config_decoder': './config_decoder.json', |
| 'bert_config_encoder': './config_encoder.json', |
| 'embed_dim': 256, |
| } |
| tokenizer = regexTokenizer(vocab_path='./vocab_bpe_300_sc.txt', max_len=127) |
| ae_model = ldmol_autoencoder(config=ae_config, no_train=True, tokenizer=tokenizer) |
| if args.vae: |
| checkpoint = torch.load(args.vae, map_location='cpu') |
| try: |
| state_dict = checkpoint['model'] |
| except: |
| state_dict = checkpoint['state_dict'] |
| msg = ae_model.load_state_dict(state_dict, strict=False) |
| if rank == 0: print('autoencoder', args.vae, msg) |
| for param in ae_model.parameters(): |
| param.requires_grad = False |
| ae_model = ae_model.to(device) |
| ae_model.eval() |
| if rank == 0: print(f'AE #parameters: {sum(p.numel() for p in ae_model.parameters())}, #trainable: {sum(p.numel() for p in ae_model.parameters() if p.requires_grad)}') |
|
|
| text_encoder = T5ForConditionalGeneration.from_pretrained('laituan245/molt5-large-caption2smiles').to(device) |
| text_tokenizer = T5Tokenizer.from_pretrained("laituan245/molt5-large-caption2smiles", model_max_length=512) |
| del text_encoder.decoder |
|
|
| for param in text_encoder.parameters(): |
| param.requires_grad = False |
| text_encoder.eval() |
| if rank == 0: |
| print(f'text encoder #parameters: {sum(p.numel() for p in text_encoder.parameters())}, #trainable: {sum(p.numel() for p in text_encoder.parameters() if p.requires_grad)}') |
|
|
| dist.barrier() |
|
|
| diffusion = create_diffusion(timestep_respacing="") |
| st = time.time() |
|
|
| x_smiles = args.input_smiles |
| y_s = [args.source_text] |
| y_t = [args.target_text] |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| x_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(x_smiles), isomericSmiles=True, canonical=True) |
| loss_scale = 0.4 |
| n_iter = 200 |
| cfg = 2. |
| print(x_smiles, ':', y_s[0], '=>', y_t[0]) |
|
|
| |
| x_source = AE_SMILES_encoder([x_smiles], ae_model).permute((0, 2, 1)).unsqueeze(-1) |
| x_target = x_source.clone() |
| x_target.requires_grad = True |
| optimizer = SGD(params=[x_target], lr=1e-1) |
|
|
| y_n = 'no description.' |
|
|
| biot5_embed_s, pad_mask_s = molT5_encoder(y_s, text_encoder, text_tokenizer, args.description_length, device) |
| biot5_embed_t, pad_mask_t = molT5_encoder(y_t, text_encoder, text_tokenizer, args.description_length, device) |
| biot5_embed_n, pad_mask_n = molT5_encoder(y_n, text_encoder, text_tokenizer, args.description_length, device) |
|
|
| y_cond_s = biot5_embed_s.to(device).type(torch.float32) |
| pad_mask_s = pad_mask_s.to(device).bool() |
| y_cond_t = biot5_embed_t.to(device).type(torch.float32) |
| pad_mask_t = pad_mask_t.to(device).bool() |
| y_cond_n = biot5_embed_n.to(device).type(torch.float32) |
| pad_mask_n = pad_mask_n.to(device).bool() |
|
|
| |
| model_kwargs_s = dict(y=y_cond_s, pad_mask=pad_mask_s) |
| model_kwargs_t = dict(y=y_cond_t, pad_mask=pad_mask_t) |
| model_kwargs_n = dict(y=y_cond_n, pad_mask=pad_mask_n) |
|
|
| for step in range(n_iter): |
| |
| t = random.randint(100, diffusion.num_timesteps - 100) * torch.ones((x_source.shape[0],), device=device).int() |
|
|
| noise = torch.randn_like(x_target) |
| x_target_t = diffusion.q_sample(x_target, t, noise=noise) |
| x_source_t = diffusion.q_sample(x_source, t, noise=noise) |
|
|
| model_output_s = model(x_source_t, t, **model_kwargs_s) |
| model_output_s, _ = torch.split(model_output_s, x_source_t.shape[1], dim=1) |
| model_output_sn = model(x_source_t, t, **model_kwargs_n) |
| model_output_sn, _ = torch.split(model_output_sn, x_source_t.shape[1], dim=1) |
| |
| model_output_t = model(x_target_t, t, **model_kwargs_t) |
| model_output_t, _ = torch.split(model_output_t, x_target_t.shape[1], dim=1) |
| model_output_tn = model(x_target_t, t, **model_kwargs_n) |
| model_output_tn, _ = torch.split(model_output_tn, x_target_t.shape[1], dim=1) |
|
|
| model_output_s = model_output_sn + cfg*(model_output_s - model_output_sn) |
| model_output_t = model_output_tn + cfg*(model_output_t - model_output_tn) |
|
|
| grad = (model_output_t - model_output_s).detach() |
|
|
| |
| |
| |
| |
|
|
| x_target -= grad*loss_scale |
|
|
| if step % 20 == 0: |
| output = x_target.squeeze(-1).permute((0, 2, 1)) |
| output = AE_SMILES_decoder(output, ae_model, stochastic=False, k=1) |
| print(f'{step}\t: {output[0]}') |
|
|
| |
| dist.barrier() |
| print('time:', time.time() - st) |
| print(f'{x_smiles}') |
| print(f'{output[0]}') |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="LDMol") |
| parser.add_argument("--vae", type=str, default="./Pretrain/checkpoint_autoencoder.ckpt") |
| parser.add_argument("--input-smiles", type=str, default="C[C@H](CCc1ccccc1)Nc1ccc(C#N)cc1F") |
| parser.add_argument("--source-text", type=str, default="This molecule contains fluorine.") |
| parser.add_argument("--target-text", type=str, default="This molecule contains bromine.") |
| parser.add_argument("--text-encoder-name", type=str, default="molt5") |
| parser.add_argument("--description-length", type=int, default=256) |
| parser.add_argument("--global-seed", type=int, default=0) |
| parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, |
| help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") |
| parser.add_argument("--ckpt", type=str, default=None, |
| help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") |
| args = parser.parse_args() |
| main(args) |
|
|
|
|