sgoel30 commited on
Commit
906ae08
·
verified ·
1 Parent(s): f8e0708

Delete src/sampling/guided_generator.py

Browse files
Files changed (1) hide show
  1. src/sampling/guided_generator.py +0 -90
src/sampling/guided_generator.py DELETED
@@ -1,90 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import sys
4
- import os
5
- import torch
6
- import pandas as pd
7
- from tqdm import tqdm
8
- from datetime import datetime
9
- from omegaconf import OmegaConf
10
- from transformers import AutoTokenizer, AutoModelForMaskedLM
11
-
12
- from src.lm.memdlm.diffusion_module import MembraneFlow
13
- from src.utils.model_utils import _print
14
- from src.sampling.guided_sampler import GuidedSampler
15
- from src.utils.generate_utils import (
16
- mask_for_scaffold,
17
- calc_blosum_score,
18
- calc_ppl
19
- )
20
-
21
- config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml")
22
-
23
- os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/')
24
- todays_date = datetime.today().strftime('%Y-%m-%d')
25
- csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/'
26
- try: os.makedirs(csv_save_path, exist_ok=False)
27
- except FileExistsError: pass
28
-
29
-
30
- def main():
31
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
-
33
- tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
34
- esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device)
35
-
36
- diffusion = MembraneFlow(config).to(device)
37
- state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt")
38
- diffusion.load_state_dict(state_dict)
39
- diffusion.eval().to(device)
40
-
41
- sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device)
42
-
43
- df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv')
44
- sequences = df['Sequence'].tolist()
45
-
46
- gen_seqs, ppls, blosums = [], [], []
47
-
48
-
49
- for seq in tqdm(sequences, desc='Infilling Sequences'):
50
- masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>')
51
- tokens = tokenizer(masked_seq, return_tensors='pt')
52
- input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device)
53
-
54
- soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()]
55
- infilled_tokens = sampler.optimize_sequence(
56
- input_ids=input_ids,
57
- attn_masks=attn_masks,
58
- soluble_indices=soluble_idxs,
59
- )
60
- infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5]
61
-
62
- bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs)
63
- try:
64
- ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm')
65
- except:
66
- ppl = float('inf')
67
-
68
- gen_seqs.append(infilled_seq)
69
- ppls.append(ppl)
70
- blosums.append(bl)
71
-
72
- _print(seq)
73
- _print(infilled_seq)
74
- _print(ppl)
75
- _print(bl)
76
- _print('\n')
77
-
78
-
79
- df['MeMDLM Sequence'] = gen_seqs
80
- df['MeMDLM PPL'] = ppls
81
- df['MeMDLM BLOSUM'] = blosums
82
-
83
- _print(df)
84
- df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False)
85
-
86
-
87
-
88
- if __name__ == "__main__":
89
- main()
90
-