File size: 2,952 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3

import sys
import os
import torch
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from omegaconf import OmegaConf
from transformers import AutoTokenizer, AutoModelForMaskedLM

from src.lm.memdlm.diffusion_module import MembraneFlow
from src.utils.model_utils import _print
from src.sampling.guided_sampler import GuidedSampler
from src.utils.generate_utils import (
    mask_for_scaffold,
    calc_blosum_score,
    calc_ppl
)

config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml")

os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/')
todays_date = datetime.today().strftime('%Y-%m-%d')
csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/'
try: os.makedirs(csv_save_path, exist_ok=False)
except FileExistsError: pass


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
    esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device)

    diffusion = MembraneFlow(config).to(device)
    state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt")
    diffusion.load_state_dict(state_dict)
    diffusion.eval().to(device)

    sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device)

    df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv')
    sequences = df['Sequence'].tolist()

    gen_seqs, ppls, blosums = [], [], []


    for seq in tqdm(sequences, desc='Infilling Sequences'):
        masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>')
        tokens = tokenizer(masked_seq, return_tensors='pt')
        input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device)
        
        soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()]
        infilled_tokens = sampler.optimize_sequence(
            input_ids=input_ids,
            attn_masks=attn_masks,
            soluble_indices=soluble_idxs,
        )
        infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5]
        
        bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs)
        try:
            ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm')
        except:
            ppl = float('inf')

        gen_seqs.append(infilled_seq)
        ppls.append(ppl)
        blosums.append(bl)

        _print(seq)
        _print(infilled_seq)
        _print(ppl)
        _print(bl)
        _print('\n')


    df['MeMDLM Sequence'] = gen_seqs
    df['MeMDLM PPL'] = ppls
    df['MeMDLM BLOSUM'] = blosums

    _print(df)
    df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False)


    
if __name__ == "__main__":
    main()