File size: 25,412 Bytes
4707555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
import itertools

from copy import deepcopy

import argparse
import socket
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, \
    r2_score
from typing import Optional

import math
import torch
import numpy as np
from fairseq.data import Dictionary
from torch.utils.data import DataLoader, DistributedSampler

from model.LMConfig import LMConfig
from model.codon_tables import AA_str

def compute_metrics_regression(preds, labels):
    spr = spearmanr(preds, labels)[0]
    pr = pearsonr(preds, labels)[0]
    mse = np.mean((preds - labels) ** 2)
    rmse = np.sqrt(mse)
    r2 = r2_score(labels,preds)
    return {'spearmanr':spr, 'pearsonr':pr,'mse':mse, 'rmse':rmse, 'r2':r2}


def compute_metrics_dict(preds, labels, average='macro', multi_class='ovr',cls='binary'):
    """
    计算分类任务的评估指标

    参数:
    preds: 预测值 (可以是类别标签或概率)
    labels: 真实标签
    average: 多分类时的平均方式 ('micro', 'macro', 'weighted', 'binary')
    multi_class: 多分类时AUC的计算方式 ('ovr', 'ovo')
    https://rcxqhxlmkf.feishu.cn/wiki/ONHBwenBjiNUkgk54mQcwVBznEg#share-RWVDdIzU2oC5dZxCgqKcHYtrnfc
    """
    if cls =='regression':
        return compute_metrics_regression(preds, labels)

    if cls =='identity':
        # codon
        pred_labels = np.argmax(preds, axis=1)
        pred_codon = [list(pred_labels[i:i+3]) for i in range(0,len(pred_labels),3)]
        true_codon = [list(labels[i:i+3]) for i in range(0,len(pred_labels),3)]
        identity_codon = sum(1 for c1, c2 in zip(pred_codon, true_codon) if c1 == c2)/len(true_codon)
        identity_NN = sum(1 for c1, c2 in zip(pred_labels, labels) if c1 == c2)/len(labels)
        return {'identity_codon':identity_codon,'identity_NN':identity_NN}
    # 如果preds是概率值而不是类别标签,转换为类别标签
    if preds.ndim > 1 and preds.shape[1] > 1:
        # 多分类概率情况
        pred_probs = None
        # pred_probs = np.softmax(preds, axis=1)
        pred_labels = np.argmax(preds, axis=1)
    elif preds.ndim > 1 and preds.shape[1] == 2:
        # 二分类概率情况
        pred_probs = np.sigmoid(preds, axis=1)
        pred_labels = (pred_probs[:, 1] > 0.5).astype(int)
    else:
        # 已经是类别标签
        pred_labels = preds
        pred_probs = None
    # if cls == 'identity':
    #     pred_labels = np.argmax(preds, axis=1)
    #     labels = labels ==  pred_labels
    # 基础分类指标
    accuracy = accuracy_score(labels, pred_labels)
    precision = precision_score(labels, pred_labels, average=average, zero_division=0)
    recall = recall_score(labels, pred_labels, average=average, zero_division=0)
    f1 = f1_score(labels, pred_labels, average=average, zero_division=0)

    # 计算混淆矩阵
    # cm = confusion_matrix(labels, pred_labels)

    # AUC-ROC (仅在可以计算概率时)
    # auc_roc = None
    # if pred_probs is not None:
    #     try:
    #         if len(np.unique(labels)) == 2:
    #             # 二分类
    #             auc_roc = roc_auc_score(labels, pred_probs[:, 1])
    #         else:
    #             # 多分类
    #             auc_roc = roc_auc_score(labels, pred_probs, multi_class=multi_class, average=average)
    #     except Exception as e:
    #         auc_roc = None
    #         exit(f'Error computing AUC-ROC for classification.{e}')

    # # 计算每个类别的指标(多分类时)
    # per_class_metrics = {}
    # if len(np.unique(labels)) > 2:
    #     precision_per_class = precision_score(labels, pred_labels, average=None, zero_division=0)
    #     recall_per_class = recall_score(labels, pred_labels, average=None, zero_division=0)
    #     f1_per_class = f1_score(labels, pred_labels, average=None, zero_division=0)
    #
    #     for i in range(len(precision_per_class)):
    #         per_class_metrics[f'class_{i}'] = {
    #             'precision': precision_per_class[i],
    #             'recall': recall_per_class[i],
    #             'f1': f1_per_class[i]
    #         }

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        # 'auc_roc': auc_roc,
        # 'confusion_matrix': cm,
        # 'per_class_metrics': per_class_metrics
    }
