Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,653 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import os
import json
import numpy as np
import torch.utils.data as data
class TSDataset(data.Dataset):
def __init__(self, path = './', split='test'):
if not os.path.exists(path):
raise "no such file:{} !!!".format(path)
else:
ts50_data = json.load(open(path+'/ts50.json'))
ts500_data = json.load(open(path+'/ts500.json'))
# TS500 has proteins with lengths of 500+
# TS50 only contains proteins with lengths less than 500
self.data = []
for temp in ts50_data:
coords = np.array(temp['coords'])
self.data.append({'title':temp['name'],
'seq':temp['seq'],
'CA':coords[:,1,:],
'C':coords[:,2,:],
'O':coords[:,3,:],
'N':coords[:,0,:],
'category': 'ts50'
})
for temp in ts500_data:
coords = np.array(temp['coords'])
self.data.append({'title':temp['name'],
'seq':temp['seq'],
'CA':coords[:,1,:],
'C':coords[:,2,:],
'O':coords[:,3,:],
'N':coords[:,0,:],
'category': 'ts500'
})
def __len__(self):
return len(self.data)
def get_item(self, index):
return self.data[index]
def __getitem__(self, index):
return self.data[index] |