|
|
import itertools |
|
|
|
|
|
from copy import deepcopy |
|
|
|
|
|
import argparse |
|
|
import socket |
|
|
from scipy.stats import spearmanr, pearsonr |
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, \ |
|
|
r2_score |
|
|
from typing import Optional |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import numpy as np |
|
|
from fairseq.data import Dictionary |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
|
|
|
from model.LMConfig import LMConfig |
|
|
from model.codon_tables import AA_str |
|
|
|
|
|
def compute_metrics_regression(preds, labels): |
|
|
spr = spearmanr(preds, labels)[0] |
|
|
pr = pearsonr(preds, labels)[0] |
|
|
mse = np.mean((preds - labels) ** 2) |
|
|
rmse = np.sqrt(mse) |
|
|
r2 = r2_score(labels,preds) |
|
|
return {'spearmanr':spr, 'pearsonr':pr,'mse':mse, 'rmse':rmse, 'r2':r2} |
|
|
|
|
|
|
|
|
def compute_metrics_dict(preds, labels, average='macro', multi_class='ovr',cls='binary'): |
|
|
""" |
|
|
计算分类任务的评估指标 |
|
|
|
|
|
参数: |
|
|
preds: 预测值 (可以是类别标签或概率) |
|
|
labels: 真实标签 |
|
|
average: 多分类时的平均方式 ('micro', 'macro', 'weighted', 'binary') |
|
|
multi_class: 多分类时AUC的计算方式 ('ovr', 'ovo') |
|
|
https://rcxqhxlmkf.feishu.cn/wiki/ONHBwenBjiNUkgk54mQcwVBznEg#share-RWVDdIzU2oC5dZxCgqKcHYtrnfc |
|
|
""" |
|
|
if cls =='regression': |
|
|
return compute_metrics_regression(preds, labels) |
|
|
|
|
|
if cls =='identity': |
|
|
|
|
|
pred_labels = np.argmax(preds, axis=1) |
|
|
pred_codon = [list(pred_labels[i:i+3]) for i in range(0,len(pred_labels),3)] |
|
|
true_codon = [list(labels[i:i+3]) for i in range(0,len(pred_labels),3)] |
|
|
identity_codon = sum(1 for c1, c2 in zip(pred_codon, true_codon) if c1 == c2)/len(true_codon) |
|
|
identity_NN = sum(1 for c1, c2 in zip(pred_labels, labels) if c1 == c2)/len(labels) |
|
|
return {'identity_codon':identity_codon,'identity_NN':identity_NN} |
|
|
|
|
|
if preds.ndim > 1 and preds.shape[1] > 1: |
|
|
|
|
|
pred_probs = None |
|
|
|
|
|
pred_labels = np.argmax(preds, axis=1) |
|
|
elif preds.ndim > 1 and preds.shape[1] == 2: |
|
|
|
|
|
pred_probs = np.sigmoid(preds, axis=1) |
|
|
pred_labels = (pred_probs[:, 1] > 0.5).astype(int) |
|
|
else: |
|
|
|
|
|
pred_labels = preds |
|
|
pred_probs = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels, pred_labels) |
|
|
precision = precision_score(labels, pred_labels, average=average, zero_division=0) |
|
|
recall = recall_score(labels, pred_labels, average=average, zero_division=0) |
|
|
f1 = f1_score(labels, pred_labels, average=average, zero_division=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'precision': precision, |
|
|
'recall': recall, |
|
|
'f1_score': f1, |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
def flatten_col(col, group=1, exclude='_', frames=None): |
|
|
""" |
|
|
展开给定列或者嵌套列表 |
|
|
frames=['0','1','2','01','12','02','012']: validated when group ==1 |
|
|
|
|
|
group =1 and frames=['0','1','2','01','12','02','012'] : return all frames |
|
|
group =1 and frames=None :NN |
|
|
group =2 :DiNN |
|
|
group =3 :codon |
|
|
""" |
|
|
if type(col) == str: |
|
|
str1 = list(col) |
|
|
|
|
|
else: |
|
|
nested_list = col.apply(list).tolist() |
|
|
str1 = list(itertools.chain(*nested_list)) |
|
|
exclude_num = str1.count(exclude) |
|
|
if exclude_num != 0: |
|
|
|
|
|
triplets1 = [''.join(str1[i:i + 3]) for i in range(0, len(str1), 3)] |
|
|
triplets1 = [triplet for triplet in triplets1 if exclude not in triplet] |
|
|
str1 = ''.join(triplets1) |
|
|
|
|
|
if group == 1: |
|
|
if frames: |
|
|
return multi_frames(deepcopy(str1), frames) |
|
|
return str1 |
|
|
if len(str1) % group != 0: |
|
|
raise ValueError(f"字符串长度必须相同且是{group}的倍数") |
|
|
triplets1 = [''.join(str1[i:i + group]) for i in range(0, len(str1), group)] |
|
|
return triplets1 |
|
|
def multi_frames(str1, frames): |
|
|
str1_list = [] |
|
|
for frame in frames: |
|
|
if len(frame) == 1: |
|
|
triplets1 = [str1[i + int(frame)] for i in range(0, len(str1), 3)] |
|
|
else: |
|
|
triplets1 = [''.join([str1[i + int(fr)] for fr in frame]) for i in range(0, len(str1) - 3 + 1, 3)] |
|
|
tmp = ''.join(triplets1) |
|
|
str1_list.append(tmp) |
|
|
return str1_list |
|
|
|
|
|
def get_correct(labels, preds, prefix='', average='macro'): |
|
|
str1 = labels |
|
|
str2 = preds |
|
|
if len(str1) == 0: |
|
|
raise ValueError(f"{prefix}str1 is empty") |
|
|
|
|
|
if len(str1) != len(str2): |
|
|
raise ValueError(f"字符串长度必须相同,str1_len:{len(str1)},str2_len:{len(str2)}") |
|
|
|
|
|
|
|
|
correct = sum(1 for c1, c2 in zip(str1, str2) if c1 == c2) |
|
|
data = { |
|
|
|
|
|
|
|
|
'identity': correct / len(str1), |
|
|
'label_seq': ''.join(str1), |
|
|
'pred_seq': ''.join(str2) |
|
|
} |
|
|
alphabet = set(str1)|set(str2) |
|
|
alphabet = {k: v for k, v in zip(alphabet, range(len(alphabet)))} |
|
|
labels = [alphabet[k] for k in str1] |
|
|
preds = [alphabet[k] for k in str2] |
|
|
|
|
|
data.update( |
|
|
compute_metrics_dict(np.array(preds).flatten(), np.array(labels).flatten(), cls='binary', average=average)) |
|
|
ans = {f'{prefix}{k}': v for k, v in data.items()} |
|
|
|
|
|
|
|
|
|
|
|
return ans |
|
|
|
|
|
|
|
|
def calculate_accuracy(label, pred, group=1, exclude='_', frames=None): |
|
|
str1 = flatten_col(label, group=group, exclude=exclude, frames=frames) |
|
|
str2 = flatten_col(pred, group=group, exclude=exclude, frames=frames) |
|
|
|
|
|
if frames: |
|
|
ans_dict = {} |
|
|
for frame, s1, s2 in zip(frames, str1, str2): |
|
|
ans_dict.update(get_correct(s1, s2, prefix=f'{frame}_')) |
|
|
return ans_dict |
|
|
else: |
|
|
return get_correct(str1, str2) |
|
|
|
|
|
def MeanPearsonCorrCoefPerChannel(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
n_channels = preds.shape[1] |
|
|
reduce_dims = (0,1) |
|
|
|
|
|
|
|
|
product = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
true_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
true_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
pred_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
pred_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
count = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) |
|
|
|
|
|
|
|
|
product += torch.sum(preds * target, dim=reduce_dims) |
|
|
true_sum += torch.sum(target, dim=reduce_dims) |
|
|
true_squared_sum += torch.sum(torch.square(target), dim=reduce_dims) |
|
|
pred_sum += torch.sum(preds, dim=reduce_dims) |
|
|
pred_squared_sum += torch.sum(torch.square(preds), dim=reduce_dims) |
|
|
count += torch.sum(torch.ones_like(target), dim=reduce_dims) |
|
|
|
|
|
|
|
|
true_mean = true_sum / count |
|
|
pred_mean = pred_sum / count |
|
|
|
|
|
|
|
|
covariance = (product |
|
|
- true_mean * pred_sum |
|
|
- pred_mean * true_sum |
|
|
+ count * true_mean * pred_mean) |
|
|
|
|
|
|
|
|
true_var = true_squared_sum - count * torch.square(true_mean) |
|
|
pred_var = pred_squared_sum - count * torch.square(pred_mean) |
|
|
|
|
|
|
|
|
tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var) |
|
|
|
|
|
|
|
|
correlation = covariance / tp_var |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return correlation.abs() |
|
|
|
|
|
def init_config(vocab_path,n_layers,max_seq_len): |
|
|
tokenizer = Dictionary.load(vocab_path) |
|
|
tokenizer.mask_index = tokenizer.add_symbol('<mask>') |
|
|
[tokenizer.add_symbol(word) for word in AA_str] |
|
|
|
|
|
lm_config = LMConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) |
|
|
|
|
|
return lm_config,tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''sorcket port''' |
|
|
def find_free_port(): |
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
|
print("Binding to a random port...") |
|
|
s.bind(('127.0.0.1', 0)) |
|
|
|
|
|
return s.getsockname()[1] |
|
|
|
|
|
def is_port_in_use(port): |
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
|
return s.connect_ex(('127.0.0.1', port)) == 0 |
|
|
|
|
|
def get_port(): |
|
|
'''todo: 无法保证所有卡都是统一端口号,这个代码还有问题''' |
|
|
|
|
|
free_port = find_free_port() |
|
|
max_attempts = 100 |
|
|
attempts = 0 |
|
|
|
|
|
while is_port_in_use(free_port) and attempts < max_attempts: |
|
|
free_port = find_free_port() |
|
|
attempts += 1 |
|
|
print(f"[{attempts}/{max_attempts}]Port {free_port} is in use, trying another port...") |
|
|
|
|
|
if attempts >= max_attempts: |
|
|
raise RuntimeError("无法找到未被占用的端口") |
|
|
return free_port |
|
|
def get_pretraining_args(): |
|
|
"""pretrain""" |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="MiniMind Full SFT") |
|
|
parser.add_argument("--out_dir", type=str, default="out") |
|
|
parser.add_argument("--epochs", type=int, default=100) |
|
|
parser.add_argument("--batch_size", type=int, default=32) |
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6) |
|
|
parser.add_argument("--celoss_alpha", type=float, default=0.1) |
|
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
parser.add_argument("--dtype", type=str, default="bfloat16") |
|
|
parser.add_argument("--use_wandb", action="store_true") |
|
|
parser.add_argument("--wandb_project", type=str, default="RiboUTR-PT") |
|
|
parser.add_argument("--num_workers", type=int, default=1) |
|
|
parser.add_argument("--ddp", action="store_true",help='DistributedDataParallel') |
|
|
parser.add_argument("--accumulation_steps", type=int, default=1) |
|
|
parser.add_argument("--grad_clip", type=float, default=1.0) |
|
|
parser.add_argument("--warmup_iters", type=int, default=0) |
|
|
parser.add_argument("--log_interval", type=int, default=10) |
|
|
parser.add_argument("--save_interval", type=int, default=100) |
|
|
parser.add_argument('--local_rank', type=int, default=-1) |
|
|
parser.add_argument("--data_path", type=str, default="./dataset/sft_mini_512.jsonl") |
|
|
|
|
|
"""dataset""" |
|
|
|
|
|
parser.add_argument('--n_layers', default=8, type=int) |
|
|
parser.add_argument('--is_twod', default=True, type=bool) |
|
|
parser.add_argument('--max_seq_len', default=1205, type=int) |
|
|
parser.add_argument('--use_moe', action='store_true', help="add moe layer") |
|
|
|
|
|
|
|
|
parser.add_argument("--mlm_pretrained_model_path", type=str, default=f"./checkpoint/ernierna.pt") |
|
|
|
|
|
parser.add_argument("--arg_overrides", type=dict,default={"data": f'./utils/ernie_rna/'}, help="The path of vocabulary") |
|
|
|
|
|
|
|
|
parser.add_argument('--finetune', action='store_true') |
|
|
parser.add_argument('--scaler', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument("--ffasta", default='./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl', |
|
|
type=str, help="The path of input seqs") |
|
|
parser.add_argument("--exp_pretrain_data_path", default='./dataset/experiment/nature/', type=str, |
|
|
help="The path of expPretrain data") |
|
|
parser.add_argument("--downstream_data_path", default='./dataset/downstream/', type=str, |
|
|
help="The path of Task/TR,VL,TS.csv") |
|
|
parser.add_argument('--task', type=str, default='predict_web', |
|
|
help='task in downstream dir') |
|
|
parser.add_argument("--seq_len", type=int, default=1205, help="The length of sequence") |
|
|
parser.add_argument("--env_counts", type=int, default=10, help="The length of sequence") |
|
|
parser.add_argument("--column", type=str, default="sequence", help="The sequences' column name") |
|
|
parser.add_argument("--label", type=str, default="label", help="The label") |
|
|
parser.add_argument("--pad_method", type=str, default="pre", help="The method which pad sequence") |
|
|
parser.add_argument("--region", default=300, type=int, help="The context length/2") |
|
|
parser.add_argument("--env_id", default=1, type=int, help="0") |
|
|
parser.add_argument("--limit", default=-1, type=int, help="less samples") |
|
|
parser.add_argument('--debug', action='store_true', help="debug mode") |
|
|
parser.add_argument('--codon_table_path', type=str, default="maotao_file/codon_table/codon_usage_{species}.csv", help="The method which pad sequence") |
|
|
|
|
|
|
|
|
"""predict mode""" |
|
|
|
|
|
parser.add_argument('--predict', action='store_true', help="save predict result") |
|
|
parser.add_argument('--test_file', default=None, help="asign test file") |
|
|
"""design mode""" |
|
|
parser.add_argument('--Kozak_GS6H_Stop3', default='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG', help="kozak,tag,Stop3") |
|
|
return parser |
|
|
|
|
|
def get_dataset_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
return parser |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unifi_dataloader(train_ds, args, ddp=False, data_tag='TR'): |
|
|
train_sampler = DistributedSampler(train_ds) if ddp else None |
|
|
drop_last = True if ddp else False |
|
|
if data_tag =='TR': |
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=drop_last, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
sampler=train_sampler, |
|
|
|
|
|
) |
|
|
else: |
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=drop_last, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
|
|
|
|
|
|
) |
|
|
return train_loader |
|
|
def ddp_broadcast_early_stopping(ddp_local_rank, args, early_stopping, current_loss, model,dist): |
|
|
|
|
|
if ddp_local_rank == 0: |
|
|
early_stopping(current_loss, model) |
|
|
if early_stopping.early_stop:early_stopping.counter = 0 |
|
|
|
|
|
to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device) |
|
|
to_broadcast_counter = torch.tensor([early_stopping.counter], dtype=torch.int, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
dist.broadcast(to_broadcast_counter, 0) |
|
|
else: |
|
|
|
|
|
to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device) |
|
|
to_broadcast_counter = torch.tensor([0], dtype=torch.int, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
dist.broadcast(to_broadcast_counter, 0) |
|
|
early_stopping.early_stop = bool(to_broadcast.item()) |
|
|
early_stopping.counter = int(to_broadcast_counter.item()) |
|
|
|
|
|
class EarlyStopping: |
|
|
"""Early stops the training if validation loss doesn't improve after a given patience.""" |
|
|
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): |
|
|
""" |
|
|
Args: |
|
|
patience (int): How long to wait after last time validation loss improved. |
|
|
Default: 7 |
|
|
verbose (bool): If True, prints a message for each validation loss improvement. |
|
|
Default: False |
|
|
delta (float): Minimum change in the monitored quantity to qualify as an improvement. |
|
|
Default: 0 |
|
|
path (str): Path for the checkpoint to be saved to. |
|
|
Default: 'checkpoint.pt' |
|
|
trace_func (function): Trace print function. |
|
|
Default: print |
|
|
""" |
|
|
self.patience = patience |
|
|
self.verbose = verbose |
|
|
self.counter = 0 |
|
|
self.best_score = None |
|
|
self.early_stop = False |
|
|
self.val_loss_min = np.Inf |
|
|
self.delta = delta |
|
|
self.path = path |
|
|
self.trace_func = trace_func |
|
|
|
|
|
def __call__(self, val_loss, model): |
|
|
|
|
|
score = -val_loss |
|
|
|
|
|
if self.best_score is None: |
|
|
self.best_score = score |
|
|
self.save_checkpoint(val_loss, model) |
|
|
elif score < self.best_score + self.delta: |
|
|
self.counter += 1 |
|
|
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') |
|
|
if self.counter >= self.patience: |
|
|
self.early_stop = True |
|
|
|
|
|
else: |
|
|
self.best_score = score |
|
|
self.save_checkpoint(val_loss, model) |
|
|
self.counter = 0 |
|
|
return self.early_stop |
|
|
def save_checkpoint(self, val_loss, model): |
|
|
'''Saves model when validation loss decrease.''' |
|
|
model.eval() |
|
|
if self.verbose: |
|
|
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..., {self.path}') |
|
|
self.save_model(model, self.path) |
|
|
self.val_loss_min = val_loss |
|
|
@staticmethod |
|
|
def save_model(model, path): |
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
|
state_dict = model.module.state_dict() |
|
|
else: |
|
|
state_dict = model.state_dict() |
|
|
torch.save(state_dict,path) |
|
|
|
|
|
|
|
|
def generate_inputs(x): |
|
|
pad_mark='_' |
|
|
bos='<' |
|
|
eos='>' |
|
|
region = 300 |
|
|
link = 'N' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utr5 = x["UTR5"] |
|
|
utr3 = x["UTR3"] |
|
|
cds = x["CDS"] |
|
|
|
|
|
utr5 = process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) |
|
|
cds_h = process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) |
|
|
cds_t = process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) |
|
|
utr3 = process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) |
|
|
seq = utr5 + cds_h + cds_t + utr3 |
|
|
seq = seq[:region*2+1]+link*3+seq[-region*2-1:] |
|
|
return seq |
|
|
|
|
|
def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'): |
|
|
if len(utr) < input_len: |
|
|
if pad_method == 'pre': |
|
|
padded_utr = pad_mark * (input_len - len(utr)) + bos + utr |
|
|
elif pad_method == 'behind': |
|
|
padded_utr = utr+eos + pad_mark * (input_len - len(utr)) |
|
|
else: |
|
|
if pad_method == 'pre': |
|
|
padded_utr = bos+utr[-input_len:] |
|
|
elif pad_method == 'behind': |
|
|
padded_utr = utr[:input_len]+eos |
|
|
return padded_utr |
|
|
|
|
|
|
|
|
|
|
|
def find_unused_parameters(model,output): |
|
|
contributing_parameters = set(get_contributing_params(output)) |
|
|
all_parameters = set(model.parameters()) |
|
|
non_contributing = all_parameters - contributing_parameters |
|
|
print("未参与计算的参数:") |
|
|
for param in non_contributing: |
|
|
|
|
|
for name, p in model.named_parameters(): |
|
|
if p is param: |
|
|
print(f" {name}") |
|
|
def get_contributing_params(y, top_level=True): |
|
|
"""找到对输出y有贡献的所有参数""" |
|
|
nf = y.grad_fn.next_functions if top_level else y.next_functions |
|
|
for f, _ in nf: |
|
|
try: |
|
|
yield f.variable |
|
|
except AttributeError: |
|
|
pass |
|
|
if f is not None: |
|
|
yield from get_contributing_params(f, top_level=False) |