def flatten_col(col, group=1, exclude='_', frames=None):
    """
    展开给定列或者嵌套列表
    frames=['0','1','2','01','12','02','012']: validated when group ==1

    group =1 and frames=['0','1','2','01','12','02','012'] : return all frames
    group =1 and frames=None :NN
    group =2 :DiNN
    group =3 :codon
    """
    if type(col) == str:
        str1 = list(col)
        # print(str1)
    else:
        nested_list = col.apply(list).tolist()
        str1 = list(itertools.chain(*nested_list))
    exclude_num = str1.count(exclude)
    if exclude_num != 0:
        # delete space triplet
        triplets1 = [''.join(str1[i:i + 3]) for i in range(0, len(str1), 3)]
        triplets1 = [triplet for triplet in triplets1 if exclude not in triplet]
        str1 = ''.join(triplets1)
    # print(f"exclude_num:{exclude_num}")
    if group == 1:
        if frames:
            return multi_frames(deepcopy(str1), frames)
        return str1
    if len(str1) % group != 0:
        raise ValueError(f"字符串长度必须相同且是{group}的倍数")
    triplets1 = [''.join(str1[i:i + group]) for i in range(0, len(str1), group)]
    return triplets1
def multi_frames(str1, frames):
    str1_list = []
    for frame in frames:
        if len(frame) == 1:
            triplets1 = [str1[i + int(frame)] for i in range(0, len(str1), 3)]
        else:
            triplets1 = [''.join([str1[i + int(fr)] for fr in frame]) for i in range(0, len(str1) - 3 + 1, 3)]
        tmp = ''.join(triplets1)
        str1_list.append(tmp)
    return str1_list

def get_correct(labels, preds, prefix='', average='macro'):
    str1 = labels
    str2 = preds
    if len(str1) == 0:
        raise ValueError(f"{prefix}str1 is empty")
        # return {'label':''.join(str1),'pred':''.join(str2)}
    if len(str1) != len(str2):
        raise ValueError(f"字符串长度必须相同,str1_len:{len(str1)},str2_len:{len(str2)}")
        # return {'label':''.join(str1),'pred':''.join(str2)}
        # raise ValueError(f"字符串长度必须相同,str1_len:{len(str1)},str2_len:{len(str2)}")
    correct = sum(1 for c1, c2 in zip(str1, str2) if c1 == c2)
    data = {
        # 'correct': correct,
        # 'total': len(str1),
        'identity': correct / len(str1),
        'label_seq': ''.join(str1),
        'pred_seq': ''.join(str2)
    }
    alphabet = set(str1)|set(str2)
    alphabet = {k: v for k, v in zip(alphabet, range(len(alphabet)))}
    labels = [alphabet[k] for k in str1]
    preds = [alphabet[k] for k in str2]

    data.update(
        compute_metrics_dict(np.array(preds).flatten(), np.array(labels).flatten(), cls='binary', average=average))
    ans = {f'{prefix}{k}': v for k, v in data.items()}
    # print(f"{prefix}correct':correct,f'{prefix}total':{len(str1)}")
    # return {'correct':correct,'total':len(str1),'accuracy':correct/len(str1),'label':''.join(str1),'pred':''.join(str2)}
    # return {f'{prefix}correct':correct,f'{prefix}total':len(str1),f'{prefix}accuracy':correct/len(str1)}
    return ans


def calculate_accuracy(label, pred, group=1, exclude='_', frames=None):
    str1 = flatten_col(label, group=group, exclude=exclude, frames=frames)
    str2 = flatten_col(pred, group=group, exclude=exclude, frames=frames)
    # print(str1,str2)
    if frames:
        ans_dict = {}
        for frame, s1, s2 in zip(frames, str1, str2):
            ans_dict.update(get_correct(s1, s2, prefix=f'{frame}_'))
        return ans_dict
    else:
        return get_correct(str1, str2)
