flexpert / Flexpert-Design /src /datasets /alphafold_dataset.py
Honzus24's picture
initial commit
7968cb0
import os
import os.path as osp
import json
import numpy as np
import pickle as cPickle
import torch.utils.data as data
from src.datasets.utils import cached_property
class AlphaFoldDataset(data.Dataset):
def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0):
self.path = path
self.upid = upid
self.max_length = max_length
self.limit_length = limit_length
self.joint_data = joint_data
if mode in ['train', 'valid', 'test']:
self.data = self.cache_data[mode]
if mode == 'all':
self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test']
self.lengths = np.array([ len(sample['seq']) for sample in self.data])
self.max_len = np.max(self.lengths)
self.min_len = np.min(self.lengths)
def _raw_data(self, path, upid):
if not os.path.exists(path):
raise "no such file:{} !!!".format(path)
else:
path = osp.join(path, upid)
data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb'))
score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb'))
for i in range(len(data_)):
data_[i]['score'] = score_[i]['res_score']
return data_
def _data_info(self, data):
len_inds = []
seq2ind = {}
for ind, temp in enumerate(data):
if self.limit_length:
if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length:
# 'title', 'seq', 'CA', 'C', 'O', 'N'
len_inds.append(ind)
seq2ind[temp['seq']] = ind
else:
len_inds.append(ind)
seq2ind[temp['seq']] = ind
return len_inds, seq2ind
def get_data(self, path, upid, **kwargs):
data_ = self._raw_data(path, upid)
path = osp.join(path, upid)
file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
assert os.path.exists(osp.join(path, file_name))
split = json.load(open(osp.join(path, file_name),'r'))
data_dict = {'train':[data_[i] for i in split['train']],
'valid':[data_[i] for i in split['valid']],
'test':[data_[i] for i in split['test']]}
return data_dict
def get_full_data(self, path, **kwargs):
datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
assert os.path.exists(osp.join(path, 'full', file_name))
split = json.load(open(osp.join(path, 'full', file_name),'r'))
return split
@cached_property
def cache_data(self): # TODO: joint_data
path = self.path
upid = self.upid
if self.joint_data:
datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
data_dict = {'train':[], 'valid':[], 'test':[]}
full_inds = self.get_full_data(path)
for dataname in datanames:
temp = self._raw_data(path, dataname)
train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test'])
data_dict['train'] += [temp[i] for i in train_idx]
data_dict['valid'] += [temp[i] for i in valid_idx]
data_test = []
for i in test_idx:
item = temp[i]
item['category'] = dataname
data_test.append(temp[i])
data_dict['test'] += data_test
else:
data_dict = self.get_data(path, upid)
for item in data_dict['test']:
item['category'] = upid
return data_dict
def change_mode(self, mode):
self.data = self.cache_data[mode]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]