Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,886 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import json
import numpy as np
from tqdm import tqdm
import random
import torch.utils.data as data
from .utils import cached_property
from transformers import AutoTokenizer
from src.tools.utils import load_yaml_config
class FlexCATHDataset(data.Dataset):
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.3, data_jsonl_name='/chain_set.jsonl', use_dynamics=True):
self.version = version
self.path = path
self.mode = split
self.max_length = max_length
self.test_name = test_name
self.removeTS = removeTS
self.data_jsonl_name = data_jsonl_name
self.using_dynamics = use_dynamics
print(self.data_jsonl_name)
if self.removeTS:
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
if data is None:
if split == 'predict':
_split = 'valid'
print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
else:
_split = split
self.data = self.cache_data[_split]
else:
self.data = data
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
@cached_property
def cache_data(self):
alphabet='ACDEFGHIKLMNPQRSTVWY'
alphabet_set = set([a for a in alphabet])
print("path is: ", self.path)
if not os.path.exists(self.path):
raise "no such file:{} !!!".format(self.path)
else:
with open(self.path+'/'+self.data_jsonl_name) as f:
lines = f.readlines()
data_list = []
for line in tqdm(lines):
entry = json.loads(line)
if self.removeTS and entry['name'] in self.remove:
continue
seq = entry['seq']
for key, val in entry['coords'].items():
entry['coords'][key] = np.asarray(val)
bad_chars = set([s for s in seq]).difference(alphabet_set)
if len(bad_chars) == 0:
if len(entry['seq']) <= self.max_length:
chain_length = len(entry['seq'])
chain_mask = np.ones(chain_length)
data_list.append({
'title':entry['name'],
'seq':entry['seq'],
'CA':entry['coords']['CA'],
'C':entry['coords']['C'],
'O':entry['coords']['O'],
'N':entry['coords']['N'],
'chain_mask': chain_mask,
'chain_encoding': 1*chain_mask
})
if self.using_dynamics:
data_list[-1]['gt_flex'] = entry['gt_flex']
data_list[-1]['enm_vals'] = entry['enm_vals']
if 'original_gt_flex' in entry:
data_list[-1]['original_gt_flex'] = entry['original_gt_flex']
if 'eng_mask' in entry:
data_list[-1]['eng_mask'] = entry['eng_mask']
# else:
# import pdb; pdb.set_trace()
# print("Bad chars found in sequence: ", bad_chars)
if self.version==4.2:
with open(self.path+'/chain_set_splits.json') as f:
dataset_splits = json.load(f)
if self.version==4.3:
with open(self.path+'/chain_set_splits.json') as f:
dataset_splits = json.load(f)
# _dataset_splits = json.load(f)
# dataset_splits = {k: _dataset_splits['train'] for k,_ in _dataset_splits.items()}
# print("TODO: FIX THIS BACK!!!")
# import pdb; pdb.set_trace()
if self.test_name == 'L100':
with open(self.path+'/test_split_L100.json') as f:
test_splits = json.load(f)
dataset_splits['test'] = test_splits['test']
if self.test_name == 'sc':
with open(self.path+'/test_split_sc.json') as f:
test_splits = json.load(f)
dataset_splits['test'] = test_splits['test']
name2set = {}
name2set.update({name:'train' for name in dataset_splits['train']})
name2set.update({name:'valid' for name in dataset_splits['validation']})
name2set.update({name:'test' for name in dataset_splits['test']})
data_dict = {'train':[],'valid':[],'test':[]}
for data in data_list:
if name2set.get(data['title']):
if name2set[data['title']] == 'train':
data_dict['train'].append(data)
if name2set[data['title']] == 'valid':
data_dict['valid'].append(data)
if name2set[data['title']] == 'test':
data['category'] = 'Unkown'
data['score'] = 100.0
data_dict['test'].append(data)
return data_dict
def change_mode(self, mode):
self.data = self.cache_data[mode]
def __len__(self):
return len(self.data)
def get_item(self, index):
return self.data[index]
def __getitem__(self, index):
item = self.data[index]
L = len(item['seq'])
if L>self.max_length:
# 计算截断的最大索引
max_index = L - self.max_length
# 生成随机的截断索引
truncate_index = random.randint(0, max_index)
# 进行截断
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length]
item['enm_vals'] = item['enm_vals'][truncate_index:truncate_index+self.max_length]
return item |