# Correlation computation along positions from https://github.com/lucidrains/enformer-pytorch/blob/main/enformer_pytorch/metrics.py
def MeanPearsonCorrCoefPerChannel(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    n_channels = preds.shape[1]  # 获取通道数
    reduce_dims = (0,1)  # 按样本和区域维度聚合

    # 初始化状态
    product = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)
    true_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)
    true_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)
    pred_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)
    pred_squared_sum = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)
    count = torch.zeros(n_channels, dtype=torch.float32, device=preds.device)

    # 计算每个状态的值
    product += torch.sum(preds * target, dim=reduce_dims)
    true_sum += torch.sum(target, dim=reduce_dims)
    true_squared_sum += torch.sum(torch.square(target), dim=reduce_dims)
    pred_sum += torch.sum(preds, dim=reduce_dims)
    pred_squared_sum += torch.sum(torch.square(preds), dim=reduce_dims)
    count += torch.sum(torch.ones_like(target), dim=reduce_dims)

    # 计算均值
    true_mean = true_sum / count
    pred_mean = pred_sum / count

    # 计算协方差
    covariance = (product
                  - true_mean * pred_sum
                  - pred_mean * true_sum
                  + count * true_mean * pred_mean)

    # 计算方差
    true_var = true_squared_sum - count * torch.square(true_mean)
    pred_var = pred_squared_sum - count * torch.square(pred_mean)

    # 计算标准差
    tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)

    # 计算皮尔逊相关系数
    correlation = covariance / tp_var

    # 返回损失值: 1 - 相关系数(越接近1越好,因此损失越小越好)
    # loss = 1 - correlation.abs()

    # 为保证返回的loss是可微的,在缺少有效count时返回0
    return correlation.abs()

def init_config(vocab_path,n_layers,max_seq_len):
    tokenizer = Dictionary.load(vocab_path)
    tokenizer.mask_index = tokenizer.add_symbol('<mask>') # ['<s>', '<pad>', '</s>', '<unk>', 'G', 'A', 'U', 'C', 'N', '<mask>']
    [tokenizer.add_symbol(word) for word in AA_str] # 10-31
    # lm_config = LMConfig(dim=256, logit_dim=tokenizer.nspecial,n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>
    lm_config = LMConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>
    # lm_config = LMConfig(dim=256, logit_dim=9,n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>
    return lm_config,tokenizer
    # vocab_path = args.arg_overrides['data'] + '/small_dict.txt'
    # tokenizer = Dictionary.load(vocab_path)
    # tokenizer.mask_index = tokenizer.add_symbol('<mask>') # ['<s>', '<pad>', '</s>', '<unk>', 'G', 'A', 'U', 'C', 'N', '<mask>']
    # lm_config = LMConfig(dim=256, n_layers=args.n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>


'''sorcket port'''
def find_free_port():
    # 创建一个临时的socket对象,绑定到一个随机端口
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        print("Binding to a random port...")
        s.bind(('127.0.0.1', 0))  # 绑定到本地主机的随机端口
        # 获取系统分配的端口号
        return s.getsockname()[1]

def is_port_in_use(port):
    # 检查端口是否已被占用
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('127.0.0.1', port)) == 0

def get_port():
    '''todo: 无法保证所有卡都是统一端口号,这个代码还有问题'''
    # 动态获取未被占用的端口号,并确保端口未被占用
    free_port = find_free_port()
    max_attempts = 100  # 最大尝试次数
    attempts = 0

    while is_port_in_use(free_port) and attempts < max_attempts:
        free_port = find_free_port()
        attempts += 1
        print(f"[{attempts}/{max_attempts}]Port {free_port} is in use, trying another port...")

    if attempts >= max_attempts:
        raise RuntimeError("无法找到未被占用的端口")
    return free_port
