File size: 3,905 Bytes
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import random
import numpy as np
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, model_checkpoint
from pytorch_lightning.loggers import TensorBoardLogger
from config import Config
from dataset.dataloader import RubikDataModule
from model.DeepcubeA_module import DeepcubeA
import datetime

torch.set_float32_matmul_precision('medium')

def main():
    # 解析配置
    config = Config()
    args = config.parse_args()
    
    # 设置随机种子
    seed_everything(args.seed, workers=True)

    args.log_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M"))
    args.checkpoint_dir = os.path.join(args.log_dir, args.checkpoint_dir)
    args.converged_checkpoint_dir = os.path.join(args.log_dir, args.converged_checkpoint_dir)

    # 设置 accelerator & devices
    if args.devices.lower() == "cpu":
        accelerator = "cpu"
        devices = 1   # CPU 默认只用一个进程
    elif args.devices.lower() == "auto":
        accelerator = "gpu" if torch.cuda.is_available() else "cpu"
        devices = "auto"
    else:
        # 用户指定了 GPU id(s)
        accelerator = "gpu"
        if "," in args.devices:
            devices = [int(x) for x in args.devices.split(",")]
        else:
            devices = [int(args.devices)]
    
    # 创建必要的目录
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.converged_checkpoint_dir, exist_ok=True)
    
    # 初始化模型(只初始化一次,后续复用)
    model = DeepcubeA(args)
    
    # 设置初始K值和最大K值
    initial_K = 16
    max_K = args.K  # 可以根据需要调整

    model_e_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_14.pth"
    model.model_theta_e.load_state_dict(torch.load(model_e_checkpoint))
    model_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_15.pth"
    model.model_theta.load_state_dict(torch.load(model_checkpoint))

    for K in range(initial_K, max_K + 1):
        print(f'\n--- 开始训练 K={K} ---')
        
        # 更新模型的K值
        model.K = K
        
        # 创建新的数据集配置
        args.K = K  # 设置当前K值
        
        # 初始化新的数据模块
        data_module = RubikDataModule(args)
        
        # # 设置回调函数,暂时不添加这个,因为好像没什么用
        # checkpoint_callback = ModelCheckpoint(
        #     dirpath=args.checkpoint_dir,
        #     filename=f'K_{K}_'+'{epoch}-{val_loss:.2f}',
        #     save_top_k=3,
        #     monitor='val_loss',
        #     mode='min'
        # )
        
        early_stopping_callback = EarlyStopping(
            monitor='val_loss',
            patience=5,
            mode='min',
        )
        
        # lr_monitor = LearningRateMonitor(logging_interval='epoch')
        
        # # 设置日志记录器(每个K值使用不同的日志目录)
        # logger = TensorBoardLogger(
        #     save_dir=args.log_dir,
        #     name=f'train_logs_K_{K}'
        # )
        
        # 初始化新的训练器,默认每个epoch验证一次,即5000步
        trainer = Trainer(
            max_epochs=args.max_epochs,
            accelerator=accelerator,
            precision="16-mixed",  # 启用混合精度
            devices=devices,
            logger=False,
            callbacks=[early_stopping_callback],
            deterministic=True,
            enable_progress_bar=True,
            enable_checkpointing=True
        )

        print(trainer.log_every_n_steps)
        
        # 训练模型
        trainer.fit(model, datamodule=data_module)
        
        print(f'--- 完成训练 K={K} ---\n')
    
if __name__ == '__main__':
    main()