Honzus24's picture
initial commit
7968cb0
import os
import json
import numpy as np
from tqdm import tqdm
import random
import torch.utils.data as data
from .utils import cached_property
from transformers import AutoTokenizer
from src.tools.utils import load_yaml_config
class CATHDataset(data.Dataset):
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.2, data_jsonl_name='/chain_set.jsonl'):
self.version = version
self.path = path
self.mode = split
self.max_length = max_length
self.test_name = test_name
self.removeTS = removeTS
self.data_jsonl_name = data_jsonl_name
self.using_dynamics = data_jsonl_name == load_yaml_config('/scratch/project/fta-24-31/koubapet/ProteinInvBench/src/models/configs/FlexibilityProtTrans.yaml')['data_jsonl_name']
print(self.data_jsonl_name)
if self.removeTS:
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
if data is None:
if split == 'predict':
_split = 'valid'
print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
else:
_split = split
self.data = self.cache_data[_split]
else:
self.data = data
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
@cached_property
def cache_data(self):
alphabet='ACDEFGHIKLMNPQRSTVWY'
alphabet_set = set([a for a in alphabet])
print("path is: ", self.path)
if not os.path.exists(self.path):
raise "no such file:{} !!!".format(self.path)
else:
with open(self.path+'/'+self.data_jsonl_name) as f:
lines = f.readlines()
data_list = []
for line in tqdm(lines):
entry = json.loads(line)
if self.removeTS and entry['name'] in self.remove:
continue
seq = entry['seq']
for key, val in entry['coords'].items():
entry['coords'][key] = np.asarray(val)
bad_chars = set([s for s in seq]).difference(alphabet_set)
if len(bad_chars) == 0:
if len(entry['seq']) <= self.max_length:
chain_length = len(entry['seq'])
chain_mask = np.ones(chain_length)
data_list.append({
'title':entry['name'],
'seq':entry['seq'],
'CA':entry['coords']['CA'],
'C':entry['coords']['C'],
'O':entry['coords']['O'],
'N':entry['coords']['N'],
'chain_mask': chain_mask,
'chain_encoding': 1*chain_mask
})
if self.using_dynamics: #TODO: pass this bool properly
data_list[-1]['norm_bfactors'] = entry['bfactor']
if self.version==4.2:
with open(self.path+'/chain_set_splits.json') as f:
dataset_splits = json.load(f)
if self.version==4.3:
with open(self.path+'/chain_set_splits.json') as f:
dataset_splits = json.load(f)
if self.test_name == 'L100':
with open(self.path+'/test_split_L100.json') as f:
test_splits = json.load(f)
dataset_splits['test'] = test_splits['test']
if self.test_name == 'sc':
with open(self.path+'/test_split_sc.json') as f:
test_splits = json.load(f)
dataset_splits['test'] = test_splits['test']
name2set = {}
name2set.update({name:'train' for name in dataset_splits['train']})
name2set.update({name:'valid' for name in dataset_splits['validation']})
name2set.update({name:'test' for name in dataset_splits['test']})
data_dict = {'train':[],'valid':[],'test':[]}
for data in data_list:
if name2set.get(data['title']):
if name2set[data['title']] == 'train':
data_dict['train'].append(data)
if name2set[data['title']] == 'valid':
data_dict['valid'].append(data)
if name2set[data['title']] == 'test':
data['category'] = 'Unkown'
data['score'] = 100.0
data_dict['test'].append(data)
return data_dict
def change_mode(self, mode):
self.data = self.cache_data[mode]
def __len__(self):
return len(self.data)
def get_item(self, index):
return self.data[index]
def __getitem__(self, index):
item = self.data[index]
L = len(item['seq'])
if L>self.max_length:
# 计算截断的最大索引
max_index = L - self.max_length
# 生成随机的截断索引
truncate_index = random.randint(0, max_index)
# 进行截断
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
return item