| | import torch |
| | import os, glob |
| | import pandas as pd |
| | import numpy as np |
| | from src.data.esm.sdk.api import ESMProtein |
| | from src.utils.utils import pmap_multi |
| |
|
| |
|
| | def read_data(aa_seq, pdb_path, target, unique_id, position=None, msa_aas=None): |
| | try: |
| | if unique_id is None: |
| | unique_id = str(hash(aa_seq)) |
| | |
| | if position: |
| | position = torch.tensor([int(ele) for ele in position.split(",")]) |
| | |
| | if pdb_path: |
| | structure = ESMProtein.from_pdb(pdb_path) |
| | coordinates = structure.coordinates |
| | else: |
| | coordinates = None |
| |
|
| | return { |
| | 'name': target, |
| | 'seq': aa_seq, |
| | 'X': coordinates, |
| | 'label': torch.tensor([0]), |
| | 'unique_id': unique_id, |
| | 'pdb_path': pdb_path, |
| | 'position': position, |
| | 'msa_aas': msa_aas |
| | } |
| | except: |
| | return None |
| | |
| |
|
| | class MSADataset(torch.utils.data.Dataset): |
| | def __init__( |
| | self, |
| | msa_csv_path: str, |
| | type: str = "center" |
| | ): |
| | self.msa_csv_path = msa_csv_path |
| | msa_df = pd.read_csv(self.msa_csv_path) |
| | |
| | if type == "center": |
| | msa_df = msa_df[msa_df["type"] == type] |
| | else: |
| | center_df = msa_df[msa_df["type"] == "center"] |
| | id_seq_dict = dict(zip(center_df['unique_id'], center_df['aa_seq_ori'])) |
| | msa_df = msa_df[msa_df["type"] == type] |
| |
|
| | def compute_position_diff(x): |
| | seq_with_gap = x["aa_seq_ori"] |
| | center_seq = id_seq_dict[x["target"]] |
| | |
| | pos_diff, msa_pos_aa = [], [] |
| | center_idx = 0 |
| |
|
| | for i, aa in enumerate(seq_with_gap): |
| | if aa == "-": |
| | continue |
| |
|
| | if center_idx >= len(center_seq): |
| | break |
| |
|
| | if aa.upper() != center_seq[center_idx].upper(): |
| | pos_diff.append(str(i)) |
| | msa_pos_aa.append(aa.upper()) |
| | center_idx += 1 |
| |
|
| | return ",".join(pos_diff) |
| | |
| | def compute_aa_diff(x): |
| | seq_with_gap = x["aa_seq_ori"] |
| | center_seq = id_seq_dict[x["target"]] |
| | |
| | pos_diff, msa_pos_aa = [], [] |
| | center_idx = 0 |
| |
|
| | for i, aa in enumerate(seq_with_gap): |
| | if aa == "-": |
| | continue |
| |
|
| | if center_idx >= len(center_seq): |
| | break |
| |
|
| | if aa.upper() != center_seq[center_idx].upper(): |
| | pos_diff.append(str(i)) |
| | msa_pos_aa.append(aa.upper()) |
| | center_idx += 1 |
| |
|
| | return "".join(msa_pos_aa) |
| | |
| | def compute_aa_diff(x): |
| | seq_with_gap = x["aa_seq_ori"] |
| | center_seq = id_seq_dict[x["target"]] |
| | |
| | pos_diff, msa_pos_aa = [], [] |
| | center_idx = 0 |
| |
|
| | for i, aa in enumerate(seq_with_gap): |
| | if aa == "-": |
| | continue |
| |
|
| | if center_idx >= len(center_seq): |
| | break |
| |
|
| | if aa.upper() != center_seq[center_idx].upper(): |
| | pos_diff.append(str(i)) |
| | msa_pos_aa.append(aa.upper()) |
| | center_idx += 1 |
| |
|
| | return "".join(msa_pos_aa) |
| |
|
| | msa_df["position"] = msa_df.apply(compute_position_diff, axis=1) |
| | msa_df["msa_aas"] = msa_df.apply(compute_aa_diff, axis=1) |
| | path_list = [] |
| | for i in range(len(msa_df)): |
| | path_list.append( |
| | ( |
| | msa_df.iloc[i].get('aa_seq'), |
| | msa_df.iloc[i].get('pdb_path'), |
| | msa_df.iloc[i].get('target'), |
| | msa_df.iloc[i].get('unique_id'), |
| | msa_df.iloc[i].get('position'), |
| | msa_df.iloc[i].get('msa_aas'), |
| | ) |
| | ) |
| | |
| | self.data = pmap_multi(read_data, path_list, n_jobs=8) |
| | self.data = [d for d in self.data if d is not None] |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| | |
| | def __getitem__( |
| | self, |
| | idx |
| | ): |
| | return self.data[idx] |
| | |
| |
|
| | if __name__ == "__main__": |
| | msa_data_center = MSADataset( |
| | msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv", |
| | type="center" |
| | ) |
| | print(f"length of msa dataset: {len(msa_data_center)}...") |
| |
|
| | msa_data_msa = MSADataset( |
| | msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv", |
| | type="msa" |
| | ) |
| | print(f"length of msa dataset: {len(msa_data_msa)}...") |
| |
|