Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import numpy as np | |
| import torch.utils.data as data | |
| class CASPDataset(data.Dataset): | |
| def __init__(self, path = './', split='test'): | |
| if not os.path.exists(path): | |
| raise "no such file:{} !!!".format(path) | |
| else: | |
| with open(os.path.join(path,'casp15.jsonl')) as f: | |
| lines = f.readlines() | |
| # casp15_data = json.load(open(path+'casp15.json', 'r')) | |
| alphabet='ACDEFGHIKLMNPQRSTVWY' | |
| alphabet_set = set([a for a in alphabet]) | |
| self.data = [] | |
| for line in lines: | |
| entry = json.loads(line) | |
| seq = entry['seq'] | |
| for key, val in entry['coords'].items(): | |
| entry['coords'][key] = np.asarray(val) | |
| bad_chars = set([s for s in seq]).difference(alphabet_set) | |
| if len(bad_chars) == 0: | |
| chain_length = len(entry['seq']) | |
| chain_mask = np.ones(chain_length) | |
| self.data.append({ | |
| 'title':entry['name'], | |
| 'seq':entry['seq'], | |
| 'CA':entry['coords']['CA'], | |
| 'C':entry['coords']['C'], | |
| 'O':entry['coords']['O'], | |
| 'N':entry['coords']['N'], | |
| 'chain_mask': chain_mask, | |
| 'chain_encoding': 1*chain_mask, | |
| 'classification': entry['classification'] | |
| }) | |
| def __len__(self): | |
| return len(self.data) | |
| def get_item(self, index): | |
| return self.data[index] | |
| def __getitem__(self, index): | |
| return self.data[index] | |
| if __name__ == '__main__': | |
| dataset = CASPDataset('/gaozhangyang/experiments/OpenCPD/data/casp15/') | |
| for data in dataset: | |
| print(data) |