File size: 4,141 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import os.path as osp
import json
import numpy as np
import pickle as cPickle

import torch.utils.data as data
from src.datasets.utils import cached_property


class AlphaFoldDataset(data.Dataset):
    def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0):
        
        self.path = path
        self.upid = upid
        self.max_length = max_length
        self.limit_length = limit_length
        self.joint_data = joint_data
        
        if mode in ['train', 'valid', 'test']:
            self.data = self.cache_data[mode]
        
        if mode == 'all':
            self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test']
        
        self.lengths = np.array([ len(sample['seq']) for sample in self.data])
        self.max_len = np.max(self.lengths)
        self.min_len = np.min(self.lengths)

    def _raw_data(self, path, upid):
        if not os.path.exists(path):
            raise "no such file:{} !!!".format(path)
        else:
            path = osp.join(path, upid)
            data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb'))
            score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb'))
            for i in range(len(data_)):
                data_[i]['score'] = score_[i]['res_score']
        return data_

    def _data_info(self, data):
        len_inds = []
        seq2ind = {}
        for ind, temp in enumerate(data):
            if self.limit_length:
                if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length:
                    # 'title', 'seq', 'CA', 'C', 'O', 'N'
                    len_inds.append(ind)
                    seq2ind[temp['seq']] = ind
            else:
                len_inds.append(ind)
                seq2ind[temp['seq']] = ind
        return len_inds, seq2ind
        
    def get_data(self, path, upid, **kwargs):
        data_ = self._raw_data(path, upid)
        path = osp.join(path, upid)

        file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'

        assert os.path.exists(osp.join(path, file_name))
        split = json.load(open(osp.join(path, file_name),'r'))
        data_dict = {'train':[data_[i] for i in split['train']],
                     'valid':[data_[i] for i in split['valid']],
                     'test':[data_[i] for i in split['test']]}
        return data_dict

    def get_full_data(self, path, **kwargs):
        datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
        file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
        assert os.path.exists(osp.join(path, 'full', file_name))
        split = json.load(open(osp.join(path, 'full', file_name),'r'))
        return split 
    
    @cached_property
    def cache_data(self): # TODO: joint_data
        path = self.path
        upid = self.upid
        if self.joint_data:
            datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
            data_dict = {'train':[], 'valid':[], 'test':[]}
            full_inds = self.get_full_data(path)

            for dataname in datanames:
                temp = self._raw_data(path, dataname)
                train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test'])
                data_dict['train'] += [temp[i] for i in train_idx]
                data_dict['valid'] += [temp[i] for i in valid_idx]

                data_test = []
                for i in test_idx:
                    item = temp[i]
                    item['category'] = dataname
                    data_test.append(temp[i])
                
                data_dict['test'] += data_test

        else:
            data_dict = self.get_data(path, upid)
            for item in data_dict['test']:
                item['category'] = upid

        return data_dict
    
    def change_mode(self, mode):
        self.data = self.cache_data[mode]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]