MemDLM / src /sampling /guided_generator.py
Shrey Goel
adding code
d04a061
raw
history blame
2.95 kB
#!/usr/bin/env python3
import sys
import os
import torch
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from omegaconf import OmegaConf
from transformers import AutoTokenizer, AutoModelForMaskedLM
from src.lm.memdlm.diffusion_module import MembraneFlow
from src.utils.model_utils import _print
from src.sampling.guided_sampler import GuidedSampler
from src.utils.generate_utils import (
mask_for_scaffold,
calc_blosum_score,
calc_ppl
)
config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml")
os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/')
todays_date = datetime.today().strftime('%Y-%m-%d')
csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/'
try: os.makedirs(csv_save_path, exist_ok=False)
except FileExistsError: pass
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device)
diffusion = MembraneFlow(config).to(device)
state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt")
diffusion.load_state_dict(state_dict)
diffusion.eval().to(device)
sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device)
df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv')
sequences = df['Sequence'].tolist()
gen_seqs, ppls, blosums = [], [], []
for seq in tqdm(sequences, desc='Infilling Sequences'):
masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>')
tokens = tokenizer(masked_seq, return_tensors='pt')
input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device)
soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()]
infilled_tokens = sampler.optimize_sequence(
input_ids=input_ids,
attn_masks=attn_masks,
soluble_indices=soluble_idxs,
)
infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5]
bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs)
try:
ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm')
except:
ppl = float('inf')
gen_seqs.append(infilled_seq)
ppls.append(ppl)
blosums.append(bl)
_print(seq)
_print(infilled_seq)
_print(ppl)
_print(bl)
_print('\n')
df['MeMDLM Sequence'] = gen_seqs
df['MeMDLM PPL'] = ppls
df['MeMDLM BLOSUM'] = blosums
_print(df)
df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False)
if __name__ == "__main__":
main()