File size: 6,039 Bytes
d04a061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import math
import sys
import torch.nn.functional as F
import pandas as pd
import numpy as np
from omegaconf import OmegaConf
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
from src.lm.memdlm.diffusion_module import MembraneFlow
from src.lm.dplm.diffusion_module import DPLM
from src.utils.model_utils import get_latents, _print
from src.sampling.unconditional_sampler import UnconditionalSampler
from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/lm.yaml")
# -------# Masking #-------- #
def mask_for_de_novo(sequence_length):
return "<mask>" * sequence_length
def mask_for_scaffold(sequence, generate_type, mask_token):
if generate_type == "uppercase":
sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
elif generate_type == "lowercase":
sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])
return sequence
# -------# Generation #-------- #
def memflow_infill_uncond(masked_seq, tokenizer, model: MembraneFlow):
generator = UnconditionalSampler(tokenizer, model) # initialize the generator object
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
return generated_sequence
def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
"""
Following the given evodiff example
https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
"""
# Manual masking of infilling sequence
motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) # Mask token is "#" in evodiff tokenizer
tkns = tokenizer.tokenize([motif_seq])
sample = torch.as_tensor(tkns).to(device)
# Create input motif + scaffold
loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
np.random.shuffle(loc)
sample = sample.to(device).unsqueeze(0)
# og_sample = sample.clone()
with torch.no_grad():
for i in loc:
timestep = torch.tensor([0] * batch_size).to(device) # placeholder but not called in model
timestep = timestep.to(device)
prediction = model(sample, timestep)
p = prediction[:, i, :len(tokenizer.all_aas) - 6] # only canonical
p = F.softmax(p, dim=1) # softmax over logits
p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
sample[:, i] = p_sample.squeeze()
output = [tokenizer.untokenize(s) for s in sample]
return output[0] #if batch_size==1 else output, og_sample, loc
def dplm_infill(masked_seq, tokenizer, model: DPLM, device):
generator = DPLMUnconditionalSampler(tokenizer, model)
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
return generated_sequence
# -------# Metrics #-------- #
def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
"""Compute causal LM cross-entropy loss for a given sequence."""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=fp16):
logits = model(
input_ids = target,
attention_mask = torch.ones_like(target)
).logits
# Shift
logits = logits[:-1, ...]
target = target[1:]
loss = torch.nn.functional.cross_entropy(
input=logits,
target=target,
reduction='mean'
)
return torch.exp(loss).item()
def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
total_loss = 0.0
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
attn_mask = torch.ones_like(tensor_input).to(model.device)
for i in mask_token_indices:
masked_input = tensor_input.clone()
masked_input[0, i] = tokenizer.mask_token_id
labels = torch.full(tensor_input.shape, -100).to(model.device)
labels[0, i] = tensor_input[0, i]
with torch.no_grad():
if model_type == 'esm':
loss = model(masked_input, labels=labels).loss.item()
elif model_type == 'flow':
logits = model.forward(masked_input, attention_mask=attn_mask)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
reduction='none',
ignore_index=-100,
)[i].item()
total_loss += loss
avg_loss = total_loss / len(generated_sequence)
perplexity = math.exp(avg_loss)
return perplexity
def calc_blosum_score(og_seq, gen_seq, indices):
import blosum as bl
mat = bl.BLOSUM(62)
tot_score = 0
for i in indices:
og_res, gen_res = og_seq[i], gen_seq[i]
try:
val = mat[og_res][gen_res]
tot_score += val
except KeyError:
# -4 is lowest BLOSUM score indicating biological implausability
tot_score += -4
return tot_score / len(indices) if indices else 0
def calc_cos_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
cosine_sim = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
cosine_sim = torch.mean(cosine_sim).item()
return cosine_sim |