|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from datetime import datetime |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
|
|
|
from src.lm.memdlm.diffusion_module import MembraneFlow |
|
|
from src.utils.model_utils import _print |
|
|
from src.sampling.guided_sampler import GuidedSampler |
|
|
from src.utils.generate_utils import ( |
|
|
mask_for_scaffold, |
|
|
calc_blosum_score, |
|
|
calc_ppl |
|
|
) |
|
|
|
|
|
config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
|
|
|
|
|
os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/') |
|
|
todays_date = datetime.today().strftime('%Y-%m-%d') |
|
|
csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/' |
|
|
try: os.makedirs(csv_save_path, exist_ok=False) |
|
|
except FileExistsError: pass |
|
|
|
|
|
|
|
|
def main(): |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm) |
|
|
esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device) |
|
|
|
|
|
diffusion = MembraneFlow(config).to(device) |
|
|
state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt") |
|
|
diffusion.load_state_dict(state_dict) |
|
|
diffusion.eval().to(device) |
|
|
|
|
|
sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device) |
|
|
|
|
|
df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv') |
|
|
sequences = df['Sequence'].tolist() |
|
|
|
|
|
gen_seqs, ppls, blosums = [], [], [] |
|
|
|
|
|
|
|
|
for seq in tqdm(sequences, desc='Infilling Sequences'): |
|
|
masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>') |
|
|
tokens = tokenizer(masked_seq, return_tensors='pt') |
|
|
input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device) |
|
|
|
|
|
soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()] |
|
|
infilled_tokens = sampler.optimize_sequence( |
|
|
input_ids=input_ids, |
|
|
attn_masks=attn_masks, |
|
|
soluble_indices=soluble_idxs, |
|
|
) |
|
|
infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5] |
|
|
|
|
|
bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs) |
|
|
try: |
|
|
ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm') |
|
|
except: |
|
|
ppl = float('inf') |
|
|
|
|
|
gen_seqs.append(infilled_seq) |
|
|
ppls.append(ppl) |
|
|
blosums.append(bl) |
|
|
|
|
|
_print(seq) |
|
|
_print(infilled_seq) |
|
|
_print(ppl) |
|
|
_print(bl) |
|
|
_print('\n') |
|
|
|
|
|
|
|
|
df['MeMDLM Sequence'] = gen_seqs |
|
|
df['MeMDLM PPL'] = ppls |
|
|
df['MeMDLM BLOSUM'] = blosums |
|
|
|
|
|
_print(df) |
|
|
df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|