import torch from torch.utils.data import Dataset import pandas as pd from src.data.protein import Protein from transformers import AutoTokenizer import torch.nn.functional as F from src.utils.utils import pmap_multi from src.data.esm.sdk.api import ESMProtein from sklearn.preprocessing import MultiLabelBinarizer def read_data(aa_seq, name, label, task_type, num_classes, csv_name=""): try: if unique_id is None: unique_id = str(hash(aa_seq)) if task_type == "multi_labels_classification": mlb = MultiLabelBinarizer(classes=range(int(num_classes))) label = str(label) label = torch.tensor(mlb.fit_transform([[int(ele) for ele in label.split(",")]]).flatten().tolist()) elif task_type == "contact": label = torch.load(label, weights_only=True) # diable warning elif task_type == "residual_classification": label = torch.tensor(list(map(int, label.strip('[]').replace('\n', ' ').split()))) else: label = torch.tensor(label) name = str(hash(pdb_path)) if pdb_path is not None: # 解析 pdb 文件,unique_id 作为结构的 id if "|" not in pdb_path: structure = ESMProtein.from_pdb(pdb_path) # TODO return { 'name':name, # 'seq': structure.sequence, # aaseq 'seq': aa_seq if "flip" in csv_name.lower() else structure.sequence, # aaseq 'X': structure.coordinates, 'label': label, 'unique_id': unique_id, 'pdb_path': pdb_path, 'smiles': smiles } # X, C, S = structure.to_XCS(all_atom=True) # X, C, S = X[0], C[0], S[0] else: # PPI structures, sequences = [], [] for _pdb_path in pdb_path.split("|"): structure = ESMProtein.from_pdb(_pdb_path) structures.append(structure.coordinates) sequences.append(structure.sequence) return { 'name':name, 'seq': "|".join(sequences), 'X': structures, # coords is organized as a list here 'label': label, 'unique_id': unique_id, 'pdb_path': pdb_path, 'smiles': smiles } else: return { 'name':name, 'seq': aa_seq, 'X': None, 'label': label, 'unique_id': unique_id, 'pdb_path': pdb_path, 'smiles': smiles } except: return None def read_data_new(aa_seq, name, label, task_type, num_classes, csv_name=""): try: if task_type == "multi_labels_classification": mlb = MultiLabelBinarizer(classes=range(int(num_classes))) label = str(label) label = torch.tensor(mlb.fit_transform([[int(ele) for ele in label.split(",")]]).flatten().tolist()) elif task_type == "contact": label = torch.load(label, weights_only=True) # diable warning elif task_type == "residual_classification": label = torch.tensor(list(map(int, label.strip('[]').replace('\n', ' ').split()))) else: label = torch.tensor(label) return { 'name': name, 'seq': aa_seq, 'label': label, } except: return None class ProteinDataset(Dataset): def __init__(self, csv_file, pretrain_model_name='esm2_650m', max_length=1022, pretrain_model_interface=None, task_name='pretrain', task_type='classification', num_classes=None): """ Args: csv_file (str): CSV 文件路径,文件中包含蛋白质序列和结构等信息。 """ self.max_length = max_length self.pretrain_model_name = pretrain_model_name self.task_name = task_name self.task_type = task_type self.num_classes = num_classes # 读取 CSV 数据 if task_name=="deep_loc_binary": csv_data = pd.read_csv(csv_file) path_list = [] for i in range(len(csv_data)): path_list.append((csv_data.iloc[i].get('aa_seq'), csv_data.iloc[i].get('name'), csv_data.iloc[i]['label'], task_type, num_classes, csv_file)) #列表里面必须是元组,不然debug模式下并行加载数据会报错 # path_list = path_list[:10] # this is for fast debug, please comment it in production self.data = pmap_multi(read_data, path_list, n_jobs=-1) self.data = [d for d in self.data if d is not None] self.max_length = min(self.max_length, max([len(d['seq']) for d in self.data])+2) self.pretrain_model_interface = pretrain_model_interface if pretrain_model_interface is not None: self.data = pretrain_model_interface.inference_datasets(self.data, task_name=self.task_name) print(f"ProteinDataset: {len(self.data)} samples loaded.") def __len__(self): return len(self.data) def pad_data(self, data, dim=0, pad_value=0, max_length=1022): if data.shape[dim] < max_length: data = dynamic_pad(data, [0, max_length-data.shape[dim]], dim=dim, pad_value=pad_value) else: start = 0 data = data[start:start+max_length] return data def __getitem__(self, idx): if self.pretrain_model_interface is not None: max_length_batch = self.max_length name = self.data[idx]['name'] embedding = self.pad_data(self.data[idx]['embedding'], dim=0, pad_value=0, max_length=max_length_batch) attention_mask = self.pad_data(self.data[idx]['attention_mask'], dim=0, pad_value=0, max_length=max_length_batch) label = self.data[idx]['label'] if self.task_type == 'binary_classification': label = label[None].float() if self.task_type == 'contact': label = (label == 0).int() label = F.pad(label, [0, max_length_batch-label.shape[0], 0, max_length_batch-label.shape[0]]) if self.task_type == 'residual_classification': label = F.pad(label, [0, max_length_batch-label.shape[0]]) result = { 'name': name, 'embedding': embedding, 'attention_mask': attention_mask, 'label': label, } if self.data[idx].get('smiles') is not None: smiles = self.data[idx]['smiles'] result['smiles'] = smiles return result else: max_length_batch = self.max_length label = self.data[idx]['label'] if self.task_type == 'binary_classification': label = label[None].float() if self.task_type == 'contact': label = (label == 0).int() label = F.pad(label, [0, max_length_batch-label.shape[0], 0, max_length_batch-label.shape[0]]) if self.task_type == 'residual_classification': label = F.pad(label, [0, max_length_batch-label.shape[0]]) data = { 'name': self.data[idx]["name"], 'seq': self.data[idx]["seq"], 'X': self.data[idx]["X"], 'label': label, 'unique_id': self.data[idx]["unique_id"], 'pdb_path': self.data[idx]["pdb_path"], 'smiles': self.data[idx]["smiles"], } return data def dynamic_pad(tensor, pad_size, dim=0, pad_value=0): # 获取原始形状 shape = list(tensor.shape) num_dims = len(shape) # 生成 padding 参数 pad = [0] * (2 * num_dims) prev_pad_size, post_pad_size = pad_size pad_index = 2 * (num_dims - dim - 1) pad[pad_index] = prev_pad_size # 前面 padding pad[pad_index + 1] = post_pad_size # 后面 padding # 应用 padding padded_tensor = F.pad(tensor, pad, mode="constant", value=pad_value) return padded_tensor # 示例用法 if __name__ == "__main__": dataset = ProteinDataset("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fold_prediction/fold_prediction_with_glmfold_structure_test.csv") sample = dataset[0] print(sample['coords'].shape) print(sample['chain']) print(sample['sequence'])