| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| from functools import partial |
| from os.path import join as opjoin |
| from typing import Callable, Tuple |
|
|
| with open("./scripts/msa/data/pdb_seqs/seq_to_pdb_id_entity_id.json", "r") as f: |
| seq_to_pdbid = json.load(f) |
| first_pdbid_to_seq = {"_".join(v[0]): k for k, v in seq_to_pdbid.items()} |
|
|
| with open("./scripts/msa/data/pdb_seqs/seq_to_pdb_index.json", "r") as f: |
| seq_to_pdb_index = json.load(f) |
|
|
|
|
| def rematch(pdb_line: str) -> Tuple[str, str]: |
| pdb_id = pdb_line[1:-1] |
| origin_query_seq = first_pdbid_to_seq[pdb_id] |
| pdb_index = seq_to_pdb_index[origin_query_seq] |
| return pdb_index, origin_query_seq |
|
|
|
|
| def write_log( |
| msg: str, |
| fname: str, |
| log_root: str, |
| ) -> None: |
| basename = fname.split(".")[0] |
| with open(opjoin(log_root, f"{basename}-{msg}"), "w") as f: |
| pass |
|
|
|
|
| def process_one_file( |
| fname: str, msa_root: str, save_root: str, logger: Callable |
| ) -> None: |
| with open(file_path := opjoin(msa_root, fname), "r") as f: |
| for i, line in enumerate(f): |
| if i == 0: |
| pdb_line = line |
| if i == 1: |
| if len(line) == 1: |
| logger("empty_query_seq", fname) |
| return |
| query_line = line |
| break |
|
|
| save_fname, origin_query_seq = rematch(pdb_line) |
|
|
| os.makedirs(sub_dir_path := opjoin(save_root, f"{save_fname}"), exist_ok=True) |
| uniref100_lines = [">query\n", f"{origin_query_seq}\n"] |
| other_lines = [">query\n", f"{origin_query_seq}\n"] |
|
|
| with open(file_path, "r") as f: |
| lines = f.readlines() |
|
|
| for i, line in enumerate(lines): |
| if i < 2: |
| continue |
| if i % 2 == 0: |
| |
| if not line.startswith(">"): |
| logger(f"bad_header_{i}", fname) |
| return |
| seq = lines[i + 1] |
|
|
| if line.startswith(">UniRef100"): |
| uniref100_lines.extend([line, seq]) |
| else: |
| other_lines.extend([line, seq]) |
|
|
| assert len(other_lines) + len(uniref100_lines) - 2 == len(lines) |
|
|
| other_lines = other_lines[0:2] + other_lines[4:] |
| for i, line in enumerate(other_lines): |
| if i > 0 and i % 2 == 0: |
| assert "\t" in line |
| with open(opjoin(sub_dir_path, "uniref100_hits.a3m"), "w") as f: |
| for line in uniref100_lines: |
| f.write(line) |
| with open(opjoin(sub_dir_path, "mmseqs_other_hits.a3m"), "w") as f: |
| for line in other_lines: |
| f.write(line) |
|
|
|
|
| if __name__ == "__main__": |
| msa_root = "./scripts/msa/data/mmcif_msa_with_taxid" |
| save_root = "./scripts/msa/data/mmcif_msa" |
| log_root = "./scripts/msa/data/mmcif_msa_log" |
|
|
| os.makedirs(log_root, exist_ok=True) |
| os.makedirs(save_root, exist_ok=True) |
|
|
| print("Loading file names...") |
|
|
| logger = partial(write_log, log_root=log_root) |
| for fname in os.listdir(msa_root): |
| process_one_file( |
| fname=fname, msa_root=msa_root, save_root=save_root, logger=logger |
| ) |
|
|