from dataclasses import dataclass from typing import List, Optional import yaml from pathlib import Path @dataclass class ModelConfig: teacher_model: str student_model: str student_hidden_dim: int teacher_hidden_dim: int @dataclass class DataConfig: root_dir: str sample_cap: int val_ratio: float batch_size: int max_frames: int sample_rate: int pin_memory: bool persistent_workers: bool prefetch_factor: int num_workers: int @dataclass class TrainingConfig: max_steps: int warmup_steps: int eval_steps: int log_steps: int batch_size: int grad_accum: int max_grad_norm: float @dataclass class OptimizerConfig: lr: float weight_decay: float @dataclass class LoRAConfig: r: int alpha: int dropout: float target_modules: List[str] @dataclass class TAIDConfig: """TAID (Temperature-Aware Interpolation Distillation) 配置""" start: float mid: float end: float @dataclass class DistillationConfig: temperature: float kl_weight: float hidden_beta: float taid: TAIDConfig @dataclass class OutputConfig: dir: str @dataclass class Config: data: DataConfig model: ModelConfig training: TrainingConfig optimizer: OptimizerConfig lora: LoRAConfig distillation: DistillationConfig output: OutputConfig @classmethod def from_yaml(cls, yaml_path: str) -> 'Config': """从YAML文件加载配置""" yaml_path = Path(yaml_path) if not yaml_path.exists(): raise FileNotFoundError(f"配置文件不存在:{yaml_path}") with open(yaml_path, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) # 处理 TAID 配置 distill_config = config_dict['distillation'] taid_config = TAIDConfig(**distill_config.pop('taid')) distill_config['taid'] = taid_config return cls( data=DataConfig(**config_dict['data']), model=ModelConfig(**config_dict['model']), training=TrainingConfig(**config_dict['training']), optimizer=OptimizerConfig(**config_dict['optimizer']), lora=LoRAConfig(**config_dict['lora']), distillation=DistillationConfig(**distill_config), output=OutputConfig(**config_dict['output']) ) def save(self, save_path: str) -> None: """保存配置到YAML文件""" config_dict = { 'data': self.data.__dict__, 'model': self.model.__dict__, 'training': self.training.__dict__, 'optimizer': self.optimizer.__dict__, 'lora': self.lora.__dict__, 'distillation': { **{k: v for k, v in self.distillation.__dict__.items() if k != 'taid'}, 'taid': self.distillation.taid.__dict__ }, 'output': self.output.__dict__ } save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, 'w', encoding='utf-8') as f: yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True) def load_config(config_path: str = "configs/default_config.yaml") -> Config: """加载配置的便捷函数""" return Config.from_yaml(config_path)