| """ |
| Diffutslator 配置文件 |
| 所有超参数集中管理 |
| """ |
|
|
| from dataclasses import dataclass, field |
| from typing import Optional |
| import os |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """模型配置""" |
| d_model: int = 256 |
| n_heads: int = 4 |
| n_layers: int = 4 |
| d_ff: int = 512 |
| max_len: int = 128 |
| dropout: float = 0.1 |
| |
| |
| vocab_size_zh: int = 8000 |
| vocab_size_en: int = 8000 |
| |
| |
| pad_token: str = "<pad>" |
| sos_token: str = "<sos>" |
| eos_token: str = "<eos>" |
| unk_token: str = "<unk>" |
| mask_token: str = "<mask>" |
|
|
|
|
| @dataclass |
| class DiffusionConfig: |
| """扩散过程配置""" |
| timesteps: int = 1000 |
| ddim_steps: int = 50 |
| |
| |
| beta_start: float = 0.0001 |
| beta_end: float = 0.02 |
| |
| |
| length_noise_scale: float = 0.3 |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """训练配置""" |
| batch_size: int = 64 |
| gradient_accumulation: int = 1 |
| |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 500 |
| |
| epochs: int = 10 |
| save_every: int = 1 |
| eval_every: int = 100 |
| |
| |
| quick_mode: bool = False |
| quick_samples: int = 1000 |
| |
| |
| checkpoint_dir: str = "checkpoints" |
| resume: Optional[str] = None |
|
|
|
|
| @dataclass |
| class DataConfig: |
| """数据配置""" |
| |
| tatoeba_path: str = "../_dataset/tatoeba.tsv" |
| cveto_zh_path: str = "../_dataset/cveto/train.zh" |
| cveto_en_path: str = "../_dataset/cveto/train.en" |
| |
| |
| max_samples: Optional[int] = None |
| min_len: int = 2 |
| max_len: int = 128 |
| |
| |
| use_cache: bool = True |
| cache_dir: str = ".cache" |
|
|
|
|
| @dataclass |
| class Config: |
| """总配置""" |
| model: ModelConfig = field(default_factory=ModelConfig) |
| diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) |
| training: TrainingConfig = field(default_factory=TrainingConfig) |
| data: DataConfig = field(default_factory=DataConfig) |
| |
| |
| project_dir: str = "" |
| |
| def __post_init__(self): |
| |
| self.project_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
| |
| if not os.path.isabs(self.data.tatoeba_path): |
| self.data.tatoeba_path = os.path.join(self.project_dir, self.data.tatoeba_path) |
| if not os.path.isabs(self.data.cveto_zh_path): |
| self.data.cveto_zh_path = os.path.join(self.project_dir, self.data.cveto_zh_path) |
| if not os.path.isabs(self.data.cveto_en_path): |
| self.data.cveto_en_path = os.path.join(self.project_dir, self.data.cveto_en_path) |
| |
| |
| os.makedirs(os.path.join(self.project_dir, self.training.checkpoint_dir), exist_ok=True) |
| os.makedirs(os.path.join(self.project_dir, self.data.cache_dir), exist_ok=True) |
| |
| @classmethod |
| def quick(cls) -> "Config": |
| """快速验证模式配置""" |
| config = cls() |
| config.training.quick_mode = True |
| config.training.quick_samples = 1000 |
| config.training.epochs = 5 |
| config.training.batch_size = 32 |
| config.data.max_samples = 1000 |
| return config |
|
|
|
|
| |
| default_config = Config() |
|
|