| | import yaml |
| | from tqdm import tqdm |
| | import torch |
| | from torch import nn |
| | from transformers import AutoTokenizer |
| |
|
| | from models.peptide_classifiers import * |
| |
|
| | from utils.parsing import parse_guidance_args |
| | args = parse_guidance_args() |
| |
|
| | import pdb |
| | import random |
| | import inspect |
| |
|
| | |
| | step_size = 1 / 100 |
| | n_samples = 1 |
| | length = args.length |
| | target = args.target_protein |
| | motifs = args.motifs |
| | vocab_size = 24 |
| | source_distribution = "uniform" |
| | device = 'cuda:0' |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) |
| | motifs = parse_motifs(motifs).to(device) |
| | print(motifs) |
| |
|
| | |
| | solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) |
| |
|
| | bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device) |
| | motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True) |
| |
|
| | affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device) |
| | affinity_model = AffinityModel(affinity_predictor, target_sequence) |
| |
|
| | score_models = [motif_model, affinity_model] |
| |
|
| | for i in range(args.n_batches): |
| | if source_distribution == "uniform": |
| | x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) |
| | elif source_distribution == "mask": |
| | x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() |
| | else: |
| | raise NotImplementedError |
| |
|
| | zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) |
| | twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) |
| | x_init = torch.cat([zeros, x_init, twos], dim=1) |
| |
|
| | x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, |
| | step_size=step_size, |
| | verbose=True, |
| | time_grid=torch.tensor([0.0, 1.0-1e-3]), |
| | score_models=score_models, |
| | num_objectives=3, |
| | weights=args.weights) |
| |
|
| | samples = x_1.tolist() |
| | samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples] |
| | print(samples) |
| |
|
| | scores = [] |
| | for i, s in enumerate(score_models): |
| | sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| | if 't' in sig.parameters: |
| | candidate_scores = s(x_1, 1) |
| | else: |
| | candidate_scores = s(x_1) |
| |
|
| | if isinstance(candidate_scores, tuple): |
| | for score in candidate_scores: |
| | scores.append(score.item()) |
| | else: |
| | scores.append(candidate_scores.item()) |
| | print(scores) |
| | |
| | with open(args.output_file, 'a') as f: |
| | f.write(samples[0]) |
| | for score in scores: |
| | f.write(f",{score}") |
| | f.write('\n') |
| | |
| | |
| | |
| | |
| |
|
| |
|