| |
|
| |
|
| |
|
| | import math
|
| | import os, csv, json
|
| | import io, textwrap, itertools
|
| | import subprocess
|
| | from Bio import SeqIO
|
| | import torch
|
| | import numpy as np
|
| | import sys, random
|
| | from sklearn.metrics import confusion_matrix
|
| | import matplotlib.pyplot as plt
|
| | import pynvml, requests
|
| | from collections import OrderedDict
|
| |
|
| | plt.rcParams.update({'font.size': 18})
|
| | plt.rcParams['axes.unicode_minus'] = False
|
| |
|
| | from .file_operator import file_reader
|
| | from .multi_label_metrics import prob_2_pred, relevant_indexes, metrics_multi_label
|
| | from .metrics import metrics_multi_class, metrics_binary, metrics_regression
|
| |
|
| | common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
|
| |
|
| |
|
| |
|
| | common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
|
| |
|
| |
|
| | def to_device(device, batch):
|
| | '''
|
| | input to device
|
| | :param device:
|
| | :param batch:
|
| | :return:
|
| | '''
|
| | new_batch = {}
|
| | sample_num = 0
|
| | tens = None
|
| | for item1 in batch.items():
|
| | new_batch[item1[0]] = {}
|
| | if isinstance(item1[1], dict):
|
| | for item2 in item1[1].items():
|
| | new_batch[item1[0]][item2[0]] = {}
|
| | if isinstance(item2[1], dict):
|
| | for item3 in item2[1].items():
|
| | if item3[1] is not None and not isinstance(item3[1], int) and not isinstance(item3[1], str) and not isinstance(item3[1], float):
|
| | new_batch[item1[0]][item2[0]][item3[0]] = item3[1].to(device)
|
| | tens = item3[1]
|
| | else:
|
| | new_batch[item1[0]][item2[0]][item3[0]] = item3[1]
|
| | else:
|
| | if item2[1] is not None and not isinstance(item2[1], int) and not isinstance(item2[1], str) and not isinstance(item2[1], float):
|
| | new_batch[item1[0]][item2[0]] = item2[1].to(device)
|
| | tens = item2[1]
|
| | else:
|
| | new_batch[item1[0]][item2[0]] = item2[1]
|
| | else:
|
| | if item1[1] is not None and not isinstance(item1[1], int) and not isinstance(item1[1], str) and not isinstance(item1[1], float):
|
| | new_batch[item1[0]] = item1[1].to(device)
|
| | tens = item1[1]
|
| | else:
|
| | new_batch[item1[0]] = item1[1]
|
| | if tens is not None:
|
| | sample_num = tens.shape[0]
|
| | return new_batch, sample_num
|
| |
|
| |
|
| | def get_parameter_number(model):
|
| | '''
|
| | colc the parameter number of the model
|
| | :param model:
|
| | :return:
|
| | '''
|
| | param_size = 0
|
| | param_sum = 0
|
| | trainable_size = 0
|
| | trainable_num = 0
|
| | for param in model.parameters():
|
| | cur_size = param.nelement() * param.element_size()
|
| | cur_num = param.nelement()
|
| | param_size += cur_size
|
| | param_sum += cur_num
|
| | if param.requires_grad:
|
| | trainable_size += cur_size
|
| | trainable_num += cur_num
|
| | buffer_size = 0
|
| | buffer_sum = 0
|
| | for buffer in model.buffers():
|
| | buffer_size += buffer.nelement() * buffer.element_size()
|
| | buffer_sum += buffer.nelement()
|
| | '''
|
| | total_num = sum(p.numel() for p in model.parameters())
|
| | total_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
| | total_num += sum(p.numel() for p in model.buffers())
|
| | total_size += sum(p.numel() * p.element_size() for p in model.buffers())
|
| | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| | trainable_size = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
|
| | '''
|
| | return {
|
| | 'total_num': "%fM" % round((buffer_sum + param_sum)/(1024 * 1024), 2),
|
| | 'total_size': "%fMB" % round((buffer_size + param_size)/(1024 * 1024), 2),
|
| | 'param_sum': "%fM" % round(param_sum/(1024 * 1024), 2),
|
| | 'param_size': "%fMB" % round(param_size/(1024 * 1024), 2),
|
| | 'buffer_sum': "%fM" % round(buffer_sum/(1024 * 1024), 2),
|
| | 'buffer_size': "%fMB" % round(buffer_size/(1024 * 1024), 2),
|
| | 'trainable_num': "%fM" % round(trainable_num/(1024 * 1024), 10),
|
| | 'trainable_size': "%fMB" % round(trainable_size/(1024 * 1024), 10)
|
| | }
|
| |
|
| |
|
| | def set_seed(args):
|
| | random.seed(args.seed)
|
| | np.random.seed(args.seed)
|
| | torch.manual_seed(args.seed)
|
| | if args.n_gpu > 0:
|
| | torch.cuda.manual_seed(args.seed)
|
| | torch.cuda.manual_seed_all(args.seed)
|
| |
|
| |
|
| | def label_id_2_label_name(output_mode, label_list, prob, threshold=0.5):
|
| | '''
|
| | convect label id to label name
|
| | :param output_mode:
|
| | :param label_list:
|
| | :param prob:
|
| | :param threshold:
|
| | :return:
|
| | '''
|
| | if output_mode in ["multi-label", "multi_label"]:
|
| | res = []
|
| | pred = prob_2_pred(prob, threshold)
|
| | pred_index = relevant_indexes(pred)
|
| | for row in range(prob.shape[0]):
|
| | label_names = [label_list[idx] for idx in pred_index[row]]
|
| | res.append(label_names)
|
| | return res
|
| | elif output_mode in ["multi-class", "multi_class"]:
|
| | pred = np.argmax(prob, axis=1)
|
| | label_names = [label_list[idx] for idx in pred]
|
| | return label_names
|
| | elif output_mode in ["binary-class", "binary_class"]:
|
| | if prob.ndim == 2:
|
| | prob = prob.flatten(order="C")
|
| | pred = prob_2_pred(prob, threshold)
|
| | label_names = [label_list[idx] for idx in pred]
|
| | return label_names
|
| | else:
|
| | raise KeyError(output_mode)
|
| |
|
| |
|
| | def plot_bins(data, xlabel, ylabel, bins, filepath):
|
| | '''
|
| | plot bins
|
| | :param data:
|
| | :param xlabel:
|
| | :param ylabel:
|
| | :param bins: bins number
|
| | :param filepath: png save filepath
|
| | :return:
|
| | '''
|
| | plt.figure(figsize=(40, 20), dpi=100)
|
| | plt.hist(data, bins=bins)
|
| |
|
| |
|
| |
|
| | plt.xlabel(xlabel)
|
| | plt.ylabel(ylabel)
|
| | if filepath is None:
|
| | plt.show()
|
| | else:
|
| | plt.savefig(filepath)
|
| | plt.clf()
|
| | plt.close()
|
| |
|
| |
|
| | def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
|
| | '''
|
| | :param targets: ground truth
|
| | :param preds: prediction probs
|
| | :param cm: confusion matrix
|
| | :param savepath: confusion matrix picture savepth
|
| | '''
|
| |
|
| | plt.figure(figsize=(40, 20), dpi=100)
|
| | if cm is None:
|
| | cm = confusion_matrix(targets, preds, labels=[0, 1])
|
| |
|
| | plt.matshow(cm, cmap=plt.cm.Oranges)
|
| | plt.colorbar()
|
| |
|
| | for x in range(len(cm)):
|
| | for y in range(len(cm)):
|
| | plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
|
| | plt.ylabel('True')
|
| | plt.xlabel('Prediction')
|
| | if savepath:
|
| | plt.savefig(savepath, dpi=100)
|
| | else:
|
| | plt.show()
|
| | plt.close("all")
|
| |
|
| |
|
| | def save_labels(filepath, label_list):
|
| | '''
|
| | save labels
|
| | :param filepath:
|
| | :param label_list:
|
| | :return:
|
| | '''
|
| | with open(filepath, "w") as wfp:
|
| | wfp.write("label" + "\n")
|
| | for label in label_list:
|
| | wfp.write(label + "\n")
|
| |
|
| |
|
| | def load_labels(filepath, header=True):
|
| | '''
|
| | load labels
|
| | :param filepath:
|
| | :param header: where the file has header or not
|
| | :return:
|
| | '''
|
| | label_list = []
|
| | with open(filepath, "r") as rfp:
|
| | for label in rfp:
|
| | label_list.append(label.strip())
|
| | if len(label_list) > 0 and (header or label_list[0] == "label"):
|
| | return label_list[1:]
|
| | return label_list
|
| |
|
| |
|
| | def load_vocab(vocab_path):
|
| | '''
|
| | load vocab
|
| | :param vocab_path:
|
| | :return:
|
| | '''
|
| | vocab = {}
|
| | with open(vocab_path, "r") as rfp:
|
| | for line in rfp:
|
| | v = line.strip()
|
| | vocab[v] = len(vocab)
|
| | return vocab
|
| |
|
| |
|
| | def subprocess_popen(statement):
|
| | '''
|
| | execute shell cmd
|
| | :param statement:
|
| | :return:
|
| | '''
|
| | p = subprocess.Popen(statement, shell=True, stdout=subprocess.PIPE)
|
| | while p.poll() is None:
|
| | if p.wait() != 0:
|
| | print("fail.")
|
| | return False
|
| | else:
|
| | re = p.stdout.readlines()
|
| | result = []
|
| | for i in range(len(re)):
|
| | res = re[i].decode('utf-8').strip('\r\n')
|
| | result.append(res)
|
| | return result
|
| |
|
| |
|
| | def prepare_inputs(input_type, embedding_type, batch):
|
| | if input_type == "sequence":
|
| | inputs = {
|
| | "input_ids_a": batch[0],
|
| | "attention_mask_a": batch[1],
|
| | "token_type_ids_a": batch[2],
|
| | "input_ids_b": batch[4],
|
| | "attention_mask_b": batch[5],
|
| | "token_type_ids_b": batch[6],
|
| | "labels": batch[-1]
|
| | }
|
| | elif input_type == "embedding":
|
| | if embedding_type not in ["vector", "bos"]:
|
| | inputs = {
|
| | "embedding_info_a": batch[0],
|
| | "embedding_attention_mask_a": batch[1],
|
| | "embedding_info_b": batch[2],
|
| | "embedding_attention_mask_b": batch[3],
|
| | "labels": batch[-1]
|
| | }
|
| | else:
|
| | inputs = {
|
| | "embedding_info_a": batch[0],
|
| | "embedding_attention_mask_a": None,
|
| | "embedding_info_b": batch[1],
|
| | "embedding_attention_mask_b": None,
|
| | "labels": batch[-1]
|
| | }
|
| | elif input_type == "structure":
|
| | inputs = {
|
| | "struct_input_ids_a": batch[0],
|
| | "struct_contact_map_a": batch[1],
|
| | "struct_input_ids_b": batch[2],
|
| | "struct_contact_map_b": batch[3],
|
| | "labels": batch[-1]
|
| | }
|
| | elif input_type == "sefn":
|
| | if embedding_type not in ["vector", "bos"]:
|
| | inputs = {
|
| | "input_ids_a": batch[0],
|
| | "attention_mask_a": batch[1],
|
| | "token_type_ids_a": batch[2],
|
| | "embedding_info_a": batch[4],
|
| | "embedding_attention_mask_a": batch[5],
|
| | "input_ids_b": batch[6],
|
| | "attention_mask_b": batch[7],
|
| | "token_type_ids_b": batch[8],
|
| | "embedding_info_b": batch[10],
|
| | "embedding_attention_mask_b": batch[11],
|
| | "labels": batch[-1],
|
| | }
|
| | else:
|
| | inputs = {
|
| | "input_ids_a": batch[0],
|
| | "attention_mask_a": batch[1],
|
| | "token_type_ids_a": batch[2],
|
| | "embedding_info_a": batch[4],
|
| | "embedding_attention_mask_a": None,
|
| | "input_ids_b": batch[5],
|
| | "attention_mask_b": batch[6],
|
| | "token_type_ids_b": batch[7],
|
| | "embedding_info_b": batch[9],
|
| | "embedding_attention_mask_b": None,
|
| | "labels": batch[-1],
|
| | }
|
| | elif input_type == "ssfn":
|
| | inputs = {
|
| | "input_ids_a": batch[0],
|
| | "attention_mask_a": batch[1],
|
| | "token_type_ids_a": batch[2],
|
| | "struct_input_ids_a": batch[4],
|
| | "struct_contact_map_a": batch[5],
|
| | "input_ids_b": batch[6],
|
| | "attention_mask_b": batch[7],
|
| | "token_type_ids_b": batch[8],
|
| | "struct_input_ids_b": batch[10],
|
| | "struct_contact_map_b": batch[11],
|
| | "labels": batch[-1]
|
| | }
|
| | else:
|
| | inputs = None
|
| | return inputs
|
| |
|
| |
|
| | def gene_seq_replace_re(seq):
|
| | '''
|
| | Nucleic acid 还原
|
| | :param seq:
|
| | :return:
|
| | '''
|
| | new_seq = ""
|
| | for ch in seq:
|
| | if ch == '1':
|
| | new_seq += "A"
|
| | elif ch == '2':
|
| | new_seq += "T"
|
| | elif ch == '3':
|
| | new_seq += "C"
|
| | elif ch == '4':
|
| | new_seq += "G"
|
| | else:
|
| | new_seq += "N"
|
| | return new_seq
|
| |
|
| |
|
| | def gene_seq_replace(seq):
|
| | '''
|
| | Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
|
| | :param seq:
|
| | :return:
|
| | '''
|
| | new_seq = ""
|
| | for ch in seq:
|
| | if ch in ["A", "a"]:
|
| | new_seq += "1"
|
| | elif ch in ["T", "U", "t", "u"]:
|
| | new_seq += "2"
|
| | elif ch in ["C", "c"]:
|
| | new_seq += "3"
|
| | elif ch in ["G", "g"]:
|
| | new_seq += "4"
|
| | else:
|
| | new_seq += "5"
|
| | return new_seq
|
| |
|
| |
|
| | def get_labels(label_filepath, header=True):
|
| | '''
|
| | get labels from file, exists header
|
| | :param label_filepath:
|
| | :param header:
|
| | :return:
|
| | '''
|
| | with open(label_filepath, "r") as fp:
|
| | labels = []
|
| | multi_cols = False
|
| | cnt = 0
|
| | for line in fp:
|
| | line = line.strip()
|
| | cnt += 1
|
| | if cnt == 1 and (header or line == "label"):
|
| | if line.find(",") > 0:
|
| | multi_cols = True
|
| | continue
|
| | if multi_cols:
|
| | idx = line.find(",")
|
| | if idx > 0:
|
| | label_name = line[idx + 1:].strip()
|
| | else:
|
| | label_name = line
|
| | else:
|
| | label_name = line
|
| | labels.append(label_name)
|
| | return labels
|
| |
|
| |
|
| | def available_gpu_id():
|
| | '''
|
| | 计算可用的GPU id
|
| | :return:
|
| | '''
|
| | pynvml.nvmlInit()
|
| | if not torch.cuda.is_available():
|
| | print("GPU not available")
|
| | return -1
|
| |
|
| | device_count = pynvml.nvmlDeviceGetCount()
|
| | max_available_gpu = -1
|
| | max_available_rate = 0
|
| |
|
| |
|
| | for i in range(device_count):
|
| | handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
| | memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| | utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| |
|
| | if utilization.gpu < 10 and max_available_rate < 100 - utilization.gpu:
|
| | max_available_rate = 100 - utilization.gpu
|
| | max_available_gpu = i
|
| |
|
| | if max_available_gpu > -1:
|
| | print("Available GPU ID: %d, Free Rate: %0.2f%%" % (max_available_gpu, max_available_rate))
|
| | else:
|
| | print("No Available GPU!")
|
| |
|
| |
|
| | pynvml.nvmlShutdown()
|
| | return max_available_gpu
|
| |
|
| |
|
| | def eval_metrics(output_mode, truths, preds, threshold=0.5):
|
| | '''
|
| | eval metrics
|
| | :param output_mode:
|
| | :param truths:
|
| | :param preds:
|
| | :param threshold:
|
| | :return:
|
| | '''
|
| | print("\ntruths size: ", truths.shape)
|
| | print("\npreds size: ", preds.shape)
|
| | if output_mode in ["multi-label", "multi_label"]:
|
| | return metrics_multi_label(truths, preds, threshold=threshold)
|
| | elif output_mode in ["multi-class", "multi_class"]:
|
| | return metrics_multi_class(truths, preds)
|
| | elif output_mode == "regression":
|
| | return metrics_regression(truths, preds)
|
| | elif output_mode in ["binary-class", "binary_class"]:
|
| | return metrics_binary(truths, preds, threshold=threshold)
|
| | else:
|
| | raise Exception("Not Support this output mode: %s" % output_mode)
|
| |
|
| |
|
| | def load_trained_model(model_config, args, model_class, model_dirpath):
|
| |
|
| | print("load pretrained model: %s" % model_dirpath)
|
| | try:
|
| | model = model_class.from_pretrained(model_dirpath, args=args)
|
| | except Exception as e:
|
| | model = model_class(model_config, args=args)
|
| | pretrained_net_dict = torch.load(os.path.join(args.model_dirpath, "pytorch.pth"),
|
| | map_location=torch.device("cpu"))
|
| | model_state_dict_keys = set()
|
| | for key in model.state_dict():
|
| | model_state_dict_keys.add(key)
|
| | new_state_dict = OrderedDict()
|
| | for k, v in pretrained_net_dict.items():
|
| | if k.startswith("module."):
|
| |
|
| | name = k[7:]
|
| | else:
|
| | name = k
|
| | if name in model_state_dict_keys:
|
| | new_state_dict[name] = v
|
| |
|
| |
|
| | model.load_state_dict(new_state_dict)
|
| | return model
|
| |
|
| |
|
| | def clean_seq(protein_id, seq, return_rm_index=False):
|
| | seq = seq.upper()
|
| | new_seq = ""
|
| | has_invalid_char = False
|
| | invalid_char_set = set()
|
| | return_rm_index_set = set()
|
| | for idx, ch in enumerate(seq):
|
| | if 'A' <= ch <= 'Z' and ch not in ['J']:
|
| | new_seq += ch
|
| | else:
|
| | invalid_char_set.add(ch)
|
| | return_rm_index_set.add(idx)
|
| | has_invalid_char = True
|
| | if has_invalid_char:
|
| | print("id: %s. Seq: %s" % (protein_id, seq))
|
| | print("invalid char set:", invalid_char_set)
|
| | print("return_rm_index:", return_rm_index_set)
|
| | if return_rm_index:
|
| | return new_seq, return_rm_index_set
|
| | return new_seq
|
| |
|
| |
|
| | def sample_size(data_dirpath):
|
| | if os.path.isdir(data_dirpath):
|
| | new_filepaths = []
|
| | for filename in os.listdir(data_dirpath):
|
| | if not filename.startswith("."):
|
| | new_filepaths.append(os.path.join(data_dirpath, filename))
|
| | filepaths = new_filepaths
|
| | else:
|
| | filepaths = [data_dirpath]
|
| | total = 0
|
| | for filepath in filepaths:
|
| | header = filepath.endswith(".tsv") or filepath.endswith(".csv")
|
| | print("sample_size filepath: %s" % filepath)
|
| | for _ in file_reader(filepath, header=header, header_filter=True):
|
| | total += 1
|
| | return total
|
| |
|
| |
|
| | def writer_info_tb(tb_writer, logs, global_step, prefix=None):
|
| | '''
|
| | write info to tensorboard
|
| | :param tb_writer:
|
| | :param logs:
|
| | :param global_step:
|
| | :param prefix:
|
| | :return:
|
| | '''
|
| | for key, value in logs.items():
|
| | if isinstance(value, dict):
|
| | '''
|
| | for key1, value1 in value.items():
|
| | tb_writer.add_scalar(key + "_" + key1, value1, global_step)
|
| | '''
|
| | writer_info_tb(tb_writer, value, global_step, prefix=key)
|
| | elif not math.isnan(value) and not math.isinf(value):
|
| | tb_writer.add_scalar(prefix + "_" + key if prefix else key, value, global_step)
|
| | else:
|
| | print("writer_info_tb NaN or Inf, Key-Value: %s=%s" % (key, value))
|
| |
|
| |
|
| | def get_lr(optimizer):
|
| | '''
|
| | get learning rate
|
| | :param optimizer:
|
| | :return:
|
| | '''
|
| | for p in optimizer.param_groups:
|
| | if "lr" in p:
|
| | return p["lr"]
|
| |
|
| |
|
| | def metrics_merge(results, all_results):
|
| | '''
|
| | merge metrics
|
| | :param results:
|
| | :param all_results:
|
| | :return:
|
| | '''
|
| | for item1 in results.items():
|
| | if item1[0] not in all_results:
|
| | all_results[item1[0]] = {}
|
| | for item2 in item1[1].items():
|
| | if item2[0] not in all_results[item1[0]]:
|
| | all_results[item1[0]][item2[0]] = {}
|
| | for item3 in item2[1].items():
|
| | if item3[0] not in all_results[item1[0]][item2[0]]:
|
| | all_results[item1[0]][item2[0]][item3[0]] = item3[1]
|
| | else:
|
| | all_results[item1[0]][item2[0]][item3[0]] += item3[1]
|
| | return all_results
|
| |
|
| |
|
| | def print_shape(item):
|
| | '''
|
| | print shape
|
| | :param item:
|
| | :return:
|
| | '''
|
| | if isinstance(item, dict):
|
| | for item1 in item.items():
|
| | print(item1[0] + ":")
|
| | print_shape(item1[1])
|
| | elif isinstance(item, list):
|
| | for idx, item1 in enumerate(item):
|
| | print("idx: %d" % idx)
|
| | print_shape(item1)
|
| | else:
|
| | print("shape:", item.shape)
|
| |
|
| |
|
| | def process_outputs(output_mode, truth, pred, output_truth, output_pred, ignore_index, keep_seq=False):
|
| | if keep_seq:
|
| |
|
| | return None, None
|
| | else:
|
| | if output_mode in ["multi_class", "multi-class"]:
|
| | cur_truth = truth.view(-1)
|
| | cur_mask = cur_truth != ignore_index
|
| | cur_pred = pred.view(-1, pred.shape[-1])
|
| | cur_truth = cur_truth[cur_mask]
|
| | cur_pred = cur_pred[cur_mask, :]
|
| | sum_v = cur_mask.sum().item()
|
| | elif output_mode in ["multi_label", "multi-label"]:
|
| | cur_truth = truth.view(-1, truth.shape[-1])
|
| | cur_pred = pred.view(-1, pred.shape[-1])
|
| | sum_v = pred.shape[0]
|
| | elif output_mode in ["binary_class", "binary-class"]:
|
| | cur_truth = truth.view(-1)
|
| | cur_mask = cur_truth != ignore_index
|
| | cur_pred = pred.view(-1)
|
| | cur_truth = cur_truth[cur_mask]
|
| | cur_pred = cur_pred[cur_mask]
|
| | sum_v = cur_mask.sum().item()
|
| | elif output_mode in ["regression"]:
|
| | cur_truth = truth.view(-1)
|
| | cur_mask = cur_truth != ignore_index
|
| | cur_pred = pred.view(-1)
|
| | cur_truth = cur_truth[cur_mask]
|
| | cur_pred = cur_pred[cur_mask]
|
| | sum_v = cur_mask.sum().item()
|
| | else:
|
| | raise Exception("not output mode: %s" % output_mode)
|
| | if sum_v > 0:
|
| | cur_truth = cur_truth.detach().cpu().numpy()
|
| | cur_pred = cur_pred.detach().cpu().numpy()
|
| | if output_truth is None or output_pred is None:
|
| | return cur_truth, cur_pred
|
| | else:
|
| | output_truth = np.append(output_truth, cur_truth, axis=0)
|
| | output_pred = np.append(output_pred, cur_pred, axis=0)
|
| | return output_truth, output_pred
|
| | return truth, pred
|
| |
|
| |
|
| | def print_batch(value, key=None, debug_path=None, wfp=None, local_rank=-1):
|
| | '''
|
| | print a batch
|
| | :param value:
|
| | :param key:
|
| | :param debug_path:
|
| | :param wfp:
|
| | :param local_rank:
|
| | :return:
|
| | '''
|
| | if isinstance(value, list):
|
| | for idx, v in enumerate(value):
|
| | if wfp is not None:
|
| | if v is not None:
|
| | wfp.write(str([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)]) + "\n")
|
| | wfp.write(str(v.shape) + "\n")
|
| | else:
|
| | wfp.write("None\n")
|
| | wfp.write("-" * 10 + "\n")
|
| | else:
|
| | if v is not None:
|
| | print([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)])
|
| | print(v.shape)
|
| | else:
|
| | print("None")
|
| | print("-" * 50)
|
| | if v is not None:
|
| | try:
|
| | value = v.detach().cpu().numpy().astype(int)
|
| | if debug_path is not None:
|
| | if value.ndim == 3:
|
| | for dim_1_idx in range(value.shape[0]):
|
| | np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
|
| | else:
|
| | np.savetxt(os.path.join(debug_path, "%d.txt" % idx), value, fmt='%i', delimiter=",")
|
| | else:
|
| | if value.ndim == 3:
|
| | for dim_1_idx in range(value.shape[0]):
|
| | np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
|
| | else:
|
| | np.savetxt("%d.txt" % idx, value, fmt='%i', delimiter=",")
|
| | except Exception as e:
|
| | print(e)
|
| | elif isinstance(value, dict):
|
| | for item in value.items():
|
| | if wfp is not None:
|
| | wfp.write(str(item[0]) + ":\n")
|
| | else:
|
| | print(str(item[0]) + ':')
|
| | print_batch(item[1], item[0], debug_path, wfp, local_rank)
|
| | else:
|
| | if wfp is not None:
|
| | if value is not None:
|
| | wfp.write(str([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)]) + "\n")
|
| | wfp.write(str(value.shape) + "\n")
|
| | else:
|
| | wfp.write("None\n")
|
| | wfp.write("-" * 10 + "\n")
|
| | else:
|
| | if value is not None:
|
| | print([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)])
|
| | print(value.shape)
|
| | else:
|
| | print("None")
|
| | print("-" * 10)
|
| | if value is not None:
|
| | if key != "prot_structure":
|
| | fmt = '%i'
|
| | d_type = int
|
| | else:
|
| | fmt = '%0.4f'
|
| | d_type = float
|
| | try:
|
| | value = value.detach().cpu().numpy().astype(d_type)
|
| | if debug_path is not None:
|
| | if value.ndim == 3:
|
| | for dim_1_idx in range(value.shape[0]):
|
| | np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
|
| | else:
|
| | np.savetxt(os.path.join(debug_path, "%s.txt" % key), value, fmt=fmt, delimiter=",")
|
| | else:
|
| | if value.ndim == 3:
|
| | for dim_1_idx in range(value.shape[0]):
|
| | np.savetxt("%s_batch_%d.txt" % (key, dim_1_idx), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
|
| | else:
|
| | np.savetxt("%s.txt" % key, value, fmt=fmt, delimiter=",")
|
| | except Exception as e:
|
| | print(e)
|
| |
|
| |
|
| | def gcd(x, y):
|
| | '''
|
| | 最大公约数
|
| | :param x:
|
| | :param y:
|
| | :return:
|
| | '''
|
| | m = max(x, y)
|
| | n = min(x, y)
|
| | while m % n:
|
| | m, n = n, m % n
|
| | return n
|
| |
|
| |
|
| | def lcm(x, y):
|
| | '''
|
| | 最小公倍数
|
| | :param x:
|
| | :param y:
|
| | :return:
|
| | '''
|
| | m = max(x, y)
|
| | n = min(x, y)
|
| | while m % n:
|
| | m, n = n, m % n
|
| | return x*y//n
|
| |
|
| |
|
| | def device_memory(gpu_id):
|
| | if gpu_id is None or gpu_id < 0:
|
| | return
|
| | pynvml.nvmlInit()
|
| | device_cnt = pynvml.nvmlDeviceGetCount()
|
| | for idx in range(device_cnt):
|
| | if gpu_id is not None and gpu_id != idx:
|
| | continue
|
| | handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
| | info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| | print(f"Device {idx}: {pynvml.nvmlDeviceGetName(handle)}")
|
| | print(f"Total memory: {info.total / 1024**3:.8f} GB")
|
| | print(f"Used memory: {info.used / 1024**3:.8f} GB")
|
| | print(f"Free memory: {info.free / 1024**3:.8f} GB")
|
| | pynvml.nvmlShutdown()
|
| |
|
| |
|
| | def calc_emb_filename_by_seq_id(seq_id, embedding_type):
|
| | """
|
| | 根据seq_id得到emb_filename
|
| | :param seq_id:
|
| | :param embedding_type:
|
| | :return:
|
| | """
|
| | if seq_id[0] == ">":
|
| | seq_id = seq_id[1:]
|
| | if "|" in seq_id:
|
| | strs = seq_id.split("|")
|
| | if len(strs) > 1:
|
| | emb_filename = embedding_type + "_" + strs[1].strip() + ".pt"
|
| | else:
|
| | emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
|
| | else:
|
| | emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
|
| | return emb_filename
|
| |
|
| |
|
| | def download_file(url, local_filename):
|
| | with requests.get(url, stream=True) as r:
|
| | r.raise_for_status()
|
| | dir_name = os.path.dirname(local_filename)
|
| | if not os.path.exists(dir_name):
|
| | os.makedirs(dir_name)
|
| | with open(local_filename, 'wb') as f:
|
| | for chunk in r.iter_content(chunk_size=8192):
|
| | if chunk:
|
| | f.write(chunk)
|
| | return local_filename
|
| |
|
| |
|
| | def download_folder(base_url, file_names, local_dir):
|
| | if not os.path.exists(local_dir):
|
| | os.makedirs(local_dir)
|
| |
|
| | for file_name in file_names:
|
| | file_url = f"{base_url}/{file_name}"
|
| | local_filename = os.path.join(local_dir, file_name)
|
| | download_file(file_url, local_filename)
|
| | print(f"Downloaded {file_name}")
|
| |
|
| |
|
| | def download_trained_checkpoint_lucaone(
|
| | llm_dir,
|
| | llm_type="lucaone_gplm",
|
| | llm_version="v2.0",
|
| | llm_task_level="token_level,span_level,seq_level,structure_level",
|
| | llm_time_str="20231125113045",
|
| | llm_step="5600000",
|
| | base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
|
| | ):
|
| | """
|
| | donwload trained checkpoint of LucaOne
|
| | :param llm_dir:
|
| | :param llm_type:
|
| | :param llm_version:
|
| | :param llm_task_level:
|
| | :param llm_time_str:
|
| | :param llm_step:
|
| | :param base_url:
|
| | :return:
|
| | """
|
| | print("------Download Trained LLM(LucaOne)------")
|
| | try:
|
| | logs_file_names = ["logs.txt"]
|
| | models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
|
| | logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
|
| | models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
|
| | logs_local_dir = os.path.join(llm_dir, logs_path)
|
| | exists = True
|
| | for logs_file_name in logs_file_names:
|
| | if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
|
| | exists = False
|
| | break
|
| | models_local_dir = os.path.join(llm_dir, models_path)
|
| | if exists:
|
| | for models_file_name in models_file_names:
|
| | if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
|
| | exists = False
|
| | break
|
| | if not exists:
|
| | print("*" * 20 + "Downloading" + "*" * 20)
|
| | print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
|
| | print("Wait a moment, please.")
|
| |
|
| | if not os.path.exists(logs_local_dir):
|
| | os.makedirs(logs_local_dir)
|
| | logs_base_url = os.path.join(base_url, logs_path)
|
| | download_folder(logs_base_url, logs_file_names, logs_local_dir)
|
| |
|
| | if not os.path.exists(models_local_dir):
|
| | os.makedirs(models_local_dir)
|
| | models_base_url = os.path.join(base_url, models_path)
|
| | download_folder(models_base_url, models_file_names, models_local_dir)
|
| | print("LucaOne Download Succeed.")
|
| | print("*" * 50)
|
| | except Exception as e:
|
| | print(e)
|
| | print("Download automatically LucaOne Trained CheckPoint failed!")
|
| | print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
|
| | raise Exception(e)
|
| |
|
| |
|
| | def download_trained_checkpoint_downstream_tasks(
|
| | save_dir="../",
|
| | dataset_name=["CentralDogma", "GenusTax", "InfA", "ncRNAFam", "ncRPI", "PPI", "ProtLoc", "ProtStab", "SpeciesTax", "SupKTax"],
|
| | dataset_type=["gene_protein", "gene", "gene_gene", "gene", "gene_protein", "protein", "protein", "protein", "gene", "gene"],
|
| | task_type=["binary_class", "multi_class", "binary_class", "multi_class", "binary_class", "binary_class", "multi_class", "regression", "multi_class", "multi_class"],
|
| | model_type=["lucappi2", "luca_base", "lucappi", "luca_base", "lucappi2", "lucappi", "luca_base", "luca_base", "luca_base", "luca_base"],
|
| | input_type=["matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix"],
|
| | time_str=["20240406173806", "20240412100337", "20240214105653", "20240414155526", "20240404105148", "20240216205421", "20240412140824", "20240404104215", "20240411144916", "20240212202328"],
|
| | step=[64000, 24500, 9603, 1958484, 716380, 52304, 466005, 70371, 24000, 37000],
|
| | base_url="http://47.93.21.181/lucaone/DownstreamTasksTrainedModels"
|
| | ):
|
| | """
|
| | donwload trained downstream task models
|
| | :param save_dir: 本地保存路径
|
| | :param dataset_name:
|
| | :param dataset_type:
|
| | :param task_type:
|
| | :param model_type:
|
| | :param input_type:
|
| | :param time_str:
|
| | :param step:
|
| | :param base_url:
|
| | :return:
|
| | """
|
| | assert len(dataset_name) == len(dataset_type) == len(task_type) == \
|
| | len(model_type) == len(input_type) == len(time_str) == len(step)
|
| | assert isinstance(dataset_name, list)
|
| | assert isinstance(dataset_type, list)
|
| | assert isinstance(task_type, list)
|
| | assert isinstance(model_type, list)
|
| | assert isinstance(input_type, list)
|
| | assert isinstance(time_str, list)
|
| | assert isinstance(step, list)
|
| | download_succeed_task_num = 0
|
| | print("------Download Trained Models------")
|
| | for idx in range(len(dataset_name)):
|
| | try:
|
| | logs_file_names = ["logs.txt", "label.txt"]
|
| | models_file_names = ["config.json", "pytorch_model.bin", "training_args.bin", "tokenizer/alphabet.pkl"]
|
| | logs_path = "logs/%s/%s/%s/%s/%s/%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx])
|
| | models_path = "models/%s/%s/%s/%s/%s/%s/checkpoint-%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx], str(step[idx]))
|
| | logs_local_dir = os.path.join(save_dir, logs_path)
|
| | exists = True
|
| | for logs_file_name in logs_file_names:
|
| | if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
|
| | exists = False
|
| | break
|
| | models_local_dir = os.path.join(save_dir, models_path)
|
| | if exists:
|
| | for models_file_name in models_file_names:
|
| | if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
|
| | exists = False
|
| | break
|
| | if not exists:
|
| | print("*" * 20 + "Downloading" + "*" * 20)
|
| | print("Downloading Downstream Task: %s TrainedCheckPoint: %s-%s-%s ..." % (dataset_name[idx], dataset_name[idx], time_str[idx], str(step[idx])))
|
| | print("Wait a moment, please.")
|
| |
|
| | if not os.path.exists(logs_local_dir):
|
| | os.makedirs(logs_local_dir)
|
| | logs_base_url = os.path.join(base_url, dataset_name[idx], logs_path)
|
| | download_folder(logs_base_url, logs_file_names, logs_local_dir)
|
| |
|
| | if not os.path.exists(models_local_dir):
|
| | os.makedirs(models_local_dir)
|
| | models_base_url = os.path.join(base_url, dataset_name[idx], models_path)
|
| | download_folder(models_base_url, models_file_names, models_local_dir)
|
| | print("Downstream Task: %s Trained Model Download Succeed." % dataset_name[idx])
|
| | print("*" * 50)
|
| | download_succeed_task_num += 1
|
| | except Exception as e:
|
| | print(e)
|
| | print("Download automatically LucaDownstream Task: %s Trained CheckPoint failed!" % dataset_name[idx])
|
| | print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(save_dir), os.path.join(base_url, dataset_name[idx])))
|
| | raise Exception(e)
|
| | print("%d Downstream Task Trained Model Download Succeed." % download_succeed_task_num) |