| import torch |
| import os, glob |
| import pandas as pd |
| import numpy as np |
| from src.data.esm.sdk.api import ESMProtein |
| from src.utils.utils import pmap_multi |
|
|
|
|
| def read_data(aa_seq, pdb_path, target, unique_id, position=None, msa_aas=None): |
| try: |
| if unique_id is None: |
| unique_id = str(hash(aa_seq)) |
| |
| if position: |
| position = torch.tensor([int(ele) for ele in position.split(",")]) |
| |
| if pdb_path: |
| structure = ESMProtein.from_pdb(pdb_path) |
| coordinates = structure.coordinates |
| else: |
| coordinates = None |
|
|
| return { |
| 'name': target, |
| 'seq': aa_seq, |
| 'X': coordinates, |
| 'label': torch.tensor([0]), |
| 'unique_id': unique_id, |
| 'pdb_path': pdb_path, |
| 'position': position, |
| 'msa_aas': msa_aas |
| } |
| except: |
| return None |
| |
|
|
| class MSADataset(torch.utils.data.Dataset): |
| def __init__( |
| self, |
| msa_csv_path: str, |
| type: str = "center" |
| ): |
| self.msa_csv_path = msa_csv_path |
| msa_df = pd.read_csv(self.msa_csv_path) |
| |
| if type == "center": |
| msa_df = msa_df[msa_df["type"] == type] |
| else: |
| center_df = msa_df[msa_df["type"] == "center"] |
| id_seq_dict = dict(zip(center_df['unique_id'], center_df['aa_seq_ori'])) |
| msa_df = msa_df[msa_df["type"] == type] |
|
|
| def compute_position_diff(x): |
| seq_with_gap = x["aa_seq_ori"] |
| center_seq = id_seq_dict[x["target"]] |
| |
| pos_diff, msa_pos_aa = [], [] |
| center_idx = 0 |
|
|
| for i, aa in enumerate(seq_with_gap): |
| if aa == "-": |
| continue |
|
|
| if center_idx >= len(center_seq): |
| break |
|
|
| if aa.upper() != center_seq[center_idx].upper(): |
| pos_diff.append(str(i)) |
| msa_pos_aa.append(aa.upper()) |
| center_idx += 1 |
|
|
| return ",".join(pos_diff) |
| |
| def compute_aa_diff(x): |
| seq_with_gap = x["aa_seq_ori"] |
| center_seq = id_seq_dict[x["target"]] |
| |
| pos_diff, msa_pos_aa = [], [] |
| center_idx = 0 |
|
|
| for i, aa in enumerate(seq_with_gap): |
| if aa == "-": |
| continue |
|
|
| if center_idx >= len(center_seq): |
| break |
|
|
| if aa.upper() != center_seq[center_idx].upper(): |
| pos_diff.append(str(i)) |
| msa_pos_aa.append(aa.upper()) |
| center_idx += 1 |
|
|
| return "".join(msa_pos_aa) |
| |
| def compute_aa_diff(x): |
| seq_with_gap = x["aa_seq_ori"] |
| center_seq = id_seq_dict[x["target"]] |
| |
| pos_diff, msa_pos_aa = [], [] |
| center_idx = 0 |
|
|
| for i, aa in enumerate(seq_with_gap): |
| if aa == "-": |
| continue |
|
|
| if center_idx >= len(center_seq): |
| break |
|
|
| if aa.upper() != center_seq[center_idx].upper(): |
| pos_diff.append(str(i)) |
| msa_pos_aa.append(aa.upper()) |
| center_idx += 1 |
|
|
| return "".join(msa_pos_aa) |
|
|
| msa_df["position"] = msa_df.apply(compute_position_diff, axis=1) |
| msa_df["msa_aas"] = msa_df.apply(compute_aa_diff, axis=1) |
| path_list = [] |
| for i in range(len(msa_df)): |
| path_list.append( |
| ( |
| msa_df.iloc[i].get('aa_seq'), |
| msa_df.iloc[i].get('pdb_path'), |
| msa_df.iloc[i].get('target'), |
| msa_df.iloc[i].get('unique_id'), |
| msa_df.iloc[i].get('position'), |
| msa_df.iloc[i].get('msa_aas'), |
| ) |
| ) |
| |
| self.data = pmap_multi(read_data, path_list, n_jobs=8) |
| self.data = [d for d in self.data if d is not None] |
|
|
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__( |
| self, |
| idx |
| ): |
| return self.data[idx] |
| |
|
|
| if __name__ == "__main__": |
| msa_data_center = MSADataset( |
| msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv", |
| type="center" |
| ) |
| print(f"length of msa dataset: {len(msa_data_center)}...") |
|
|
| msa_data_msa = MSADataset( |
| msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv", |
| type="msa" |
| ) |
| print(f"length of msa dataset: {len(msa_data_msa)}...") |
|
|