| import csv |
| import os |
| import random |
| import time |
|
|
| import numpy as np |
| import torch |
| from dateutil import parser |
| from torch.utils.data import DataLoader |
|
|
|
|
| class StructureDataset: |
| def __init__( |
| self, |
| pdb_dict_list, |
| verbose=True, |
| truncate=None, |
| max_length=100, |
| alphabet="ACDEFGHIKLMNPQRSTVWYX", |
| ): |
| alphabet_set = set([a for a in alphabet]) |
| discard_count = {"bad_chars": 0, "too_long": 0, "bad_seq_length": 0} |
|
|
| self.data = [] |
|
|
| start = time.time() |
| for i, entry in enumerate(pdb_dict_list): |
| seq = entry["seq"] |
| name = entry["name"] |
|
|
| bad_chars = set([s for s in seq]).difference(alphabet_set) |
| if len(bad_chars) == 0: |
| if len(entry["seq"]) <= max_length: |
| self.data.append(entry) |
| else: |
| discard_count["too_long"] += 1 |
| else: |
| |
| discard_count["bad_chars"] += 1 |
|
|
| |
| if truncate is not None and len(self.data) == truncate: |
| return |
|
|
| if verbose and (i + 1) % 1000 == 0: |
| elapsed = time.time() - start |
| |
|
|
| |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
|
|
| class StructureLoader: |
| def __init__( |
| self, |
| dataset, |
| batch_size=100, |
| shuffle=True, |
| collate_fn=lambda x: x, |
| drop_last=False, |
| ): |
| self.dataset = dataset |
| self.size = len(dataset) |
| self.lengths = [len(dataset[i]["seq"]) for i in range(self.size)] |
| self.batch_size = batch_size |
| sorted_ix = np.argsort(self.lengths) |
|
|
| |
| clusters, batch = [], [] |
| batch_max = 0 |
| for ix in sorted_ix: |
| size = self.lengths[ix] |
| if size * (len(batch) + 1) <= self.batch_size: |
| batch.append(ix) |
| batch_max = size |
| else: |
| clusters.append(batch) |
| batch, batch_max = [], 0 |
| if len(batch) > 0: |
| clusters.append(batch) |
| self.clusters = clusters |
|
|
| def __len__(self): |
| return len(self.clusters) |
|
|
| def __iter__(self): |
| np.random.shuffle(self.clusters) |
| for b_idx in self.clusters: |
| batch = [self.dataset[i] for i in b_idx] |
| yield batch |
|
|
|
|
| def worker_init_fn(worker_id): |
| np.random.seed() |
|
|
|
|
| class NoamOpt: |
| "Optim wrapper that implements rate." |
|
|
| def __init__(self, model_size, factor, warmup, optimizer, step): |
| self.optimizer = optimizer |
| self._step = step |
| self.warmup = warmup |
| self.factor = factor |
| self.model_size = model_size |
| self._rate = 0 |
|
|
| @property |
| def param_groups(self): |
| """Return param_groups.""" |
| return self.optimizer.param_groups |
|
|
| def step(self): |
| "Update parameters and rate" |
| self._step += 1 |
| rate = self.rate() |
| for p in self.optimizer.param_groups: |
| p["lr"] = rate |
| self._rate = rate |
| self.optimizer.step() |
|
|
| def rate(self, step=None): |
| "Implement `lrate` above" |
| if step is None: |
| step = self._step |
| return self.factor * ( |
| self.model_size ** (-0.5) |
| * min(step ** (-0.5), step * self.warmup ** (-1.5)) |
| ) |
|
|
| def zero_grad(self): |
| self.optimizer.zero_grad() |
|
|
|
|
| def get_std_opt(parameters, d_model, step): |
| return NoamOpt( |
| d_model, |
| 2, |
| 4000, |
| torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), |
| step, |
| ) |
|
|
|
|
| def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000): |
| init_alphabet = [ |
| "A", |
| "B", |
| "C", |
| "D", |
| "E", |
| "F", |
| "G", |
| "H", |
| "I", |
| "J", |
| "K", |
| "L", |
| "M", |
| "N", |
| "O", |
| "P", |
| "Q", |
| "R", |
| "S", |
| "T", |
| "U", |
| "V", |
| "W", |
| "X", |
| "Y", |
| "Z", |
| "a", |
| "b", |
| "c", |
| "d", |
| "e", |
| "f", |
| "g", |
| "h", |
| "i", |
| "j", |
| "k", |
| "l", |
| "m", |
| "n", |
| "o", |
| "p", |
| "q", |
| "r", |
| "s", |
| "t", |
| "u", |
| "v", |
| "w", |
| "x", |
| "y", |
| "z", |
| ] |
| extra_alphabet = [str(item) for item in list(np.arange(300))] |
| chain_alphabet = init_alphabet + extra_alphabet |
| c = 0 |
| c1 = 0 |
| pdb_dict_list = [] |
| t0 = time.time() |
| for _ in range(repeat): |
| for step, t in enumerate(data_loader): |
| t = {k: v[0] for k, v in t.items()} |
| c1 += 1 |
| if "label" in list(t): |
| my_dict = {} |
| s = 0 |
| concat_seq = "" |
| concat_N = [] |
| concat_CA = [] |
| concat_C = [] |
| concat_O = [] |
| concat_mask = [] |
| coords_dict = {} |
| mask_list = [] |
| visible_list = [] |
| if len(list(np.unique(t["idx"]))) < 352: |
| for idx in list(np.unique(t["idx"])): |
| letter = chain_alphabet[idx] |
| res = np.argwhere(t["idx"] == idx) |
| initial_sequence = "".join( |
| list( |
| np.array(list(t["seq"]))[res][ |
| 0, |
| ] |
| ) |
| ) |
| if initial_sequence[-6:] == "HHHHHH": |
| res = res[:, :-6] |
| if initial_sequence[0:6] == "HHHHHH": |
| res = res[:, 6:] |
| if initial_sequence[-7:-1] == "HHHHHH": |
| res = res[:, :-7] |
| if initial_sequence[-8:-2] == "HHHHHH": |
| res = res[:, :-8] |
| if initial_sequence[-9:-3] == "HHHHHH": |
| res = res[:, :-9] |
| if initial_sequence[-10:-4] == "HHHHHH": |
| res = res[:, :-10] |
| if initial_sequence[1:7] == "HHHHHH": |
| res = res[:, 7:] |
| if initial_sequence[2:8] == "HHHHHH": |
| res = res[:, 8:] |
| if initial_sequence[3:9] == "HHHHHH": |
| res = res[:, 9:] |
| if initial_sequence[4:10] == "HHHHHH": |
| res = res[:, 10:] |
| if res.shape[1] < 4: |
| pass |
| else: |
| my_dict["seq_chain_" + letter] = "".join( |
| list( |
| np.array(list(t["seq"]))[res][ |
| 0, |
| ] |
| ) |
| ) |
| concat_seq += my_dict["seq_chain_" + letter] |
| if idx in t["masked"]: |
| mask_list.append(letter) |
| else: |
| visible_list.append(letter) |
| coords_dict_chain = {} |
| all_atoms = np.array(t["xyz"][res,])[ |
| 0, |
| ] |
| coords_dict_chain["N_chain_" + letter] = all_atoms[ |
| :, 0, : |
| ].tolist() |
| coords_dict_chain["CA_chain_" + letter] = all_atoms[ |
| :, 1, : |
| ].tolist() |
| coords_dict_chain["C_chain_" + letter] = all_atoms[ |
| :, 2, : |
| ].tolist() |
| coords_dict_chain["O_chain_" + letter] = all_atoms[ |
| :, 3, : |
| ].tolist() |
| my_dict["coords_chain_" + letter] = coords_dict_chain |
| my_dict["name"] = t["label"] |
| my_dict["masked_list"] = mask_list |
| my_dict["visible_list"] = visible_list |
| my_dict["num_of_chains"] = len(mask_list) + len(visible_list) |
| my_dict["seq"] = concat_seq |
| if len(concat_seq) <= max_length: |
| pdb_dict_list.append(my_dict) |
| if len(pdb_dict_list) >= num_units: |
| break |
| return pdb_dict_list |
|
|
|
|
| class PDB_dataset(torch.utils.data.Dataset): |
| def __init__(self, IDs, loader, train_dict, params): |
| self.IDs = IDs |
| self.train_dict = train_dict |
| self.loader = loader |
| self.params = params |
|
|
| def __len__(self): |
| return len(self.IDs) |
|
|
| def __getitem__(self, index): |
| ID = self.IDs[index] |
| sel_idx = np.random.randint(0, len(self.train_dict[ID])) |
| out = self.loader(self.train_dict[ID][sel_idx], self.params) |
| return out |
|
|
|
|
| def loader_pdb(item, params): |
|
|
| pdbid, chid = item[0].split("_") |
| PREFIX = "%s/pdb/%s/%s" % (params["DIR"], pdbid[1:3], pdbid) |
|
|
| |
| if not os.path.isfile(PREFIX + ".pt"): |
| return {"seq": np.zeros(5)} |
| meta = torch.load(PREFIX + ".pt") |
| asmb_ids = meta["asmb_ids"] |
| asmb_chains = meta["asmb_chains"] |
| chids = np.array(meta["chains"]) |
|
|
| |
| asmb_candidates = set( |
| [a for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(",")] |
| ) |
|
|
| |
| |
| if len(asmb_candidates) < 1: |
| chain = torch.load("%s_%s.pt" % (PREFIX, chid)) |
| L = len(chain["seq"]) |
| return { |
| "seq": chain["seq"], |
| "xyz": chain["xyz"], |
| "idx": torch.zeros(L).int(), |
| "masked": torch.Tensor([0]).int(), |
| "label": item[0], |
| } |
|
|
| |
| asmb_i = random.sample(list(asmb_candidates), 1) |
|
|
| |
| idx = np.where(np.array(asmb_ids) == asmb_i)[0] |
|
|
| |
| chains = { |
| c: torch.load("%s_%s.pt" % (PREFIX, c)) |
| for i in idx |
| for c in asmb_chains[i] |
| if c in meta["chains"] |
| } |
|
|
| |
| asmb = {} |
| for k in idx: |
|
|
| |
| xform = meta["asmb_xform%d" % k] |
| u = xform[:, :3, :3] |
| r = xform[:, :3, 3] |
|
|
| |
| s1 = set(meta["chains"]) |
| s2 = set(asmb_chains[k].split(",")) |
| chains_k = s1 & s2 |
|
|
| |
| for c in chains_k: |
| try: |
| xyz = chains[c]["xyz"] |
| xyz_ru = torch.einsum("bij,raj->brai", u, xyz) + r[:, None, None, :] |
| asmb.update({(c, k, i): xyz_i for i, xyz_i in enumerate(xyz_ru)}) |
| except KeyError: |
| return {"seq": np.zeros(5)} |
|
|
| |
| seqid = meta["tm"][chids == chid][0, :, 1] |
| homo = set( |
| [ch_j for seqid_j, ch_j in zip(seqid, chids) if seqid_j > params["HOMO"]] |
| ) |
| |
| seq, xyz, idx, masked = "", [], [], [] |
| seq_list = [] |
| for counter, (k, v) in enumerate(asmb.items()): |
| seq += chains[k[0]]["seq"] |
| seq_list.append(chains[k[0]]["seq"]) |
| xyz.append(v) |
| idx.append(torch.full((v.shape[0],), counter)) |
| if k[0] in homo: |
| masked.append(counter) |
|
|
| return { |
| "seq": seq, |
| "xyz": torch.cat(xyz, dim=0), |
| "idx": torch.cat(idx, dim=0), |
| "masked": torch.Tensor(masked).int(), |
| "label": item[0], |
| } |
|
|
|
|
| def build_training_clusters(params, debug): |
| val_ids = set([int(l) for l in open(params["VAL"]).readlines()]) |
| test_ids = set([int(l) for l in open(params["TEST"]).readlines()]) |
|
|
| if debug: |
| val_ids = [] |
| test_ids = [] |
|
|
| |
| with open(params["LIST"], "r") as f: |
| reader = csv.reader(f) |
| next(reader) |
| rows = [ |
| [r[0], r[3], int(r[4])] |
| for r in reader |
| if float(r[2]) <= params["RESCUT"] |
| and parser.parse(r[1]) <= parser.parse(params["DATCUT"]) |
| ] |
|
|
| |
| train = {} |
| valid = {} |
| test = {} |
|
|
| if debug: |
| rows = rows[:20] |
| for r in rows: |
| if r[2] in val_ids: |
| if r[2] in valid.keys(): |
| valid[r[2]].append(r[:2]) |
| else: |
| valid[r[2]] = [r[:2]] |
| elif r[2] in test_ids: |
| if r[2] in test.keys(): |
| test[r[2]].append(r[:2]) |
| else: |
| test[r[2]] = [r[:2]] |
| else: |
| if r[2] in train.keys(): |
| train[r[2]].append(r[:2]) |
| else: |
| train[r[2]] = [r[:2]] |
| if debug: |
| valid = train |
| return train, valid, test |
|
|