Spaces:
Runtime error
Runtime error
| """ | |
| Configuration management for ResNet-18 CIFAR-100 training | |
| """ | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Tuple, Optional | |
| class ModelConfig: | |
| """Model architecture configuration""" | |
| name: str = "ResNet18" | |
| num_classes: int = 100 | |
| input_channels: int = 3 | |
| class TrainingConfig: | |
| """Training hyperparameters configuration""" | |
| batch_size: int = 128 | |
| learning_rate: float = 0.1 | |
| weight_decay: float = 5e-4 | |
| momentum: float = 0.9 | |
| epochs: int = 100 | |
| target_accuracy: float = 73.0 | |
| class DataConfig: | |
| """Data loading configuration""" | |
| dataset_name: str = "CIFAR100" | |
| data_dir: str = "./data" | |
| num_workers: int = 2 | |
| pin_memory: bool = True | |
| # Data augmentation parameters | |
| random_crop_padding: int = 4 | |
| rotation_degrees: int = 15 | |
| color_jitter_brightness: float = 0.2 | |
| color_jitter_contrast: float = 0.2 | |
| color_jitter_saturation: float = 0.2 | |
| color_jitter_hue: float = 0.1 | |
| # Normalization values for CIFAR-100 | |
| mean: Tuple[float, float, float] = (0.5071, 0.4867, 0.4408) | |
| std: Tuple[float, float, float] = (0.2675, 0.2565, 0.2761) | |
| class SystemConfig: | |
| """System and device configuration""" | |
| device: Optional[str] = None # Auto-detect if None | |
| save_model: bool = True | |
| model_save_path: str = "best_model.pth" | |
| log_file_path: str = "training_logs.md" | |
| class Config: | |
| """Main configuration combining all sub-configurations""" | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| training: TrainingConfig = field(default_factory=TrainingConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| system: SystemConfig = field(default_factory=SystemConfig) | |
| def from_dict(cls, config_dict: dict) -> 'Config': | |
| """Create config from dictionary""" | |
| return cls( | |
| model=ModelConfig(**config_dict.get('model', {})), | |
| training=TrainingConfig(**config_dict.get('training', {})), | |
| data=DataConfig(**config_dict.get('data', {})), | |
| system=SystemConfig(**config_dict.get('system', {})) | |
| ) | |
| def to_dict(self) -> dict: | |
| """Convert config to dictionary""" | |
| return { | |
| 'model': self.model.__dict__, | |
| 'training': self.training.__dict__, | |
| 'data': self.data.__dict__, | |
| 'system': self.system.__dict__ | |
| } | |
| def get_device(config: SystemConfig) -> str: | |
| """Auto-detect best available device""" | |
| import torch | |
| if config.device is not None: | |
| return config.device | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| elif torch.cuda.is_available(): | |
| return "cuda" | |
| else: | |
| return "cpu" | |