def get_pretraining_args():
    """pretrain"""
    # time torchrun --nproc_per_node 8 --master_port=22353 train_riboutr.py
    # --limit=-1 --batch_size=32 --n_layers=8 --use_wandb --ddp --local_rank=0 --epochs=100 --wandb_project=Amino_MOE0401 --use_moe=True --save_interval=100 --out_dir=exp_log/out_demo10
    parser = argparse.ArgumentParser(description="MiniMind Full SFT")
    parser.add_argument("--out_dir", type=str, default="out")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--celoss_alpha", type=float, default=0.1)
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--dtype", type=str, default="bfloat16")
    parser.add_argument("--use_wandb", action="store_true")
    parser.add_argument("--wandb_project", type=str, default="RiboUTR-PT")
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument("--ddp", action="store_true",help='DistributedDataParallel')
    parser.add_argument("--accumulation_steps", type=int, default=1)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--warmup_iters", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=10) # 100
    parser.add_argument("--save_interval", type=int, default=100) # 100
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument("--data_path", type=str, default="./dataset/sft_mini_512.jsonl") # sft_data.jsonl

    """dataset"""

    parser.add_argument('--n_layers', default=8, type=int) # 8
    parser.add_argument('--is_twod', default=True, type=bool)
    parser.add_argument('--max_seq_len', default=1205, type=int) # 512
    parser.add_argument('--use_moe', action='store_true', help="add moe layer") # False
    # ? mlm_pretrained_model_path
    # parser.add_argument("--mlm_pretrained_model_path", type=str, default="/public/home/jiang_jiuhong/soft/ERNIE-RNA/checkpoint/ERNIE-RNA_checkpoint/ERNIE-RNA_pretrain.pt")
    parser.add_argument("--mlm_pretrained_model_path", type=str, default=f"./checkpoint/ernierna.pt")
    # parser.add_argument("--mlm_pretrained_model_path", type=str, default=f"{username}/soft/ERNIE-RNA/checkpoint/ERNIE-RNA_checkpoint/ERNIE-RNA_pretrain.pt")
    parser.add_argument("--arg_overrides", type=dict,default={"data": f'./utils/ernie_rna/'}, help="The path of vocabulary")


    parser.add_argument('--finetune', action='store_true')  ## if --finetune: true
    parser.add_argument('--scaler', action='store_true')  ## if --finetune: true

    # parser.add_argument("--ffasta", default='./dataset/sequence/full.fa', type=str, help="The path of input seqs")
    parser.add_argument("--ffasta", default='./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl',
                        type=str, help="The path of input seqs")
    parser.add_argument("--exp_pretrain_data_path", default='./dataset/experiment/nature/', type=str,
                        help="The path of expPretrain data")
    parser.add_argument("--downstream_data_path", default='./dataset/downstream/', type=str,
                        help="The path of Task/TR,VL,TS.csv")
    parser.add_argument('--task', type=str, default='predict_web',
                        help='task in downstream dir')
    parser.add_argument("--seq_len", type=int, default=1205, help="The length of sequence")
    parser.add_argument("--env_counts", type=int, default=10, help="The length of sequence")
    parser.add_argument("--column", type=str, default="sequence", help="The sequences' column name")
    parser.add_argument("--label", type=str, default="label", help="The label")
    parser.add_argument("--pad_method", type=str, default="pre", help="The method which pad sequence")
    parser.add_argument("--region", default=300, type=int, help="The context length/2")
    parser.add_argument("--env_id", default=1, type=int, help="0")
    parser.add_argument("--limit", default=-1, type=int, help="less samples")
    parser.add_argument('--debug', action='store_true', help="debug mode")
    parser.add_argument('--codon_table_path', type=str, default="maotao_file/codon_table/codon_usage_{species}.csv", help="The method which pad sequence")


    """predict mode"""

    parser.add_argument('--predict', action='store_true', help="save predict result")
    parser.add_argument('--test_file', default=None, help="asign test file")
    """design mode"""
    parser.add_argument('--Kozak_GS6H_Stop3', default='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG', help="kozak,tag,Stop3")
    return parser

def get_dataset_args():
    parser = argparse.ArgumentParser()
    return parser
#     # parser.add_argument("--ffasta", default='./dataset/sequence/full.fa', type=str, help="The path of input seqs")
#     parser.add_argument("--ffasta", default='./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl', type=str, help="The path of input seqs")
#     parser.add_argument("--exp_pretrain_data_path", default='./dataset/experiment/nature/', type=str, help="The path of expPretrain data")
#     parser.add_argument("--downstream_data_path", default='./dataset/downstream/', type=str, help="The path of Task/TR,VL,TS.csv")
#     parser.add_argument("--arg_overrides", type=dict,default={"data": f'./utils/ernie_rna/'}, help="The path of vocabulary") # GRCh38.p14
#     parser.add_argument("--seq_len", type=int, default=50, help="The length of sequence")
#     parser.add_argument("--column", type=str, default="sequence", help="The sequences' column name")
#     parser.add_argument("--label", type=str, default="label", help="The label")
#     parser.add_argument("--pad_method", type=str, default="pre", help="The method which pad sequence")
#     parser.add_argument("--region", default=300, type=int, help="The context length/2")
#     parser.add_argument("--env_id", default=0, type=int, help="The context length/2")
#     parser.add_argument("--limit", default=10, type=int, help="less samples")
#     parser.add_argument('--debug', action='store_true', help="debug mode")
#     return parser


