ASR_latent_space / src /utils /config.py
Yiru Yang
fresh LFS version for Hugging Face push
5f2f308
from dataclasses import dataclass
from typing import List, Optional
import yaml
from pathlib import Path
@dataclass
class ModelConfig:
teacher_model: str
student_model: str
student_hidden_dim: int
teacher_hidden_dim: int
@dataclass
class DataConfig:
root_dir: str
sample_cap: int
val_ratio: float
batch_size: int
max_frames: int
sample_rate: int
pin_memory: bool
persistent_workers: bool
prefetch_factor: int
num_workers: int
@dataclass
class TrainingConfig:
max_steps: int
warmup_steps: int
eval_steps: int
log_steps: int
batch_size: int
grad_accum: int
max_grad_norm: float
@dataclass
class OptimizerConfig:
lr: float
weight_decay: float
@dataclass
class LoRAConfig:
r: int
alpha: int
dropout: float
target_modules: List[str]
@dataclass
class TAIDConfig:
"""TAID (Temperature-Aware Interpolation Distillation) 配置"""
start: float
mid: float
end: float
@dataclass
class DistillationConfig:
temperature: float
kl_weight: float
hidden_beta: float
taid: TAIDConfig
@dataclass
class OutputConfig:
dir: str
@dataclass
class Config:
data: DataConfig
model: ModelConfig
training: TrainingConfig
optimizer: OptimizerConfig
lora: LoRAConfig
distillation: DistillationConfig
output: OutputConfig
@classmethod
def from_yaml(cls, yaml_path: str) -> 'Config':
"""从YAML文件加载配置"""
yaml_path = Path(yaml_path)
if not yaml_path.exists():
raise FileNotFoundError(f"配置文件不存在:{yaml_path}")
with open(yaml_path, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
# 处理 TAID 配置
distill_config = config_dict['distillation']
taid_config = TAIDConfig(**distill_config.pop('taid'))
distill_config['taid'] = taid_config
return cls(
data=DataConfig(**config_dict['data']),
model=ModelConfig(**config_dict['model']),
training=TrainingConfig(**config_dict['training']),
optimizer=OptimizerConfig(**config_dict['optimizer']),
lora=LoRAConfig(**config_dict['lora']),
distillation=DistillationConfig(**distill_config),
output=OutputConfig(**config_dict['output'])
)
def save(self, save_path: str) -> None:
"""保存配置到YAML文件"""
config_dict = {
'data': self.data.__dict__,
'model': self.model.__dict__,
'training': self.training.__dict__,
'optimizer': self.optimizer.__dict__,
'lora': self.lora.__dict__,
'distillation': {
**{k: v for k, v in self.distillation.__dict__.items() if k != 'taid'},
'taid': self.distillation.taid.__dict__
},
'output': self.output.__dict__
}
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)
def load_config(config_path: str = "configs/default_config.yaml") -> Config:
"""加载配置的便捷函数"""
return Config.from_yaml(config_path)