File size: 3,349 Bytes
5f2f308 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | from dataclasses import dataclass
from typing import List, Optional
import yaml
from pathlib import Path
@dataclass
class ModelConfig:
teacher_model: str
student_model: str
student_hidden_dim: int
teacher_hidden_dim: int
@dataclass
class DataConfig:
root_dir: str
sample_cap: int
val_ratio: float
batch_size: int
max_frames: int
sample_rate: int
pin_memory: bool
persistent_workers: bool
prefetch_factor: int
num_workers: int
@dataclass
class TrainingConfig:
max_steps: int
warmup_steps: int
eval_steps: int
log_steps: int
batch_size: int
grad_accum: int
max_grad_norm: float
@dataclass
class OptimizerConfig:
lr: float
weight_decay: float
@dataclass
class LoRAConfig:
r: int
alpha: int
dropout: float
target_modules: List[str]
@dataclass
class TAIDConfig:
"""TAID (Temperature-Aware Interpolation Distillation) 配置"""
start: float
mid: float
end: float
@dataclass
class DistillationConfig:
temperature: float
kl_weight: float
hidden_beta: float
taid: TAIDConfig
@dataclass
class OutputConfig:
dir: str
@dataclass
class Config:
data: DataConfig
model: ModelConfig
training: TrainingConfig
optimizer: OptimizerConfig
lora: LoRAConfig
distillation: DistillationConfig
output: OutputConfig
@classmethod
def from_yaml(cls, yaml_path: str) -> 'Config':
"""从YAML文件加载配置"""
yaml_path = Path(yaml_path)
if not yaml_path.exists():
raise FileNotFoundError(f"配置文件不存在:{yaml_path}")
with open(yaml_path, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
# 处理 TAID 配置
distill_config = config_dict['distillation']
taid_config = TAIDConfig(**distill_config.pop('taid'))
distill_config['taid'] = taid_config
return cls(
data=DataConfig(**config_dict['data']),
model=ModelConfig(**config_dict['model']),
training=TrainingConfig(**config_dict['training']),
optimizer=OptimizerConfig(**config_dict['optimizer']),
lora=LoRAConfig(**config_dict['lora']),
distillation=DistillationConfig(**distill_config),
output=OutputConfig(**config_dict['output'])
)
def save(self, save_path: str) -> None:
"""保存配置到YAML文件"""
config_dict = {
'data': self.data.__dict__,
'model': self.model.__dict__,
'training': self.training.__dict__,
'optimizer': self.optimizer.__dict__,
'lora': self.lora.__dict__,
'distillation': {
**{k: v for k, v in self.distillation.__dict__.items() if k != 'taid'},
'taid': self.distillation.taid.__dict__
},
'output': self.output.__dict__
}
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)
def load_config(config_path: str = "configs/default_config.yaml") -> Config:
"""加载配置的便捷函数"""
return Config.from_yaml(config_path) |