File size: 3,349 Bytes
5f2f308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from dataclasses import dataclass
from typing import List, Optional
import yaml
from pathlib import Path


@dataclass
class ModelConfig:
    teacher_model: str
    student_model: str
    student_hidden_dim: int
    teacher_hidden_dim: int


@dataclass
class DataConfig:
    root_dir: str
    sample_cap: int
    val_ratio: float
    batch_size: int
    max_frames: int
    sample_rate: int
    pin_memory: bool
    persistent_workers: bool
    prefetch_factor: int
    num_workers: int


@dataclass
class TrainingConfig:
    max_steps: int
    warmup_steps: int
    eval_steps: int
    log_steps: int
    batch_size: int
    grad_accum: int
    max_grad_norm: float


@dataclass
class OptimizerConfig:
    lr: float
    weight_decay: float


@dataclass
class LoRAConfig:
    r: int
    alpha: int
    dropout: float
    target_modules: List[str]


@dataclass
class TAIDConfig:
    """TAID (Temperature-Aware Interpolation Distillation) 配置"""
    start: float
    mid: float
    end: float


@dataclass
class DistillationConfig:
    temperature: float
    kl_weight: float
    hidden_beta: float
    taid: TAIDConfig


@dataclass
class OutputConfig:
    dir: str


@dataclass
class Config:
    data: DataConfig
    model: ModelConfig
    training: TrainingConfig
    optimizer: OptimizerConfig
    lora: LoRAConfig
    distillation: DistillationConfig
    output: OutputConfig

    @classmethod
    def from_yaml(cls, yaml_path: str) -> 'Config':
        """从YAML文件加载配置"""
        yaml_path = Path(yaml_path)
        if not yaml_path.exists():
            raise FileNotFoundError(f"配置文件不存在:{yaml_path}")

        with open(yaml_path, 'r', encoding='utf-8') as f:
            config_dict = yaml.safe_load(f)

        # 处理 TAID 配置
        distill_config = config_dict['distillation']
        taid_config = TAIDConfig(**distill_config.pop('taid'))
        distill_config['taid'] = taid_config

        return cls(
            data=DataConfig(**config_dict['data']),
            model=ModelConfig(**config_dict['model']),
            training=TrainingConfig(**config_dict['training']),
            optimizer=OptimizerConfig(**config_dict['optimizer']),
            lora=LoRAConfig(**config_dict['lora']),
            distillation=DistillationConfig(**distill_config),
            output=OutputConfig(**config_dict['output'])
        )

    def save(self, save_path: str) -> None:
        """保存配置到YAML文件"""
        config_dict = {
            'data': self.data.__dict__,
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'optimizer': self.optimizer.__dict__,
            'lora': self.lora.__dict__,
            'distillation': {
                **{k: v for k, v in self.distillation.__dict__.items() if k != 'taid'},
                'taid': self.distillation.taid.__dict__
            },
            'output': self.output.__dict__
        }

        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(save_path, 'w', encoding='utf-8') as f:
            yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)


def load_config(config_path: str = "configs/default_config.yaml") -> Config:
    """加载配置的便捷函数"""
    return Config.from_yaml(config_path)