FlowProt / model /esm_eval.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
3.54 kB
"""
import os
from helpers.sequence import extract_seqs_from_dir
mpnn_dir = os.path.join("mpnn_results", "seqs")
mpnn_sequences = extract_seqs_from_dir(mpnn_dir, extension="fa")
print(mpnn_sequences)
"""
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from tqdm import tqdm
import pandas as pd
import os
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1", cache_dir="esm_cache")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True, cache_dir="esm_cache")
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
model.to(device)
print("Model is loaded!")
torch.backends.cuda.matmul.allow_tf32 = True
model.trunk.set_chunk_size(64)
def convert_outputs_to_pdb(outputs):
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
return pdbs
def write_pdb(pdb_path, seq_num, pdbcontent):
outfile_name = os.path.splitext(os.path.basename(pdb_path))[0]
outfile_name = f"{outfile_name}_{seq_num}.pdb"
out_path = os.path.join("esmfold_outputs", outfile_name)
with open(out_path, "w") as f:
f.write("".join(pdbcontent))
# Multiple files
"""
mpnn_dir = os.path.join("mpnn_results", "seqs")
mpnn_sequences = extract_seqs_from_dir(mpnn_dir, extension="fa")
"""
df = pd.read_csv("mpnn_sequences.csv")
sequences_tokenized = tokenizer(df.seq.tolist(), padding=False, add_special_tokens=False)["input_ids"]
# sequences_tokenized = tokenizer(mpnn_sequences, padding=False, add_special_tokens=False)["input_ids"]
# outputs = []
if not os.path.exists("esmfold_outputs"):
os.makedirs("esmfold_outputs")
tk = tqdm(zip(sequences_tokenized, df.pdb_path.tolist(), df.seq_num.tolist()), total=df.shape[0])
with torch.no_grad():
#for input_ids, pdb_path, seq_num in tqdm(sequences_tokenized):
for input_ids, pdb_path, seq_num in tk:
input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
output = model(input_ids)
# outputs.append({key: val.cpu() for key, val in output.items()})
# Write to a file
pdbfile = convert_outputs_to_pdb({key: val.cpu() for key, val in output.items()})
write_pdb(pdb_path, seq_num, pdbfile)
"""
pdb_list = [convert_outputs_to_pdb(output) for output in outputs]
for pdb_path, seq_num, pdb_content in zip(df.pdb_path.tolist(), df.seq_num.tolist(), pdb_list):
outfile_name = os.path.splitext(os.path.basename(pdb_path))[0]
outfile_name = f"{outfile_name}_{seq_num}.pdb"
out_path = os.path.join("esmfold_outputs", outfile_name)
with open(out_path, "w") as f:
f.write("".join(pdb_content))
"""