def unifi_dataloader(train_ds, args, ddp=False, data_tag='TR'):
    train_sampler = DistributedSampler(train_ds) if ddp else None
    drop_last = True if ddp else False
    if data_tag =='TR':
        train_loader = DataLoader(
            train_ds,
            batch_size=args.batch_size,
            pin_memory=True,
            drop_last=drop_last, # 以避免各卡处理的批次数量不同。  测试的时候容易把唯一的batch丢掉
            shuffle=False,
            num_workers=args.num_workers,
            sampler=train_sampler, # 验证集不需要
            # collate_fn = train_ds.collate_fn
        )
    else:
        train_loader = DataLoader(
            train_ds,
            batch_size=args.batch_size,
            pin_memory=True,
            drop_last=drop_last, # 以避免各卡处理的批次数量不同。
            shuffle=False,
            num_workers=args.num_workers,
        # collate_fn = train_ds.collate_fn

        )
    return train_loader
def ddp_broadcast_early_stopping(ddp_local_rank, args, early_stopping, current_loss, model,dist):
    # 分布式训练逻辑
    if ddp_local_rank == 0:
        early_stopping(current_loss, model)  # 如果监控的是SPR,直接传入-SPR即可
        if early_stopping.early_stop:early_stopping.counter = 0  # 重置 early_stopping.counter, 为了温度从高到低蒸馏
        # 广播 should_stop 的值到其他进程
        to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device)
        to_broadcast_counter = torch.tensor([early_stopping.counter], dtype=torch.int, device=args.device)
        dist.broadcast(to_broadcast, 0)
        dist.broadcast(to_broadcast_counter, 0)
    else:
        # 非主进程等待主进程广播
        to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device)  # 这个False只是缓冲池
        to_broadcast_counter = torch.tensor([0], dtype=torch.int, device=args.device)  # 这个False只是缓冲池
        dist.broadcast(to_broadcast, 0)
        dist.broadcast(to_broadcast_counter, 0)
        early_stopping.early_stop = bool(to_broadcast.item())
        early_stopping.counter = int(to_broadcast_counter.item())

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): Trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        return self.early_stop
    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        model.eval()
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..., {self.path}')
        self.save_model(model, self.path)
        self.val_loss_min = val_loss
    @staticmethod
    def save_model(model, path):
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        torch.save(state_dict,path)


def generate_inputs(x):
    pad_mark='_'
    bos='<'
    eos='>'
    region = 300
    link = 'N'

    # utr5 = x["UTR5"] if 'UTR5' in x else UTR5
    # utr3 = x["UTR3"] if 'UTR3' in x else UTR3
    # cds = x["CDS"] if 'CDS' in x else CDS

    utr5 = x["UTR5"]
    utr3 = x["UTR3"]
    cds = x["CDS"]

    utr5 = process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
    cds_h = process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
    cds_t = process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
    utr3 = process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
    seq = utr5 + cds_h + cds_t + utr3
    seq = seq[:region*2+1]+link*3+seq[-region*2-1:]
    return seq

def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'):
    if len(utr) < input_len:
        if pad_method == 'pre':
            padded_utr = pad_mark * (input_len - len(utr)) + bos + utr
        elif pad_method == 'behind':
            padded_utr = utr+eos + pad_mark * (input_len - len(utr))
    else:
        if pad_method == 'pre':
            padded_utr = bos+utr[-input_len:]
        elif pad_method == 'behind':
            padded_utr = utr[:input_len]+eos
    return padded_utr



def find_unused_parameters(model,output):
    contributing_parameters = set(get_contributing_params(output))
    all_parameters = set(model.parameters())
    non_contributing = all_parameters - contributing_parameters
    print("未参与计算的参数:")
    for param in non_contributing:
        # 找到参数对应的名字
        for name, p in model.named_parameters():
            if p is param:
                print(f"  {name}")
def get_contributing_params(y, top_level=True):
    """找到对输出y有贡献的所有参数"""
    nf = y.grad_fn.next_functions if top_level else y.next_functions
    for f, _ in nf:
        try:
            yield f.variable
        except AttributeError:
            pass  # 节点没有tensor
        if f is not None:
            yield from get_contributing_params(f, top_level=False)