import itertools from copy import deepcopy import argparse import socket from scipy.stats import spearmanr, pearsonr from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, \ r2_score from typing import Optional import math import torch import numpy as np from fairseq.data import Dictionary from torch.utils.data import DataLoader, DistributedSampler from model.LMConfig import LMConfig from model.codon_tables import AA_str def compute_metrics_regression(preds, labels): spr = spearmanr(preds, labels)[0] pr = pearsonr(preds, labels)[0] mse = np.mean((preds - labels) ** 2) rmse = np.sqrt(mse) r2 = r2_score(labels,preds) return {'spearmanr':spr, 'pearsonr':pr,'mse':mse, 'rmse':rmse, 'r2':r2} def compute_metrics_dict(preds, labels, average='macro', multi_class='ovr',cls='binary'): """ 计算分类任务的评估指标 参数: preds: 预测值 (可以是类别标签或概率) labels: 真实标签 average: 多分类时的平均方式 ('micro', 'macro', 'weighted', 'binary') multi_class: 多分类时AUC的计算方式 ('ovr', 'ovo') https://rcxqhxlmkf.feishu.cn/wiki/ONHBwenBjiNUkgk54mQcwVBznEg#share-RWVDdIzU2oC5dZxCgqKcHYtrnfc """ if cls =='regression': return compute_metrics_regression(preds, labels) if cls =='identity': # codon pred_labels = np.argmax(preds, axis=1) pred_codon = [list(pred_labels[i:i+3]) for i in range(0,len(pred_labels),3)] true_codon = [list(labels[i:i+3]) for i in range(0,len(pred_labels),3)] identity_codon = sum(1 for c1, c2 in zip(pred_codon, true_codon) if c1 == c2)/len(true_codon) identity_NN = sum(1 for c1, c2 in zip(pred_labels, labels) if c1 == c2)/len(labels) return {'identity_codon':identity_codon,'identity_NN':identity_NN} # 如果preds是概率值而不是类别标签,转换为类别标签 if preds.ndim > 1 and preds.shape[1] > 1: # 多分类概率情况 pred_probs = None # pred_probs = np.softmax(preds, axis=1) pred_labels = np.argmax(preds, axis=1) elif preds.ndim > 1 and preds.shape[1] == 2: # 二分类概率情况 pred_probs = np.sigmoid(preds, axis=1) pred_labels = (pred_probs[:, 1] > 0.5).astype(int) else: # 已经是类别标签 pred_labels = preds pred_probs = None # if cls == 'identity': # pred_labels = np.argmax(preds, axis=1) # labels = labels == pred_labels # 基础分类指标 accuracy = accuracy_score(labels, pred_labels) precision = precision_score(labels, pred_labels, average=average, zero_division=0) recall = recall_score(labels, pred_labels, average=average, zero_division=0) f1 = f1_score(labels, pred_labels, average=average, zero_division=0) # 计算混淆矩阵 # cm = confusion_matrix(labels, pred_labels) # AUC-ROC (仅在可以计算概率时) # auc_roc = None # if pred_probs is not None: # try: # if len(np.unique(labels)) == 2: # # 二分类 # auc_roc = roc_auc_score(labels, pred_probs[:, 1]) # else: # # 多分类 # auc_roc = roc_auc_score(labels, pred_probs, multi_class=multi_class, average=average) # except Exception as e: # auc_roc = None # exit(f'Error computing AUC-ROC for classification.{e}') # # 计算每个类别的指标(多分类时) # per_class_metrics = {} # if len(np.unique(labels)) > 2: # precision_per_class = precision_score(labels, pred_labels, average=None, zero_division=0) # recall_per_class = recall_score(labels, pred_labels, average=None, zero_division=0) # f1_per_class = f1_score(labels, pred_labels, average=None, zero_division=0) # # for i in range(len(precision_per_class)): # per_class_metrics[f'class_{i}'] = { # 'precision': precision_per_class[i], # 'recall': recall_per_class[i], # 'f1': f1_per_class[i] # } return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1, # 'auc_roc': auc_roc, # 'confusion_matrix': cm, # 'per_class_metrics': per_class_metrics } def flatten_col(col, group=1, exclude='_', frames=None): """ 展开给定列或者嵌套列表 frames=['0','1','2','01','12','02','012']: validated when group ==1 group =1 and frames=['0','1','2','01','12','02','012'] : return all frames group =1 and frames=None :NN group =2 :DiNN group =3 :codon """ if type(col) == str: str1 = list(col) # print(str1) else: nested_list = col.apply(list).tolist() str1 = list(itertools.chain(*nested_list)) exclude_num = str1.count(exclude) if exclude_num != 0: # delete space triplet triplets1 = [''.join(str1[i:i + 3]) for i in range(0, len(str1), 3)] triplets1 = [triplet for triplet in triplets1 if exclude not in triplet] str1 = ''.join(triplets1) # print(f"exclude_num:{exclude_num}") if group == 1: if frames: return multi_frames(deepcopy(str1), frames) return str1 if len(str1) % group != 0: raise ValueError(f"字符串长度必须相同且是{group}的倍数") triplets1 = [''.join(str1[i:i + group]) for i in range(0, len(str1), group)] return triplets1 def multi_frames(str1, frames): str1_list = [] for frame in frames: if len(frame) == 1: triplets1 = [str1[i + int(frame)] for i in range(0, len(str1), 3)] else: triplets1 = [''.join([str1[i + int(fr)] for fr in frame]) for i in range(0, len(str1) - 3 + 1, 3)] tmp = ''.join(triplets1) str1_list.append(tmp) return str1_list def get_correct(labels, preds, prefix='', average='macro'): str1 = labels str2 = preds if len(str1) == 0: raise ValueError(f"{prefix}str1 is empty") # return {'label':''.join(str1),'pred':''.join(str2)} if len(str1) != len(str2): raise ValueError(f"字符串长度必须相同,str1_len:{len(str1)},str2_len:{len(str2)}") # return {'label':''.join(str1),'pred':''.join(str2)} # raise ValueError(f"字符串长度必须相同,str1_len:{len(str1)},str2_len:{len(str2)}") correct = sum(1 for c1, c2 in zip(str1, str2) if c1 == c2) data = { # 'correct': correct, # 'total': len(str1), 'identity': correct / len(str1), 'label_seq': ''.join(str1), 'pred_seq': ''.join(str2) } alphabet = set(str1)|set(str2) alphabet = {k: v for k, v in zip(alphabet, range(len(alphabet)))} labels = [alphabet[k] for k in str1] preds = [alphabet[k] for k in str2] data.update( compute_metrics_dict(np.array(preds).flatten(), np.array(labels).flatten(), cls='binary', average=average)) ans = {f'{prefix}{k}': v for k, v in data.items()} # print(f"{prefix}correct':correct,f'{prefix}total':{len(str1)}") # return {'correct':correct,'total':len(str1),'accuracy':correct/len(str1),'label':''.join(str1),'pred':''.join(str2)} # return {f'{prefix}correct':correct,f'{prefix}total':len(str1),f'{prefix}accuracy':correct/len(str1)} return ans def calculate_accuracy(label, pred, group=1, exclude='_', frames=None): str1 = flatten_col(label, group=group, exclude=exclude, frames=frames) str2 = flatten_col(pred, group=group, exclude=exclude, frames=frames) # print(str1,str2) if frames: ans_dict = {} for frame, s1, s2 in zip(frames, str1, str2): ans_dict.update(get_correct(s1, s2, prefix=f'{frame}_')) return ans_dict else: return get_correct(str1, str2) # Correlation computation along positions from https://github.com/lucidrains/enformer-pytorch/blob/main/enformer_pytorch/metrics.py def MeanPearsonCorrCoefPerChannel(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: n_channels = preds.shape[1] # 获取通道数 reduce_dims = (0,1) # 按样本和区域维度聚合 # 初始化状态 product = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) true_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) true_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) pred_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) pred_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) count = torch.zeros(n_channels, dtype=torch.float32, device=preds.device) # 计算每个状态的值 product += torch.sum(preds * target, dim=reduce_dims) true_sum += torch.sum(target, dim=reduce_dims) true_squared_sum += torch.sum(torch.square(target), dim=reduce_dims) pred_sum += torch.sum(preds, dim=reduce_dims) pred_squared_sum += torch.sum(torch.square(preds), dim=reduce_dims) count += torch.sum(torch.ones_like(target), dim=reduce_dims) # 计算均值 true_mean = true_sum / count pred_mean = pred_sum / count # 计算协方差 covariance = (product - true_mean * pred_sum - pred_mean * true_sum + count * true_mean * pred_mean) # 计算方差 true_var = true_squared_sum - count * torch.square(true_mean) pred_var = pred_squared_sum - count * torch.square(pred_mean) # 计算标准差 tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var) # 计算皮尔逊相关系数 correlation = covariance / tp_var # 返回损失值: 1 - 相关系数(越接近1越好,因此损失越小越好) # loss = 1 - correlation.abs() # 为保证返回的loss是可微的,在缺少有效count时返回0 return correlation.abs() def init_config(vocab_path,n_layers,max_seq_len): tokenizer = Dictionary.load(vocab_path) tokenizer.mask_index = tokenizer.add_symbol('') # ['', '', '', '', 'G', 'A', 'U', 'C', 'N', ''] [tokenizer.add_symbol(word) for word in AA_str] # 10-31 # lm_config = LMConfig(dim=256, logit_dim=tokenizer.nspecial,n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, lm_config = LMConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, # lm_config = LMConfig(dim=256, logit_dim=9,n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, return lm_config,tokenizer # vocab_path = args.arg_overrides['data'] + '/small_dict.txt' # tokenizer = Dictionary.load(vocab_path) # tokenizer.mask_index = tokenizer.add_symbol('') # ['', '', '', '', 'G', 'A', 'U', 'C', 'N', ''] # lm_config = LMConfig(dim=256, n_layers=args.n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, '''sorcket port''' def find_free_port(): # 创建一个临时的socket对象,绑定到一个随机端口 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: print("Binding to a random port...") s.bind(('127.0.0.1', 0)) # 绑定到本地主机的随机端口 # 获取系统分配的端口号 return s.getsockname()[1] def is_port_in_use(port): # 检查端口是否已被占用 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('127.0.0.1', port)) == 0 def get_port(): '''todo: 无法保证所有卡都是统一端口号,这个代码还有问题''' # 动态获取未被占用的端口号,并确保端口未被占用 free_port = find_free_port() max_attempts = 100 # 最大尝试次数 attempts = 0 while is_port_in_use(free_port) and attempts < max_attempts: free_port = find_free_port() attempts += 1 print(f"[{attempts}/{max_attempts}]Port {free_port} is in use, trying another port...") if attempts >= max_attempts: raise RuntimeError("无法找到未被占用的端口") return free_port def get_pretraining_args(): """pretrain""" # time torchrun --nproc_per_node 8 --master_port=22353 train_riboutr.py # --limit=-1 --batch_size=32 --n_layers=8 --use_wandb --ddp --local_rank=0 --epochs=100 --wandb_project=Amino_MOE0401 --use_moe=True --save_interval=100 --out_dir=exp_log/out_demo10 parser = argparse.ArgumentParser(description="MiniMind Full SFT") parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--learning_rate", type=float, default=5e-6) parser.add_argument("--celoss_alpha", type=float, default=0.1) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_wandb", action="store_true") parser.add_argument("--wandb_project", type=str, default="RiboUTR-PT") parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--ddp", action="store_true",help='DistributedDataParallel') parser.add_argument("--accumulation_steps", type=int, default=1) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--log_interval", type=int, default=10) # 100 parser.add_argument("--save_interval", type=int, default=100) # 100 parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument("--data_path", type=str, default="./dataset/sft_mini_512.jsonl") # sft_data.jsonl """dataset""" parser.add_argument('--n_layers', default=8, type=int) # 8 parser.add_argument('--is_twod', default=True, type=bool) parser.add_argument('--max_seq_len', default=1205, type=int) # 512 parser.add_argument('--use_moe', action='store_true', help="add moe layer") # False # ? mlm_pretrained_model_path # parser.add_argument("--mlm_pretrained_model_path", type=str, default="/public/home/jiang_jiuhong/soft/ERNIE-RNA/checkpoint/ERNIE-RNA_checkpoint/ERNIE-RNA_pretrain.pt") parser.add_argument("--mlm_pretrained_model_path", type=str, default=f"./checkpoint/ernierna.pt") # parser.add_argument("--mlm_pretrained_model_path", type=str, default=f"{username}/soft/ERNIE-RNA/checkpoint/ERNIE-RNA_checkpoint/ERNIE-RNA_pretrain.pt") parser.add_argument("--arg_overrides", type=dict,default={"data": f'./utils/ernie_rna/'}, help="The path of vocabulary") parser.add_argument('--finetune', action='store_true') ## if --finetune: true parser.add_argument('--scaler', action='store_true') ## if --finetune: true # parser.add_argument("--ffasta", default='./dataset/sequence/full.fa', type=str, help="The path of input seqs") parser.add_argument("--ffasta", default='./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl', type=str, help="The path of input seqs") parser.add_argument("--exp_pretrain_data_path", default='./dataset/experiment/nature/', type=str, help="The path of expPretrain data") parser.add_argument("--downstream_data_path", default='./dataset/downstream/', type=str, help="The path of Task/TR,VL,TS.csv") parser.add_argument('--task', type=str, default='predict_web', help='task in downstream dir') parser.add_argument("--seq_len", type=int, default=1205, help="The length of sequence") parser.add_argument("--env_counts", type=int, default=10, help="The length of sequence") parser.add_argument("--column", type=str, default="sequence", help="The sequences' column name") parser.add_argument("--label", type=str, default="label", help="The label") parser.add_argument("--pad_method", type=str, default="pre", help="The method which pad sequence") parser.add_argument("--region", default=300, type=int, help="The context length/2") parser.add_argument("--env_id", default=1, type=int, help="0") parser.add_argument("--limit", default=-1, type=int, help="less samples") parser.add_argument('--debug', action='store_true', help="debug mode") parser.add_argument('--codon_table_path', type=str, default="maotao_file/codon_table/codon_usage_{species}.csv", help="The method which pad sequence") """predict mode""" parser.add_argument('--predict', action='store_true', help="save predict result") parser.add_argument('--test_file', default=None, help="asign test file") """design mode""" parser.add_argument('--Kozak_GS6H_Stop3', default='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG', help="kozak,tag,Stop3") return parser def get_dataset_args(): parser = argparse.ArgumentParser() return parser # # parser.add_argument("--ffasta", default='./dataset/sequence/full.fa', type=str, help="The path of input seqs") # parser.add_argument("--ffasta", default='./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl', type=str, help="The path of input seqs") # parser.add_argument("--exp_pretrain_data_path", default='./dataset/experiment/nature/', type=str, help="The path of expPretrain data") # parser.add_argument("--downstream_data_path", default='./dataset/downstream/', type=str, help="The path of Task/TR,VL,TS.csv") # parser.add_argument("--arg_overrides", type=dict,default={"data": f'./utils/ernie_rna/'}, help="The path of vocabulary") # GRCh38.p14 # parser.add_argument("--seq_len", type=int, default=50, help="The length of sequence") # parser.add_argument("--column", type=str, default="sequence", help="The sequences' column name") # parser.add_argument("--label", type=str, default="label", help="The label") # parser.add_argument("--pad_method", type=str, default="pre", help="The method which pad sequence") # parser.add_argument("--region", default=300, type=int, help="The context length/2") # parser.add_argument("--env_id", default=0, type=int, help="The context length/2") # parser.add_argument("--limit", default=10, type=int, help="less samples") # parser.add_argument('--debug', action='store_true', help="debug mode") # return parser def unifi_dataloader(train_ds, args, ddp=False, data_tag='TR'): train_sampler = DistributedSampler(train_ds) if ddp else None drop_last = True if ddp else False if data_tag =='TR': train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=drop_last, # 以避免各卡处理的批次数量不同。 测试的时候容易把唯一的batch丢掉 shuffle=False, num_workers=args.num_workers, sampler=train_sampler, # 验证集不需要 # collate_fn = train_ds.collate_fn ) else: train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=drop_last, # 以避免各卡处理的批次数量不同。 shuffle=False, num_workers=args.num_workers, # collate_fn = train_ds.collate_fn ) return train_loader def ddp_broadcast_early_stopping(ddp_local_rank, args, early_stopping, current_loss, model,dist): # 分布式训练逻辑 if ddp_local_rank == 0: early_stopping(current_loss, model) # 如果监控的是SPR,直接传入-SPR即可 if early_stopping.early_stop:early_stopping.counter = 0 # 重置 early_stopping.counter, 为了温度从高到低蒸馏 # 广播 should_stop 的值到其他进程 to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device) to_broadcast_counter = torch.tensor([early_stopping.counter], dtype=torch.int, device=args.device) dist.broadcast(to_broadcast, 0) dist.broadcast(to_broadcast_counter, 0) else: # 非主进程等待主进程广播 to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device) # 这个False只是缓冲池 to_broadcast_counter = torch.tensor([0], dtype=torch.int, device=args.device) # 这个False只是缓冲池 dist.broadcast(to_broadcast, 0) dist.broadcast(to_broadcast_counter, 0) early_stopping.early_stop = bool(to_broadcast.item()) early_stopping.counter = int(to_broadcast_counter.item()) class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' trace_func (function): Trace print function. Default: print """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta self.path = path self.trace_func = trace_func def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 return self.early_stop def save_checkpoint(self, val_loss, model): '''Saves model when validation loss decrease.''' model.eval() if self.verbose: self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..., {self.path}') self.save_model(model, self.path) self.val_loss_min = val_loss @staticmethod def save_model(model, path): if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict,path) def generate_inputs(x): pad_mark='_' bos='<' eos='>' region = 300 link = 'N' # utr5 = x["UTR5"] if 'UTR5' in x else UTR5 # utr3 = x["UTR3"] if 'UTR3' in x else UTR3 # cds = x["CDS"] if 'CDS' in x else CDS utr5 = x["UTR5"] utr3 = x["UTR3"] cds = x["CDS"] utr5 = process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) cds_h = process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) cds_t = process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) utr3 = process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) seq = utr5 + cds_h + cds_t + utr3 seq = seq[:region*2+1]+link*3+seq[-region*2-1:] return seq def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'): if len(utr) < input_len: if pad_method == 'pre': padded_utr = pad_mark * (input_len - len(utr)) + bos + utr elif pad_method == 'behind': padded_utr = utr+eos + pad_mark * (input_len - len(utr)) else: if pad_method == 'pre': padded_utr = bos+utr[-input_len:] elif pad_method == 'behind': padded_utr = utr[:input_len]+eos return padded_utr def find_unused_parameters(model,output): contributing_parameters = set(get_contributing_params(output)) all_parameters = set(model.parameters()) non_contributing = all_parameters - contributing_parameters print("未参与计算的参数:") for param in non_contributing: # 找到参数对应的名字 for name, p in model.named_parameters(): if p is param: print(f" {name}") def get_contributing_params(y, top_level=True): """找到对输出y有贡献的所有参数""" nf = y.grad_fn.next_functions if top_level else y.next_functions for f, _ in nf: try: yield f.variable except AttributeError: pass # 节点没有tensor if f is not None: yield from get_contributing_params(f, top_level=False)