SoluProtMutDemo / code /data_loader.py
vvelda's picture
Initial commit
b140e2c verified
import os
import math
import numpy as np
from Bio.PDB.Polypeptide import one_to_index
import torch
from torch.utils.data import Dataset
class Collate_Protein_Batch():
@staticmethod
def collate(p_batch):
batch_names = []
batch_aas = []
batch_coords = []
batch_seq_pos = []
batch_axes = []
batch_instance = []
batch_labels = []
batch_weights = []
cur_iter = 0
for protA, protB, label, w in p_batch:
for chain in (protA, protB):
if chain:
batch_names.append(chain[0])
batch_aas.append(chain[1])
batch_coords.append(chain[2])
batch_seq_pos.append(chain[3])
batch_axes.append(chain[4])
batch_instance.append(np.ones_like(chain[1])*cur_iter)
cur_iter += 1
batch_labels.append(label)
batch_weights.append(w)
batch_labels = list(filter(lambda l: l is not None, batch_labels))
return batch_names,\
torch.as_tensor(np.concatenate(batch_aas, axis=0)),\
torch.as_tensor(np.concatenate(batch_coords, axis=0)),\
torch.as_tensor(np.concatenate(batch_seq_pos, axis=0)),\
torch.as_tensor(np.concatenate(batch_axes, axis=0)),\
torch.as_tensor(np.concatenate(batch_instance, axis=0)).to(torch.int32),\
torch.as_tensor(batch_weights),\
torch.as_tensor(batch_labels)
# AA Letter to id
AA1 = "ACDEFGHIKLMNPQRSTVWYX"
AA_TO_ID = {}
for i in range(0, 21):
AA_TO_ID[AA1[i]] = i
def create_datapoint(pdb_code: str, seq: str, coords, w: float = 1):
return (
(
pdb_code,
[AA_TO_ID[aa] for aa in seq],
coords,
list(range(len(seq))),
[],
[]
), None, None, w
)
def collate_batch(p_batch):
return Collate_Protein_Batch.collate(p_batch)
class EnzymeClassDataset(Dataset):
def __init__(
self,
p_path = 'data',
p_data_path = 'chains',
p_dataset = 'training',
p_fastafile = 'chain_list_pdb.fasta',
p_random_seed = None,
p_fold: str = None, # particular fold from 1 to N
p_train_mode = False, # to select all but the given fold (for training)
p_data_aug = False,
p_batch_pairs = False,
p_load_data = False
):
if p_fold is not None and int(p_fold) < 1:
raise Exception("Fold for CV should be a positive integer! Got: " + str(p_fold))
# Random state.
self.random_state_ = np.random.RandomState(p_random_seed)
# Save the data augmentation parameters.
self.data_augment_ = p_data_aug
self.batch_pairs_ = p_batch_pairs
# Get the paths.
self.pdb_folder_ = os.path.join(os.path.join(p_path, p_data_path))
pdb_fasta_file = os.path.join(p_path, p_fastafile)
# Load the sequences from the fasta file
self.list_chains_ = {}
def process_fasta_file(fasta_file, folder):
with open(fasta_file, 'r') as my_fasta_file:
chain_name = ''
for cur_line in my_fasta_file.readlines():
if cur_line.startswith('>'):
chain_name = cur_line.rstrip()[1:]
else:
cur_chain = cur_line.rstrip()
cur_chain_ids = []
for cur_aa in cur_chain:
cur_chain_ids.append(AA_TO_ID[cur_aa])
self.list_chains_[chain_name] = (np.array(cur_chain_ids), folder)
process_fasta_file(pdb_fasta_file, self.pdb_folder_)
# load datapoints
self.datapoints_ = []
with open(os.path.join(p_path, p_dataset+'.csv'), 'r') as labels_map_file:
for cur_line in labels_map_file:
line_split = cur_line.rstrip().split(',')
line_split[2] = float(line_split[2])
line_split[3] = float(line_split[3]) if line_split[3] else 1 # set default weight if not available
# Cross-validation row selection
if p_fold and (line_split[4] == p_fold) == p_train_mode:
continue # do not include this fold
self.datapoints_.append(line_split[:4]) # orig_pdb, mut_pdb, label, weight
if p_load_data:
self.data_ = []
print()
for cur_iter, cur_chain in enumerate(self.list_chains_):
cur_path = os.path.join(cur_chain[2], cur_chain[0]+".npy")
cur_pos_seq_path = os.path.join(cur_chain[2], cur_chain[0]+"_seq_pos.npy")
# cur_axes_path = os.path.join(cur_chain[2], cur_chain[0]+"_axes.npy")
cur_aces_path = []
self.data_.append((np.load(cur_path), np.load(cur_pos_seq_path), np.load(cur_axes_path)))
if cur_iter%100==0:
print("\r Loading {:6d}/{:6d}".format(cur_iter, len(self.list_chains_)), end ="")
print()
else:
self.data_ = None
def __len__(self):
return len(self.datapoints_)
def __getitem__(self, idx):
orig_pdb, mut_pdb, label, weight = self.datapoints_[idx]
orig_path = os.path.join(self.list_chains_[orig_pdb][1], orig_pdb +".npy")
mut_path = os.path.join(self.list_chains_[mut_pdb][1], mut_pdb + ".npy")
# cur_pos_seq_path = os.path.join(self.list_chains_[idx][2], self.list_chains_[idx][0]+"_seq_pos.npy")
# cur_axes_path = os.path.join(self.list_chains_[idx][2], self.list_chains_[idx][0]+"_axes.npy")
cur_axes_path = []
noise = None
def get_pdb(idx, cur_path, label: int):
nonlocal noise
cur_aa_ids = self.list_chains_[idx][0]
if self.data_ is None:
cur_pos = np.load(cur_path)
# cur_seq_pos = np.load(cur_pos_seq_path)
cur_seq_pos = list(range(len(cur_aa_ids)))
cur_axes = []
else:
cur_pos = self.data_[idx][0]
cur_seq_pos = self.data_[idx][1]
cur_axes = self.data_[idx][2]
cur_min = np.amin(cur_pos, axis=0, keepdims=True)
cur_max = np.amax(cur_pos, axis=0, keepdims=True)
center = (cur_max + cur_min)*0.5
cur_pos = cur_pos - center
if self.data_augment_:
if noise is None or not self.batch_pairs_:
noise = self.random_state_.normal(0.0, 0.05, cur_pos.shape)
assert cur_pos.shape == noise.shape
# print(cur_pos)
cur_pos = cur_pos + noise
# print(cur_pos)
return idx, cur_aa_ids, cur_pos, cur_seq_pos, cur_axes
return get_pdb(orig_pdb, orig_path, label), get_pdb(mut_pdb, mut_path, label), label, weight