import os import json import numpy as np from tqdm import tqdm import random import torch.utils.data as data from .utils import cached_property from transformers import AutoTokenizer from src.tools.utils import load_yaml_config class FlexCATHDataset(data.Dataset): def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.3, data_jsonl_name='/chain_set.jsonl', use_dynamics=True): self.version = version self.path = path self.mode = split self.max_length = max_length self.test_name = test_name self.removeTS = removeTS self.data_jsonl_name = data_jsonl_name self.using_dynamics = use_dynamics print(self.data_jsonl_name) if self.removeTS: self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove'] if data is None: if split == 'predict': _split = 'valid' print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.') else: _split = split self.data = self.cache_data[_split] else: self.data = data self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/") @cached_property def cache_data(self): alphabet='ACDEFGHIKLMNPQRSTVWY' alphabet_set = set([a for a in alphabet]) print("path is: ", self.path) if not os.path.exists(self.path): raise "no such file:{} !!!".format(self.path) else: with open(self.path+'/'+self.data_jsonl_name) as f: lines = f.readlines() data_list = [] for line in tqdm(lines): entry = json.loads(line) if self.removeTS and entry['name'] in self.remove: continue seq = entry['seq'] for key, val in entry['coords'].items(): entry['coords'][key] = np.asarray(val) bad_chars = set([s for s in seq]).difference(alphabet_set) if len(bad_chars) == 0: if len(entry['seq']) <= self.max_length: chain_length = len(entry['seq']) chain_mask = np.ones(chain_length) data_list.append({ 'title':entry['name'], 'seq':entry['seq'], 'CA':entry['coords']['CA'], 'C':entry['coords']['C'], 'O':entry['coords']['O'], 'N':entry['coords']['N'], 'chain_mask': chain_mask, 'chain_encoding': 1*chain_mask }) if self.using_dynamics: data_list[-1]['gt_flex'] = entry['gt_flex'] data_list[-1]['enm_vals'] = entry['enm_vals'] if 'original_gt_flex' in entry: data_list[-1]['original_gt_flex'] = entry['original_gt_flex'] if 'eng_mask' in entry: data_list[-1]['eng_mask'] = entry['eng_mask'] # else: # import pdb; pdb.set_trace() # print("Bad chars found in sequence: ", bad_chars) if self.version==4.2: with open(self.path+'/chain_set_splits.json') as f: dataset_splits = json.load(f) if self.version==4.3: with open(self.path+'/chain_set_splits.json') as f: dataset_splits = json.load(f) # _dataset_splits = json.load(f) # dataset_splits = {k: _dataset_splits['train'] for k,_ in _dataset_splits.items()} # print("TODO: FIX THIS BACK!!!") # import pdb; pdb.set_trace() if self.test_name == 'L100': with open(self.path+'/test_split_L100.json') as f: test_splits = json.load(f) dataset_splits['test'] = test_splits['test'] if self.test_name == 'sc': with open(self.path+'/test_split_sc.json') as f: test_splits = json.load(f) dataset_splits['test'] = test_splits['test'] name2set = {} name2set.update({name:'train' for name in dataset_splits['train']}) name2set.update({name:'valid' for name in dataset_splits['validation']}) name2set.update({name:'test' for name in dataset_splits['test']}) data_dict = {'train':[],'valid':[],'test':[]} for data in data_list: if name2set.get(data['title']): if name2set[data['title']] == 'train': data_dict['train'].append(data) if name2set[data['title']] == 'valid': data_dict['valid'].append(data) if name2set[data['title']] == 'test': data['category'] = 'Unkown' data['score'] = 100.0 data_dict['test'].append(data) return data_dict def change_mode(self, mode): self.data = self.cache_data[mode] def __len__(self): return len(self.data) def get_item(self, index): return self.data[index] def __getitem__(self, index): item = self.data[index] L = len(item['seq']) if L>self.max_length: # 计算截断的最大索引 max_index = L - self.max_length # 生成随机的截断索引 truncate_index = random.randint(0, max_index) # 进行截断 item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length] item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length] item['C'] = item['C'][truncate_index:truncate_index+self.max_length] item['O'] = item['O'][truncate_index:truncate_index+self.max_length] item['N'] = item['N'][truncate_index:truncate_index+self.max_length] item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length] item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length] item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length] item['enm_vals'] = item['enm_vals'][truncate_index:truncate_index+self.max_length] return item