import csv import os import random import time import numpy as np import torch from dateutil import parser from torch.utils.data import DataLoader class StructureDataset: def __init__( self, pdb_dict_list, verbose=True, truncate=None, max_length=100, alphabet="ACDEFGHIKLMNPQRSTVWYX", ): alphabet_set = set([a for a in alphabet]) discard_count = {"bad_chars": 0, "too_long": 0, "bad_seq_length": 0} self.data = [] start = time.time() for i, entry in enumerate(pdb_dict_list): seq = entry["seq"] name = entry["name"] bad_chars = set([s for s in seq]).difference(alphabet_set) if len(bad_chars) == 0: if len(entry["seq"]) <= max_length: self.data.append(entry) else: discard_count["too_long"] += 1 else: # print(name, bad_chars, entry['seq']) discard_count["bad_chars"] += 1 # Truncate early if truncate is not None and len(self.data) == truncate: return if verbose and (i + 1) % 1000 == 0: elapsed = time.time() - start # print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed)) # print('Discarded', discard_count) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class StructureLoader: def __init__( self, dataset, batch_size=100, shuffle=True, collate_fn=lambda x: x, drop_last=False, ): self.dataset = dataset self.size = len(dataset) self.lengths = [len(dataset[i]["seq"]) for i in range(self.size)] self.batch_size = batch_size sorted_ix = np.argsort(self.lengths) # Cluster into batches of similar sizes clusters, batch = [], [] batch_max = 0 for ix in sorted_ix: size = self.lengths[ix] if size * (len(batch) + 1) <= self.batch_size: batch.append(ix) batch_max = size else: clusters.append(batch) batch, batch_max = [], 0 if len(batch) > 0: clusters.append(batch) self.clusters = clusters def __len__(self): return len(self.clusters) def __iter__(self): np.random.shuffle(self.clusters) for b_idx in self.clusters: batch = [self.dataset[i] for i in b_idx] yield batch def worker_init_fn(worker_id): np.random.seed() class NoamOpt: "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer, step): self.optimizer = optimizer self._step = step self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate self.optimizer.step() def rate(self, step=None): "Implement `lrate` above" if step is None: step = self._step return self.factor * ( self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) ) def zero_grad(self): self.optimizer.zero_grad() def get_std_opt(parameters, d_model, step): return NoamOpt( d_model, 2, 4000, torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), step, ) def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000): init_alphabet = [ "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", ] extra_alphabet = [str(item) for item in list(np.arange(300))] chain_alphabet = init_alphabet + extra_alphabet c = 0 c1 = 0 pdb_dict_list = [] t0 = time.time() for _ in range(repeat): for step, t in enumerate(data_loader): t = {k: v[0] for k, v in t.items()} c1 += 1 if "label" in list(t): my_dict = {} s = 0 concat_seq = "" concat_N = [] concat_CA = [] concat_C = [] concat_O = [] concat_mask = [] coords_dict = {} mask_list = [] visible_list = [] if len(list(np.unique(t["idx"]))) < 352: for idx in list(np.unique(t["idx"])): letter = chain_alphabet[idx] res = np.argwhere(t["idx"] == idx) initial_sequence = "".join( list( np.array(list(t["seq"]))[res][ 0, ] ) ) if initial_sequence[-6:] == "HHHHHH": res = res[:, :-6] if initial_sequence[0:6] == "HHHHHH": res = res[:, 6:] if initial_sequence[-7:-1] == "HHHHHH": res = res[:, :-7] if initial_sequence[-8:-2] == "HHHHHH": res = res[:, :-8] if initial_sequence[-9:-3] == "HHHHHH": res = res[:, :-9] if initial_sequence[-10:-4] == "HHHHHH": res = res[:, :-10] if initial_sequence[1:7] == "HHHHHH": res = res[:, 7:] if initial_sequence[2:8] == "HHHHHH": res = res[:, 8:] if initial_sequence[3:9] == "HHHHHH": res = res[:, 9:] if initial_sequence[4:10] == "HHHHHH": res = res[:, 10:] if res.shape[1] < 4: pass else: my_dict["seq_chain_" + letter] = "".join( list( np.array(list(t["seq"]))[res][ 0, ] ) ) concat_seq += my_dict["seq_chain_" + letter] if idx in t["masked"]: mask_list.append(letter) else: visible_list.append(letter) coords_dict_chain = {} all_atoms = np.array(t["xyz"][res,])[ 0, ] # [L, 14, 3] coords_dict_chain["N_chain_" + letter] = all_atoms[ :, 0, : ].tolist() coords_dict_chain["CA_chain_" + letter] = all_atoms[ :, 1, : ].tolist() coords_dict_chain["C_chain_" + letter] = all_atoms[ :, 2, : ].tolist() coords_dict_chain["O_chain_" + letter] = all_atoms[ :, 3, : ].tolist() my_dict["coords_chain_" + letter] = coords_dict_chain my_dict["name"] = t["label"] my_dict["masked_list"] = mask_list my_dict["visible_list"] = visible_list my_dict["num_of_chains"] = len(mask_list) + len(visible_list) my_dict["seq"] = concat_seq if len(concat_seq) <= max_length: pdb_dict_list.append(my_dict) if len(pdb_dict_list) >= num_units: break return pdb_dict_list class PDB_dataset(torch.utils.data.Dataset): def __init__(self, IDs, loader, train_dict, params): self.IDs = IDs self.train_dict = train_dict self.loader = loader self.params = params def __len__(self): return len(self.IDs) def __getitem__(self, index): ID = self.IDs[index] sel_idx = np.random.randint(0, len(self.train_dict[ID])) out = self.loader(self.train_dict[ID][sel_idx], self.params) return out def loader_pdb(item, params): pdbid, chid = item[0].split("_") PREFIX = "%s/pdb/%s/%s" % (params["DIR"], pdbid[1:3], pdbid) # load metadata if not os.path.isfile(PREFIX + ".pt"): return {"seq": np.zeros(5)} meta = torch.load(PREFIX + ".pt") asmb_ids = meta["asmb_ids"] asmb_chains = meta["asmb_chains"] chids = np.array(meta["chains"]) # find candidate assemblies which contain chid chain asmb_candidates = set( [a for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(",")] ) # if the chains is missing is missing from all the assemblies # then return this chain alone if len(asmb_candidates) < 1: chain = torch.load("%s_%s.pt" % (PREFIX, chid)) L = len(chain["seq"]) return { "seq": chain["seq"], "xyz": chain["xyz"], "idx": torch.zeros(L).int(), "masked": torch.Tensor([0]).int(), "label": item[0], } # randomly pick one assembly from candidates asmb_i = random.sample(list(asmb_candidates), 1) # indices of selected transforms idx = np.where(np.array(asmb_ids) == asmb_i)[0] # load relevant chains chains = { c: torch.load("%s_%s.pt" % (PREFIX, c)) for i in idx for c in asmb_chains[i] if c in meta["chains"] } # generate assembly asmb = {} for k in idx: # pick k-th xform xform = meta["asmb_xform%d" % k] u = xform[:, :3, :3] r = xform[:, :3, 3] # select chains which k-th xform should be applied to s1 = set(meta["chains"]) s2 = set(asmb_chains[k].split(",")) chains_k = s1 & s2 # transform selected chains for c in chains_k: try: xyz = chains[c]["xyz"] xyz_ru = torch.einsum("bij,raj->brai", u, xyz) + r[:, None, None, :] asmb.update({(c, k, i): xyz_i for i, xyz_i in enumerate(xyz_ru)}) except KeyError: return {"seq": np.zeros(5)} # select chains which share considerable similarity to chid seqid = meta["tm"][chids == chid][0, :, 1] homo = set( [ch_j for seqid_j, ch_j in zip(seqid, chids) if seqid_j > params["HOMO"]] ) # stack all chains in the assembly together seq, xyz, idx, masked = "", [], [], [] seq_list = [] for counter, (k, v) in enumerate(asmb.items()): seq += chains[k[0]]["seq"] seq_list.append(chains[k[0]]["seq"]) xyz.append(v) idx.append(torch.full((v.shape[0],), counter)) if k[0] in homo: masked.append(counter) return { "seq": seq, "xyz": torch.cat(xyz, dim=0), "idx": torch.cat(idx, dim=0), "masked": torch.Tensor(masked).int(), "label": item[0], } def build_training_clusters(params, debug): val_ids = set([int(l) for l in open(params["VAL"]).readlines()]) test_ids = set([int(l) for l in open(params["TEST"]).readlines()]) if debug: val_ids = [] test_ids = [] # read & clean list.csv with open(params["LIST"], "r") as f: reader = csv.reader(f) next(reader) rows = [ [r[0], r[3], int(r[4])] for r in reader if float(r[2]) <= params["RESCUT"] and parser.parse(r[1]) <= parser.parse(params["DATCUT"]) ] # compile training and validation sets train = {} valid = {} test = {} if debug: rows = rows[:20] for r in rows: if r[2] in val_ids: if r[2] in valid.keys(): valid[r[2]].append(r[:2]) else: valid[r[2]] = [r[:2]] elif r[2] in test_ids: if r[2] in test.keys(): test[r[2]].append(r[:2]) else: test[r[2]] = [r[:2]] else: if r[2] in train.keys(): train[r[2]].append(r[:2]) else: train[r[2]] = [r[:2]] if debug: valid = train return train, valid, test