File size: 2,693 Bytes
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdf1dec
b570cf2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse

from numpy import False_

class Config:
    def __init__(self):
        self.parser = argparse.ArgumentParser(description='PyTorch Lightning Training Config')
        
        # 数据配置
        self.parser.add_argument('--data_dir', type=str, default='./data', help='数据存储目录')
        self.parser.add_argument('--batch_size', type=int, default=10000, help='批次大小 (根据Readme设置为10000)')
        self.parser.add_argument('--num_workers', type=int, default=16, help='数据加载线程数')
        self.parser.add_argument('--K', type=int, default=30, help='最大打乱次数 (对于魔方设置为30)')
        self.parser.add_argument('--num_val_samples', type=int, default=10000 * 100, help='每个epoch样本数')
        self.parser.add_argument('--num_train_samples', type=int, default=10000 * 1000, help='每个epoch样本数')
    
        # 训练配置
        self.parser.add_argument('--max_epochs', type=int, default=20, help='最大训练轮数')
        self.parser.add_argument('--learning_rate', type=float, default=2e-4, help='学习率')
        self.parser.add_argument('--weight_decay', type=float, default=0, help='权重衰减 (根据Readme不使用正则化)')
        self.parser.add_argument('--devices', type=str, default="2", help="Devices to use: 'cpu', 'auto', '0', '1', '0,1', etc.")
        self.parser.add_argument('--convergence_threshold', type=float, default=0.05, help='收敛阈值 (根据Readme设置为0.05)')
        self.parser.add_argument('--chunk_size', type=int, default=10000 * 12, help='分块大小 (用于模型预测时的分块处理)')
        self.parser.add_argument('--compile', type=bool, default=True, help='是否编译模型')
        
        # 其他配置
        self.parser.add_argument('--log_dir', type=str, default='./logs', help='日志存储目录')
        self.parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='模型 checkpoint 存储目录')
        self.parser.add_argument('--converged_checkpoint_dir', type=str, default='converged_checkpoints', help='收敛模型 checkpoint 存储目录')
        self.parser.add_argument('--seed', type=int, default=42, help='随机种子')

        # inference
        self.parser.add_argument('--model_path', type=str, default='/app/checkpoint/final_model_K_30.pth', help='模型路径')
        self.parser.add_argument('--actions', type=str, default=None, help='指定的魔方动作序列,用空格分隔,如 "U R F D L B"')
        
    def parse_args(self):
        return self.parser.parse_args()

if __name__ == '__main__':
    config = Config()
    args = config.parse_args()
    print(args)