CNN / config.py
N-I-M-I's picture
Upload folder using huggingface_hub
233caeb verified
"""
Configuration file for CIFAR-10 CNN project
"""
import torch
# Data configuration
DATA_DIR = './data'
BATCH_SIZE = 32
NUM_WORKERS = 0 # Set to 0 for better stability on some systems without GPU
# Model configuration
NUM_CLASSES = 10
INPUT_CHANNELS = 3
IMAGE_SIZE = 32
HIDDEN_SIZE = 256
NUM_LAYERS = 2
RNN_DROPOUT = 0.2
# Training configuration
EPOCHS = 5
LEARNING_RATE = 0.01 # Increased slightly for faster convergence in few epochs
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9
# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Checkpoint configuration
CHECKPOINT_DIR = './checkpoints'
BEST_MODEL_PATH = './checkpoints/best_model.pth'
LAST_MODEL_PATH = './checkpoints/last_model.pth'
# Visualization configuration
PLOTS_DIR = './plots'
# CIFAR-10 class names
CLASS_NAMES = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
# Data augmentation settings
USE_AUGMENTATION = True
RANDOM_CROP_PADDING = 4
RANDOM_HORIZONTAL_FLIP = 0.5
# Learning rate scheduler
USE_SCHEDULER = True
SCHEDULER_STEP_SIZE = 20
SCHEDULER_GAMMA = 0.1