File size: 6,039 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import math
import sys

import torch.nn.functional as F
import pandas as pd
import numpy as np

from omegaconf import OmegaConf
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer

from src.lm.memdlm.diffusion_module import MembraneFlow
from src.lm.dplm.diffusion_module import DPLM
from src.utils.model_utils import get_latents, _print
from src.sampling.unconditional_sampler import UnconditionalSampler
from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler

config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/lm.yaml")

# -------# Masking #-------- #
def mask_for_de_novo(sequence_length):
    return "<mask>" * sequence_length

def mask_for_scaffold(sequence, generate_type, mask_token):
    if generate_type == "uppercase":
        sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
    elif generate_type == "lowercase":
        sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])   
    return sequence


# -------# Generation #-------- #
def memflow_infill_uncond(masked_seq, tokenizer, model: MembraneFlow):
    generator = UnconditionalSampler(tokenizer, model) # initialize the generator object
    xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
    denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
    generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
    return generated_sequence


def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
    """
    Following the given evodiff example
    https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
    """    
    # Manual masking of infilling sequence
    motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq])  # Mask token is "#" in evodiff tokenizer
    tkns = tokenizer.tokenize([motif_seq])
    sample = torch.as_tensor(tkns).to(device)

    # Create input motif + scaffold
    loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
    np.random.shuffle(loc)
    
    sample = sample.to(device).unsqueeze(0)
    # og_sample = sample.clone()
    
    with torch.no_grad():
        for i in loc:
            timestep = torch.tensor([0] * batch_size).to(device)  # placeholder but not called in model
            timestep = timestep.to(device)
            prediction = model(sample, timestep)
            p = prediction[:, i, :len(tokenizer.all_aas) - 6]  # only canonical
            p = F.softmax(p, dim=1)  # softmax over logits
            p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
            sample[:, i] = p_sample.squeeze()
    output = [tokenizer.untokenize(s) for s in sample]
    return output[0] #if batch_size==1 else output, og_sample, loc


def dplm_infill(masked_seq, tokenizer, model: DPLM, device):
    generator = DPLMUnconditionalSampler(tokenizer, model)
    xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
    denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
    generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
    return generated_sequence


# -------# Metrics #-------- #
def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
    """Compute causal LM cross-entropy loss for a given sequence."""
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=fp16):
            logits = model(
                input_ids = target,
                attention_mask = torch.ones_like(target)
            ).logits
            # Shift
            logits = logits[:-1, ...]
            target = target[1:]
            loss = torch.nn.functional.cross_entropy(
                input=logits,
                target=target,
                reduction='mean'
            )
            return torch.exp(loss).item()


def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
    total_loss = 0.0
    tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
    attn_mask = torch.ones_like(tensor_input).to(model.device)

    for i in mask_token_indices:
        masked_input = tensor_input.clone()
        masked_input[0, i] = tokenizer.mask_token_id
    
        labels = torch.full(tensor_input.shape, -100).to(model.device)
        labels[0, i] = tensor_input[0, i]

        with torch.no_grad():
            if model_type == 'esm':
                loss = model(masked_input, labels=labels).loss.item()
            elif model_type == 'flow':
                logits = model.forward(masked_input, attention_mask=attn_mask)
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                    reduction='none',
                    ignore_index=-100,
                )[i].item()
 
            total_loss += loss
    
    avg_loss = total_loss / len(generated_sequence)
    perplexity = math.exp(avg_loss)

    return perplexity


def calc_blosum_score(og_seq, gen_seq, indices):
    import blosum as bl
    mat = bl.BLOSUM(62)
    tot_score = 0
    for i in indices:
        og_res, gen_res = og_seq[i], gen_seq[i]
        try:
            val = mat[og_res][gen_res]
            tot_score += val
        except KeyError:
             # -4 is lowest BLOSUM score indicating biological implausability
            tot_score += -4
    return tot_score / len(indices) if indices else 0


def calc_cos_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
    og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
    new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
    cosine_sim = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
    cosine_sim = torch.mean(cosine_sim).item()
    return cosine_sim