"""Configuration management utilities.""" from dataclasses import dataclass, field, asdict from typing import Optional, Dict, Any import yaml from pathlib import Path @dataclass class ModelConfig: """Model configuration.""" name: str = "facebook/wav2vec2-base" device: str = "cuda" checkpoint: Optional[str] = None @dataclass class RLConfig: """Reinforcement learning configuration.""" algorithm: str = "ppo" learning_rate: float = 3.0e-4 batch_size: int = 32 num_episodes: int = 1000 episode_length: int = 100 gamma: float = 0.99 clip_epsilon: float = 0.2 # PPO specific max_grad_norm: float = 1.0 @dataclass class DataConfig: """Data configuration.""" dataset_path: str = "data/processed" train_split: float = 0.7 val_split: float = 0.15 test_split: float = 0.15 sample_rate: int = 16000 @dataclass class CurriculumConfig: """Curriculum learning configuration.""" enabled: bool = True levels: int = 5 advancement_threshold: float = 0.8 @dataclass class OptimizationConfig: """Optimization configuration.""" mixed_precision: bool = True gradient_checkpointing: bool = False @dataclass class CheckpointConfig: """Checkpointing configuration.""" interval: int = 50 # episodes save_dir: str = "checkpoints" keep_last_n: int = 5 @dataclass class MonitoringConfig: """Monitoring configuration.""" log_interval: int = 10 visualization_interval: int = 50 tensorboard_dir: str = "runs" @dataclass class ReproducibilityConfig: """Reproducibility configuration.""" random_seed: int = 42 @dataclass class TrainingConfig: """Complete training configuration.""" model: ModelConfig = field(default_factory=ModelConfig) rl: RLConfig = field(default_factory=RLConfig) data: DataConfig = field(default_factory=DataConfig) curriculum: CurriculumConfig = field(default_factory=CurriculumConfig) optimization: OptimizationConfig = field(default_factory=OptimizationConfig) checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig) monitoring: MonitoringConfig = field(default_factory=MonitoringConfig) reproducibility: ReproducibilityConfig = field(default_factory=ReproducibilityConfig) @classmethod def from_yaml(cls, path: str) -> "TrainingConfig": """Load configuration from YAML file.""" with open(path, 'r') as f: config_dict = yaml.safe_load(f) return cls( model=ModelConfig(**config_dict.get('model', {})), rl=RLConfig(**config_dict.get('rl', {})), data=DataConfig(**config_dict.get('data', {})), curriculum=CurriculumConfig(**config_dict.get('curriculum', {})), optimization=OptimizationConfig(**config_dict.get('optimization', {})), checkpointing=CheckpointConfig(**config_dict.get('checkpointing', {})), monitoring=MonitoringConfig(**config_dict.get('monitoring', {})), reproducibility=ReproducibilityConfig(**config_dict.get('reproducibility', {})) ) def to_yaml(self, path: str) -> None: """Save configuration to YAML file.""" config_dict = { 'model': asdict(self.model), 'rl': asdict(self.rl), 'data': asdict(self.data), 'curriculum': asdict(self.curriculum), 'optimization': asdict(self.optimization), 'checkpointing': asdict(self.checkpointing), 'monitoring': asdict(self.monitoring), 'reproducibility': asdict(self.reproducibility) } Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, 'w') as f: yaml.dump(config_dict, f, default_flow_style=False) def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary.""" return { 'model': asdict(self.model), 'rl': asdict(self.rl), 'data': asdict(self.data), 'curriculum': asdict(self.curriculum), 'optimization': asdict(self.optimization), 'checkpointing': asdict(self.checkpointing), 'monitoring': asdict(self.monitoring), 'reproducibility': asdict(self.reproducibility) }