File size: 3,996 Bytes
2651102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Diffutslator 配置文件
所有超参数集中管理
"""

from dataclasses import dataclass, field
from typing import Optional
import os


@dataclass
class ModelConfig:
    """模型配置"""
    d_model: int = 256           # 嵌入维度
    n_heads: int = 4             # 注意力头数
    n_layers: int = 4            # Transformer层数
    d_ff: int = 512              # 前馈网络维度
    max_len: int = 128           # 最大序列长度
    dropout: float = 0.1         # Dropout率
    
    # 词表
    vocab_size_zh: int = 8000    # 中文词表大小
    vocab_size_en: int = 8000    # 英文词表大小
    
    # 特殊token
    pad_token: str = "<pad>"
    sos_token: str = "<sos>"
    eos_token: str = "<eos>"
    unk_token: str = "<unk>"
    mask_token: str = "<mask>"


@dataclass
class DiffusionConfig:
    """扩散过程配置"""
    timesteps: int = 1000        # 训练时的扩散步数
    ddim_steps: int = 50         # DDIM推理步数
    
    # 噪声调度 - 线性
    beta_start: float = 0.0001
    beta_end: float = 0.02
    
    # 长度变化
    length_noise_scale: float = 0.3  # 扩散时长度变化的噪声程度


@dataclass
class TrainingConfig:
    """训练配置"""
    batch_size: int = 64         # 批量大小(CPU擅长大批量)
    gradient_accumulation: int = 1  # 梯度累积步数
    
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    warmup_steps: int = 500
    
    epochs: int = 10
    save_every: int = 1          # 每多少epoch保存一次
    eval_every: int = 100        # 每多少步评估一次
    
    # 快速验证模式
    quick_mode: bool = False
    quick_samples: int = 1000
    
    # 检查点
    checkpoint_dir: str = "checkpoints"
    resume: Optional[str] = None  # 恢复训练的检查点路径


@dataclass
class DataConfig:
    """数据配置"""
    # 数据集路径
    tatoeba_path: str = "../_dataset/tatoeba.tsv"
    cveto_zh_path: str = "../_dataset/cveto/train.zh"
    cveto_en_path: str = "../_dataset/cveto/train.en"
    
    # 数据处理
    max_samples: Optional[int] = None  # 最大样本数(None=全部)
    min_len: int = 2             # 最小句子长度
    max_len: int = 128           # 最大句子长度
    
    # 缓存
    use_cache: bool = True       # 是否缓存预处理后的数据
    cache_dir: str = ".cache"


@dataclass
class Config:
    """总配置"""
    model: ModelConfig = field(default_factory=ModelConfig)
    diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    
    # 项目根目录
    project_dir: str = ""
    
    def __post_init__(self):
        # 设置项目根目录
        self.project_dir = os.path.dirname(os.path.abspath(__file__))
        
        # 更新相对路径为绝对路径
        if not os.path.isabs(self.data.tatoeba_path):
            self.data.tatoeba_path = os.path.join(self.project_dir, self.data.tatoeba_path)
        if not os.path.isabs(self.data.cveto_zh_path):
            self.data.cveto_zh_path = os.path.join(self.project_dir, self.data.cveto_zh_path)
        if not os.path.isabs(self.data.cveto_en_path):
            self.data.cveto_en_path = os.path.join(self.project_dir, self.data.cveto_en_path)
        
        # 创建必要目录
        os.makedirs(os.path.join(self.project_dir, self.training.checkpoint_dir), exist_ok=True)
        os.makedirs(os.path.join(self.project_dir, self.data.cache_dir), exist_ok=True)
    
    @classmethod
    def quick(cls) -> "Config":
        """快速验证模式配置"""
        config = cls()
        config.training.quick_mode = True
        config.training.quick_samples = 1000
        config.training.epochs = 5
        config.training.batch_size = 32  # CPU擅长大批量
        config.data.max_samples = 1000
        return config


# 默认配置实例
default_config = Config()