Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import os.path as osp | |
| import json | |
| import numpy as np | |
| import pickle as cPickle | |
| import torch.utils.data as data | |
| from src.datasets.utils import cached_property | |
| class AlphaFoldDataset(data.Dataset): | |
| def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0): | |
| self.path = path | |
| self.upid = upid | |
| self.max_length = max_length | |
| self.limit_length = limit_length | |
| self.joint_data = joint_data | |
| if mode in ['train', 'valid', 'test']: | |
| self.data = self.cache_data[mode] | |
| if mode == 'all': | |
| self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test'] | |
| self.lengths = np.array([ len(sample['seq']) for sample in self.data]) | |
| self.max_len = np.max(self.lengths) | |
| self.min_len = np.min(self.lengths) | |
| def _raw_data(self, path, upid): | |
| if not os.path.exists(path): | |
| raise "no such file:{} !!!".format(path) | |
| else: | |
| path = osp.join(path, upid) | |
| data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb')) | |
| score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb')) | |
| for i in range(len(data_)): | |
| data_[i]['score'] = score_[i]['res_score'] | |
| return data_ | |
| def _data_info(self, data): | |
| len_inds = [] | |
| seq2ind = {} | |
| for ind, temp in enumerate(data): | |
| if self.limit_length: | |
| if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length: | |
| # 'title', 'seq', 'CA', 'C', 'O', 'N' | |
| len_inds.append(ind) | |
| seq2ind[temp['seq']] = ind | |
| else: | |
| len_inds.append(ind) | |
| seq2ind[temp['seq']] = ind | |
| return len_inds, seq2ind | |
| def get_data(self, path, upid, **kwargs): | |
| data_ = self._raw_data(path, upid) | |
| path = osp.join(path, upid) | |
| file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json' | |
| assert os.path.exists(osp.join(path, file_name)) | |
| split = json.load(open(osp.join(path, file_name),'r')) | |
| data_dict = {'train':[data_[i] for i in split['train']], | |
| 'valid':[data_[i] for i in split['valid']], | |
| 'test':[data_[i] for i in split['test']]} | |
| return data_dict | |
| def get_full_data(self, path, **kwargs): | |
| datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)] | |
| file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json' | |
| assert os.path.exists(osp.join(path, 'full', file_name)) | |
| split = json.load(open(osp.join(path, 'full', file_name),'r')) | |
| return split | |
| def cache_data(self): # TODO: joint_data | |
| path = self.path | |
| upid = self.upid | |
| if self.joint_data: | |
| datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)] | |
| data_dict = {'train':[], 'valid':[], 'test':[]} | |
| full_inds = self.get_full_data(path) | |
| for dataname in datanames: | |
| temp = self._raw_data(path, dataname) | |
| train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test']) | |
| data_dict['train'] += [temp[i] for i in train_idx] | |
| data_dict['valid'] += [temp[i] for i in valid_idx] | |
| data_test = [] | |
| for i in test_idx: | |
| item = temp[i] | |
| item['category'] = dataname | |
| data_test.append(temp[i]) | |
| data_dict['test'] += data_test | |
| else: | |
| data_dict = self.get_data(path, upid) | |
| for item in data_dict['test']: | |
| item['category'] = upid | |
| return data_dict | |
| def change_mode(self, mode): | |
| self.data = self.cache_data[mode] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, index): | |
| return self.data[index] |