trash-classification-pytorch / src /hyperparameter_tuning.py
neecat's picture
add modified files
57d41d5
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from src.model import TrashNetClassifier
from src.data_loader import get_dataloaders
from src import config
import logging
import time
from datetime import datetime
import os
def setup_tuning_logging(log_dir):
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"hyperparameter_tuning_{timestamp}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return log_file
def train_model_for_validation(model, train_loader, val_loader, lr, weight_decay, device, epochs=config.TUNING_EPOCHS):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay
)
best_val_acc = 0.0
logging.info(f"Starting validation training with lr={lr}, weight_decay={weight_decay}")
for epoch in range(epochs):
model.train()
running_loss, running_acc = 0.0, 0.0
for batch_idx, (images, labels) in enumerate(train_loader):
if batch_idx % 20 == 0:
logging.info(f" Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}")
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
preds = torch.argmax(outputs, dim=1)
acc = (preds == labels).float().mean()
running_loss += loss.item()
running_acc += acc.item()
train_loss = running_loss / len(train_loader)
train_acc = running_acc / len(train_loader)
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
preds = torch.argmax(outputs, dim=1)
acc = (preds == labels).float().mean()
val_loss += loss.item()
val_acc += acc.item()
val_loss /= len(val_loader)
val_acc /= len(val_loader)
logging.info(f" Epoch {epoch+1}/{epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
logging.info(f" New best validation accuracy: {best_val_acc:.4f}")
return best_val_acc
def run_hyperparameter_search():
log_file = setup_tuning_logging(config.LOG_DIR)
logging.info(f"Hyperparameter tuning logs will be saved to: {log_file}")
device = torch.device(config.DEVICE)
logging.info(f"Using device: {device}")
logging.info("Loading datasets...")
train_loader, val_loader, _, class_names = get_dataloaders(
data_dir=config.DATA_DIR,
batch_size=config.TUNING_BATCH_SIZE,
image_size=config.IMAGE_SIZE,
num_workers=config.NUM_WORKERS
)
learning_rates = [1e-5, 1e-4, 5e-4, 1e-3]
weight_decays = [1e-5, 1e-4, 1e-3]
num_trials = config.TUNING_TRIALS
best_acc = 0.0
best_config = {"lr": 0, "weight_decay": 0}
logging.info("Starting hyperparameter search...")
logging.info(f"Number of trials: {num_trials}")
logging.info(f"Learning rates to try: {learning_rates}")
logging.info(f"Weight decays to try: {weight_decays}")
start_time = time.time()
for trial in range(num_trials):
trial_start = time.time()
lr = random.choice(learning_rates)
weight_decay = random.choice(weight_decays)
logging.info(f"\nTrial {trial+1}/{num_trials}")
logging.info(f"Testing lr={lr}, weight_decay={weight_decay}")
model = TrashNetClassifier(num_classes=len(class_names))
val_acc = train_model_for_validation(
model=model,
train_loader=train_loader,
val_loader=val_loader,
lr=lr,
weight_decay=weight_decay,
device=device
)
trial_time = time.time() - trial_start
logging.info(f"Trial {trial+1} completed in {trial_time:.2f}s")
logging.info(f"Validation accuracy: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
best_config = {"lr": lr, "weight_decay": weight_decay}
logging.info(f"New best config found!")
total_time = time.time() - start_time
logging.info(f"\nHyperparameter search completed in {total_time:.2f}s")
logging.info(f"Best config: lr={best_config['lr']}, weight_decay={best_config['weight_decay']}")
logging.info(f"Best validation accuracy: {best_acc:.4f}")
return best_config