Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import random | |
| import math | |
| from sklearn.metrics import * | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| import pickle | |
| def word2idx(word, words): | |
| if word in words.keys(): | |
| return int(words[word]) | |
| return 0 | |
| def pad_seq(dataset, max_len): | |
| output = [] | |
| for seq in dataset: | |
| pad = np.zeros(max_len) | |
| pad[:len(seq)] = seq | |
| output.append(pad) | |
| return np.array(output) | |
| def str2bool(seq): | |
| out = [] | |
| for s in seq: | |
| if s == "positive": | |
| out.append(1) | |
| elif s == "negative": | |
| out.append(0) | |
| return np.array(out) | |
| class API_Dataset(Dataset): | |
| def __init__(self, apta, esm_prot, y, apta_attn_mask, prot_attn_mask): | |
| super(Dataset, self).__init__() | |
| self.apta = np.array(apta, dtype=np.int64) | |
| self.esm_prot = np.array(esm_prot, dtype=np.int64) | |
| self.y = np.array(y, dtype=np.int64) | |
| self.apta_attn_mask = np.array(apta_attn_mask) | |
| self.prot_attn_mask = np.array(prot_attn_mask) | |
| self.len = len(self.apta) | |
| def __len__(self): | |
| return self.len | |
| def __getitem__(self, index): | |
| return torch.tensor(self.apta[index], dtype=torch.int64), torch.tensor(self.esm_prot[index], dtype=torch.int64), torch.tensor(self.y[index], dtype=torch.int64), torch.tensor(self.apta_attn_mask[index], dtype=torch.int64), torch.tensor(self.prot_attn_mask[index], dtype=torch.int64) | |
| def find_opt_threshold(target, pred): | |
| result = 0 | |
| best = 0 | |
| for i in range(0, 1000): | |
| pred_threshold = np.where(pred > i/1000, 1, 0) | |
| now = f1_score(target, pred_threshold) | |
| if now > best: | |
| result = i/1000 | |
| best = now | |
| return result | |
| def argument_seqset(seqset): | |
| arg_seqset = [] | |
| for s, ss in seqset: | |
| arg_seqset.append([s, ss]) | |
| arg_seqset.append([s[::-1], ss[::-1]]) | |
| return arg_seqset | |
| def augment_apis(apta, prot, ys): | |
| aug_apta = [] | |
| aug_prot = [] | |
| aug_y = [] | |
| for a, p, y in zip(apta, prot, ys): | |
| aug_apta.append(a) | |
| aug_prot.append(p) | |
| aug_y.append(y) | |
| aug_apta.append(a[::-1]) | |
| aug_prot.append(p) | |
| aug_y.append(y) | |
| return np.array(aug_apta), np.array(aug_prot), np.array(aug_y) | |
| def load_data_source(filepath): | |
| with open(filepath,"rb") as fr: | |
| dataset = pickle.load(fr) | |
| dataset_train = np.array(dataset[dataset["dataset"]=="training dataset"]) | |
| dataset_test = np.array(dataset[dataset["dataset"]=="test dataset"]) | |
| dataset_bench = np.array(dataset[dataset['dataset']=='benchmark dataset']) | |
| return dataset_train, dataset_test, dataset_bench | |
| def get_dataset(filepath, prot_max_len, n_prot_vocabs, prot_words): | |
| dataset_train, dataset_test, dataset_bench = load_data_source(filepath) | |
| arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]) | |
| datasets_train = [rna2vec(arg_apta), tokenize_sequences(arg_prot, prot_max_len, n_prot_vocabs, prot_words), str2bool(arg_y)] | |
| datasets_test = [rna2vec(dataset_test[:, 0]), tokenize_sequences(dataset_test[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_test[:, 2])] | |
| datasets_bench = [rna2vec(dataset_bench[:, 0]), tokenize_sequences(dataset_bench[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_bench[:, 2])] | |
| return datasets_train, datasets_test, datasets_bench | |
| def get_esm_dataset(filepath, batch_converter, alphabet): | |
| dataset_train, dataset_test, dataset_bench = load_data_source(filepath) | |
| # arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]) | |
| # arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label | |
| arg_apta, arg_prot, arg_y = dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2] | |
| arg_apta, arg_prot, arg_y = augment_apis(arg_apta, arg_prot, arg_y) | |
| train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)] | |
| _, _, prot_tokens = batch_converter(train_inputs) | |
| datasets_train = [rna2vec(arg_apta), prot_tokens, str2bool(arg_y)] | |
| test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])] | |
| _, _, test_prot_tokens = batch_converter(test_inputs) | |
| datasets_test = [rna2vec(dataset_test[:, 0]), test_prot_tokens, str2bool(dataset_test[:, 2])] | |
| bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])] | |
| _, _, bench_prot_tokens = batch_converter(bench_inputs) | |
| # truncating | |
| bench_prot_tokenized = bench_prot_tokens[:, :1678] | |
| # padding | |
| prot_ex = torch.ones((bench_prot_tokenized.shape[0], 1678), dtype=torch.int64)*alphabet.padding_idx | |
| prot_ex[:, :bench_prot_tokenized.shape[1]] = bench_prot_tokenized | |
| datasets_bench = [rna2vec(dataset_bench[:, 0]), prot_ex, str2bool(dataset_bench[:, 2])] | |
| return datasets_train, datasets_test, datasets_bench | |
| def get_nt_esm_dataset(filepath, nt_tokenizer, batch_converter, alphabet): | |
| dataset_train, dataset_test, dataset_bench = load_data_source(filepath) | |
| arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]) | |
| # arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label | |
| max_length = 275#nt_tokenizer.model_max_length | |
| train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)] | |
| _, _, prot_tokens = batch_converter(train_inputs) | |
| apta_toks = nt_tokenizer.batch_encode_plus(arg_apta, return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] | |
| apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id | |
| prot_attention_mask = prot_tokens != alphabet.padding_idx | |
| # datasets_train = [apta_toks, prot_tokens, str2bool(arg_y)] | |
| datasets_train = [apta_toks, prot_tokens, str2bool(arg_y), apta_attention_mask, prot_attention_mask] | |
| test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])] | |
| _, _, test_prot_tokens = batch_converter(test_inputs) | |
| prot_ex = torch.ones((test_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx | |
| prot_ex[:, :test_prot_tokens.shape[1]] = test_prot_tokens | |
| apta_toks = nt_tokenizer.batch_encode_plus(dataset_test[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] | |
| apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id | |
| prot_attention_mask = prot_ex != alphabet.padding_idx | |
| datasets_test = [apta_toks, prot_ex, str2bool(dataset_test[:, 2]), apta_attention_mask, prot_attention_mask] | |
| bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])] | |
| _, _, bench_prot_tokens = batch_converter(bench_inputs) | |
| # padding | |
| prot_ex = torch.ones((bench_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx | |
| prot_ex[:, :bench_prot_tokens.shape[1]] = bench_prot_tokens | |
| apta_toks = nt_tokenizer.batch_encode_plus(dataset_bench[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] | |
| apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id | |
| prot_attention_mask = prot_ex != alphabet.padding_idx | |
| datasets_bench = [apta_toks, prot_ex, str2bool(dataset_bench[:, 2]), apta_attention_mask, prot_attention_mask] | |
| return datasets_train, datasets_test, datasets_bench | |
| def get_scores(target, pred): | |
| threshold = find_opt_threshold(target, pred) | |
| pred_threshold = np.where(pred > threshold, 1, 0) | |
| acc = accuracy_score(target, pred_threshold) | |
| roc_auc = roc_auc_score(target, pred) | |
| mcc = matthews_corrcoef(target, pred_threshold) | |
| f1 = f1_score(target, pred_threshold) | |
| pr_auc = average_precision_score(target, pred) | |
| cls_report = classification_report(target, pred_threshold) | |
| scores = { | |
| 'threshold': threshold, | |
| 'acc': acc, | |
| 'roc_auc': roc_auc, | |
| 'mcc': mcc, | |
| 'f1': f1, | |
| 'pr_auc': pr_auc, | |
| 'cls_report': cls_report | |
| } | |
| return scores | |