File size: 4,110 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3

import sys
import os

import random
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm
from collections import Counter
from omegaconf import OmegaConf
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForMaskedLM

from MeMDLM_v2.src.lm.diffusion_module import MembraneFlow
from src.sampling.unconditional_sampler import UnconditionalSampler
from src.utils.generate_utils import mask_for_de_novo, calc_ppl
from src.utils.model_utils import _print


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.chdir('/home/a03-sgoel/MeMDLM_v2')
config = OmegaConf.load("./src/configs/lm.yaml")

date = datetime.now().strftime("%Y-%m-%d")



def generate_sequence(prior: str, tokenizer, generator, device):
    input_ids = tokenizer(prior, return_tensors="pt").to(device)['input_ids']
    ids = generator.sample_unconditional(
        xt=input_ids,
        num_steps=config.sampling.n_steps,
        return_logits=False,
        banned_token_ids=None
        #banned_token_ids=[tokenizer.convert_tokens_to_ids("P"), tokenizer.convert_tokens_to_ids("C")]
    )
    generated_sequence = tokenizer.decode(ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
    return generated_sequence


def main():
    csv_save_path = f'./results/denovo/unconditional/{config.wandb.name}/{date}_tau=3.0_test-set_distribution'
    
    try: os.makedirs(csv_save_path, exist_ok=False)
    except FileExistsError: pass


    tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
    
    flow = MembraneFlow(config).to(device)
    state_dict = flow.get_state_dict(f"./checkpoints/{config.wandb.name}/best_model.ckpt")
    flow.load_state_dict(state_dict)
    flow.eval()

    esm_pth = config.lm.pretrained_esm
    esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
    esm_model.eval()

    generator = UnconditionalSampler(tokenizer, flow)

    # # Get 100 random sequence lengths to generate
    # seq_lengths = [random.randint(50, 250) for _ in range(5000)]

    # # Determine length from positive controls
    # df = pd.read_csv(f'./results/denovo/unconditional/{config.wandb.name}/perin_pos_ctrl/raw_seqs.csv')
    # seq_lengths = [len(seq) for seq in df['Sequence'].tolist() for _ in range(500)] # generate each length 100 times
    # _print(seq_lengths)

    # Determine lengths from test set distribution
    df = pd.read_csv("./data/test.csv")
    seq_lengths = [len(seq) for seq in df['Sequence'].tolist()]
    length_counts = Counter(seq_lengths) # {L1: freq, L2: freq, ...}
    total = sum(length_counts.values()) # total number of tokens
    lengths = np.array(list(length_counts.keys())) # Frequency of each length
    probs = np.array([length_counts[l] / total for l in lengths])
    seq_lengths = np.random.choice(lengths, size=len(seq_lengths), p=probs)

    generation_results = []
    for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
        seq_res = []

        masked_seq = mask_for_de_novo(seq_len) # Sequence of all <mask> tokens
        gen_seq = ""
        attempts = 0

        while len(gen_seq) != seq_len and attempts < 3:
            gen_seq = generate_sequence(masked_seq, tokenizer, generator, device)
            attempts += 1

        if len(gen_seq) != seq_len:
            esm_ppl, flow_ppl = None, None
        else:
            esm_ppl = calc_ppl(esm_model, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='esm')
            flow_ppl = calc_ppl(flow, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='flow')
        
        _print(f'gen seq: {gen_seq}')
        _print(f'esm ppl: {esm_ppl}')
        _print(f'flow ppl: {flow_ppl}')
        
        seq_res.append(gen_seq)
        seq_res.append(esm_ppl)
        seq_res.append(flow_ppl)

        generation_results.append(seq_res)

    df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'ESM PPL', 'Flow PPL'])
    df.to_csv(csv_save_path + "/seqs_with_ppl.csv", index=False)


if __name__ == "__main__":
    main()