File size: 5,297 Bytes
57d41d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from src.model import TrashNetClassifier
from src.data_loader import get_dataloaders
from src import config


import logging
import time
from datetime import datetime
import os


def setup_tuning_logging(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"hyperparameter_tuning_{timestamp}.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return log_file

def train_model_for_validation(model, train_loader, val_loader, lr, weight_decay, device, epochs=config.TUNING_EPOCHS):
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(), 
        lr=lr,
        weight_decay=weight_decay
    )
    
    best_val_acc = 0.0
    
    logging.info(f"Starting validation training with lr={lr}, weight_decay={weight_decay}")
    
    for epoch in range(epochs):

        model.train()
        running_loss, running_acc = 0.0, 0.0
        for batch_idx, (images, labels) in enumerate(train_loader):
            if batch_idx % 20 == 0:
                logging.info(f"  Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}")
                
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            preds = torch.argmax(outputs, dim=1)
            acc = (preds == labels).float().mean()
            running_loss += loss.item()
            running_acc += acc.item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = running_acc / len(train_loader)
        

        model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                preds = torch.argmax(outputs, dim=1)
                acc = (preds == labels).float().mean()
                val_loss += loss.item()
                val_acc += acc.item()
        
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)
        
        logging.info(f"  Epoch {epoch+1}/{epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            logging.info(f"  New best validation accuracy: {best_val_acc:.4f}")
    
    return best_val_acc

def run_hyperparameter_search():

    log_file = setup_tuning_logging(config.LOG_DIR)
    logging.info(f"Hyperparameter tuning logs will be saved to: {log_file}")
    
    device = torch.device(config.DEVICE)
    logging.info(f"Using device: {device}")
    

    logging.info("Loading datasets...")
    train_loader, val_loader, _, class_names = get_dataloaders(
        data_dir=config.DATA_DIR,
        batch_size=config.TUNING_BATCH_SIZE,
        image_size=config.IMAGE_SIZE,
        num_workers=config.NUM_WORKERS
    )
    

    learning_rates = [1e-5, 1e-4, 5e-4, 1e-3]
    weight_decays = [1e-5, 1e-4, 1e-3]
    

    num_trials = config.TUNING_TRIALS
    
    best_acc = 0.0
    best_config = {"lr": 0, "weight_decay": 0}
    
    logging.info("Starting hyperparameter search...")
    logging.info(f"Number of trials: {num_trials}")
    logging.info(f"Learning rates to try: {learning_rates}")
    logging.info(f"Weight decays to try: {weight_decays}")
    
    start_time = time.time()
    
    for trial in range(num_trials):
        trial_start = time.time()

        lr = random.choice(learning_rates)
        weight_decay = random.choice(weight_decays)
        
        logging.info(f"\nTrial {trial+1}/{num_trials}")
        logging.info(f"Testing lr={lr}, weight_decay={weight_decay}")
        

        model = TrashNetClassifier(num_classes=len(class_names))
        

        val_acc = train_model_for_validation(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            lr=lr,
            weight_decay=weight_decay,
            device=device
        )
        
        trial_time = time.time() - trial_start
        logging.info(f"Trial {trial+1} completed in {trial_time:.2f}s")
        logging.info(f"Validation accuracy: {val_acc:.4f}")
        

        if val_acc > best_acc:
            best_acc = val_acc
            best_config = {"lr": lr, "weight_decay": weight_decay}
            logging.info(f"New best config found!")
    
    total_time = time.time() - start_time
    logging.info(f"\nHyperparameter search completed in {total_time:.2f}s")
    logging.info(f"Best config: lr={best_config['lr']}, weight_decay={best_config['weight_decay']}")
    logging.info(f"Best validation accuracy: {best_acc:.4f}")
    
    return best_config