| """ |
| import os |
| from helpers.sequence import extract_seqs_from_dir |
| |
| mpnn_dir = os.path.join("mpnn_results", "seqs") |
| mpnn_sequences = extract_seqs_from_dir(mpnn_dir, extension="fa") |
| print(mpnn_sequences) |
| """ |
| import torch |
|
|
| from transformers import AutoTokenizer, EsmForProteinFolding |
| from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein |
| from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 |
| from tqdm import tqdm |
|
|
| import pandas as pd |
|
|
| import os |
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1", cache_dir="esm_cache") |
| model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True, cache_dir="esm_cache") |
| device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") |
| model.to(device) |
| print("Model is loaded!") |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| model.trunk.set_chunk_size(64) |
|
|
| def convert_outputs_to_pdb(outputs): |
| final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) |
| outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} |
| final_atom_positions = final_atom_positions.cpu().numpy() |
| final_atom_mask = outputs["atom37_atom_exists"] |
| pdbs = [] |
| for i in range(outputs["aatype"].shape[0]): |
| aa = outputs["aatype"][i] |
| pred_pos = final_atom_positions[i] |
| mask = final_atom_mask[i] |
| resid = outputs["residue_index"][i] + 1 |
| pred = OFProtein( |
| aatype=aa, |
| atom_positions=pred_pos, |
| atom_mask=mask, |
| residue_index=resid, |
| b_factors=outputs["plddt"][i], |
| chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, |
| ) |
| pdbs.append(to_pdb(pred)) |
| return pdbs |
|
|
|
|
| def write_pdb(pdb_path, seq_num, pdbcontent): |
| outfile_name = os.path.splitext(os.path.basename(pdb_path))[0] |
| outfile_name = f"{outfile_name}_{seq_num}.pdb" |
| out_path = os.path.join("esmfold_outputs", outfile_name) |
| with open(out_path, "w") as f: |
| f.write("".join(pdbcontent)) |
|
|
| |
| |
| """ |
| mpnn_dir = os.path.join("mpnn_results", "seqs") |
| mpnn_sequences = extract_seqs_from_dir(mpnn_dir, extension="fa") |
| """ |
| df = pd.read_csv("mpnn_sequences.csv") |
| sequences_tokenized = tokenizer(df.seq.tolist(), padding=False, add_special_tokens=False)["input_ids"] |
|
|
| |
| |
| if not os.path.exists("esmfold_outputs"): |
| os.makedirs("esmfold_outputs") |
|
|
| tk = tqdm(zip(sequences_tokenized, df.pdb_path.tolist(), df.seq_num.tolist()), total=df.shape[0]) |
| with torch.no_grad(): |
| |
| for input_ids, pdb_path, seq_num in tk: |
| input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) |
| output = model(input_ids) |
| |
| |
| pdbfile = convert_outputs_to_pdb({key: val.cpu() for key, val in output.items()}) |
| write_pdb(pdb_path, seq_num, pdbfile) |
| |
|
|
| """ |
| pdb_list = [convert_outputs_to_pdb(output) for output in outputs] |
| |
| for pdb_path, seq_num, pdb_content in zip(df.pdb_path.tolist(), df.seq_num.tolist(), pdb_list): |
| outfile_name = os.path.splitext(os.path.basename(pdb_path))[0] |
| outfile_name = f"{outfile_name}_{seq_num}.pdb" |
| out_path = os.path.join("esmfold_outputs", outfile_name) |
| with open(out_path, "w") as f: |
| f.write("".join(pdb_content)) |
| """ |
|
|
|
|