Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| import random | |
| from src.model import TrashNetClassifier | |
| from src.data_loader import get_dataloaders | |
| from src import config | |
| import logging | |
| import time | |
| from datetime import datetime | |
| import os | |
| def setup_tuning_logging(log_dir): | |
| os.makedirs(log_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| log_file = os.path.join(log_dir, f"hyperparameter_tuning_{timestamp}.log") | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| return log_file | |
| def train_model_for_validation(model, train_loader, val_loader, lr, weight_decay, device, epochs=config.TUNING_EPOCHS): | |
| model = model.to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam( | |
| model.parameters(), | |
| lr=lr, | |
| weight_decay=weight_decay | |
| ) | |
| best_val_acc = 0.0 | |
| logging.info(f"Starting validation training with lr={lr}, weight_decay={weight_decay}") | |
| for epoch in range(epochs): | |
| model.train() | |
| running_loss, running_acc = 0.0, 0.0 | |
| for batch_idx, (images, labels) in enumerate(train_loader): | |
| if batch_idx % 20 == 0: | |
| logging.info(f" Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}") | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| preds = torch.argmax(outputs, dim=1) | |
| acc = (preds == labels).float().mean() | |
| running_loss += loss.item() | |
| running_acc += acc.item() | |
| train_loss = running_loss / len(train_loader) | |
| train_acc = running_acc / len(train_loader) | |
| model.eval() | |
| val_loss, val_acc = 0.0, 0.0 | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| preds = torch.argmax(outputs, dim=1) | |
| acc = (preds == labels).float().mean() | |
| val_loss += loss.item() | |
| val_acc += acc.item() | |
| val_loss /= len(val_loader) | |
| val_acc /= len(val_loader) | |
| logging.info(f" Epoch {epoch+1}/{epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}") | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| logging.info(f" New best validation accuracy: {best_val_acc:.4f}") | |
| return best_val_acc | |
| def run_hyperparameter_search(): | |
| log_file = setup_tuning_logging(config.LOG_DIR) | |
| logging.info(f"Hyperparameter tuning logs will be saved to: {log_file}") | |
| device = torch.device(config.DEVICE) | |
| logging.info(f"Using device: {device}") | |
| logging.info("Loading datasets...") | |
| train_loader, val_loader, _, class_names = get_dataloaders( | |
| data_dir=config.DATA_DIR, | |
| batch_size=config.TUNING_BATCH_SIZE, | |
| image_size=config.IMAGE_SIZE, | |
| num_workers=config.NUM_WORKERS | |
| ) | |
| learning_rates = [1e-5, 1e-4, 5e-4, 1e-3] | |
| weight_decays = [1e-5, 1e-4, 1e-3] | |
| num_trials = config.TUNING_TRIALS | |
| best_acc = 0.0 | |
| best_config = {"lr": 0, "weight_decay": 0} | |
| logging.info("Starting hyperparameter search...") | |
| logging.info(f"Number of trials: {num_trials}") | |
| logging.info(f"Learning rates to try: {learning_rates}") | |
| logging.info(f"Weight decays to try: {weight_decays}") | |
| start_time = time.time() | |
| for trial in range(num_trials): | |
| trial_start = time.time() | |
| lr = random.choice(learning_rates) | |
| weight_decay = random.choice(weight_decays) | |
| logging.info(f"\nTrial {trial+1}/{num_trials}") | |
| logging.info(f"Testing lr={lr}, weight_decay={weight_decay}") | |
| model = TrashNetClassifier(num_classes=len(class_names)) | |
| val_acc = train_model_for_validation( | |
| model=model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| device=device | |
| ) | |
| trial_time = time.time() - trial_start | |
| logging.info(f"Trial {trial+1} completed in {trial_time:.2f}s") | |
| logging.info(f"Validation accuracy: {val_acc:.4f}") | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| best_config = {"lr": lr, "weight_decay": weight_decay} | |
| logging.info(f"New best config found!") | |
| total_time = time.time() - start_time | |
| logging.info(f"\nHyperparameter search completed in {total_time:.2f}s") | |
| logging.info(f"Best config: lr={best_config['lr']}, weight_decay={best_config['weight_decay']}") | |
| logging.info(f"Best validation accuracy: {best_acc:.4f}") | |
| return best_config |