malarsaravanan's picture
Upload 11 files
3072e38 verified
"""
Configuration management for ResNet-18 CIFAR-100 training
"""
import os
from dataclasses import dataclass, field
from typing import Tuple, Optional
@dataclass
class ModelConfig:
"""Model architecture configuration"""
name: str = "ResNet18"
num_classes: int = 100
input_channels: int = 3
@dataclass
class TrainingConfig:
"""Training hyperparameters configuration"""
batch_size: int = 128
learning_rate: float = 0.1
weight_decay: float = 5e-4
momentum: float = 0.9
epochs: int = 100
target_accuracy: float = 73.0
@dataclass
class DataConfig:
"""Data loading configuration"""
dataset_name: str = "CIFAR100"
data_dir: str = "./data"
num_workers: int = 2
pin_memory: bool = True
# Data augmentation parameters
random_crop_padding: int = 4
rotation_degrees: int = 15
color_jitter_brightness: float = 0.2
color_jitter_contrast: float = 0.2
color_jitter_saturation: float = 0.2
color_jitter_hue: float = 0.1
# Normalization values for CIFAR-100
mean: Tuple[float, float, float] = (0.5071, 0.4867, 0.4408)
std: Tuple[float, float, float] = (0.2675, 0.2565, 0.2761)
@dataclass
class SystemConfig:
"""System and device configuration"""
device: Optional[str] = None # Auto-detect if None
save_model: bool = True
model_save_path: str = "best_model.pth"
log_file_path: str = "training_logs.md"
@dataclass
class Config:
"""Main configuration combining all sub-configurations"""
model: ModelConfig = field(default_factory=ModelConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
data: DataConfig = field(default_factory=DataConfig)
system: SystemConfig = field(default_factory=SystemConfig)
@classmethod
def from_dict(cls, config_dict: dict) -> 'Config':
"""Create config from dictionary"""
return cls(
model=ModelConfig(**config_dict.get('model', {})),
training=TrainingConfig(**config_dict.get('training', {})),
data=DataConfig(**config_dict.get('data', {})),
system=SystemConfig(**config_dict.get('system', {}))
)
def to_dict(self) -> dict:
"""Convert config to dictionary"""
return {
'model': self.model.__dict__,
'training': self.training.__dict__,
'data': self.data.__dict__,
'system': self.system.__dict__
}
def get_device(config: SystemConfig) -> str:
"""Auto-detect best available device"""
import torch
if config.device is not None:
return config.device
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
return "cpu"