File size: 4,691 Bytes
57d41d5
 
 
84d0c9e
 
 
 
 
 
 
57d41d5
84d0c9e
 
 
 
 
57d41d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84d0c9e
 
 
57d41d5
 
 
 
 
 
84d0c9e
 
 
 
 
 
 
 
 
57d41d5
 
 
 
84d0c9e
 
 
 
 
 
 
 
57d41d5
84d0c9e
57d41d5
 
84d0c9e
57d41d5
 
 
 
 
 
 
84d0c9e
 
57d41d5
 
 
 
 
 
 
 
 
 
 
84d0c9e
 
 
 
 
 
 
57d41d5
84d0c9e
 
57d41d5
 
 
 
 
 
 
84d0c9e
 
57d41d5
 
 
 
 
 
 
 
 
 
 
84d0c9e
 
 
 
 
 
 
 
57d41d5
84d0c9e
 
57d41d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from datetime import datetime
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from src import config
import time
from torch.utils.tensorboard import SummaryWriter


def calculate_accuracy(y_pred, y_true):
    preds = torch.argmax(y_pred, dim=1)
    correct = (preds == y_true).sum().item()
    return correct / len(y_true)


def setup_logging(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"training_{timestamp}.log")

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return log_file


def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss, running_acc = 0.0, 0.0
    batch_count = len(dataloader)

    logging.info(f"Training on {batch_count} batches")
    for batch_idx, (images, labels) in enumerate(dataloader):
        if batch_idx % 10 == 0:
            logging.info(f"  Batch {batch_idx}/{batch_count}")

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        acc = calculate_accuracy(outputs, labels)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=config.GRAD_CLIP_VALUE)

        optimizer.step()

        running_loss += loss.item()
        running_acc += acc

    return running_loss / len(dataloader), running_acc / len(dataloader)


def train_model(model, train_loader, val_loader, epochs=config.EPOCHS, lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, device=config.DEVICE):

    log_file = setup_logging(config.LOG_DIR)
    logging.info(f"Training logs will be saved to: {log_file}")

    logging.info(f"Training configuration:")
    logging.info(f"  Epochs: {epochs}")
    logging.info(f"  Learning rate: {lr}")
    logging.info(f"  Weight decay: {weight_decay}")
    logging.info(f"  Device: {device}")
    logging.info(f"  Batch size: {config.BATCH_SIZE}")
    logging.info(f"  Image size: {config.IMAGE_SIZE}")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=config.LR_SCHEDULER_FACTOR,
        patience=config.LR_SCHEDULER_PATIENCE,
        verbose=True
    )

    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0.0

    run_name = time.strftime("run_%Y%m%d-%H%M")
    log_dir = f"{config.LOG_DIR}/{run_name}"
    writer = SummaryWriter(log_dir=log_dir)

    logging.info(f"Training on: {device.upper()}\n")

    for epoch in range(epochs):
        epoch_start_time = time.time()
        logging.info(f"Epoch {epoch+1}/{epochs} started")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device)

        logging.info("Validating...")
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        epoch_time = time.time() - epoch_start_time

        scheduler.step(val_acc)

        logging.info(
            f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
        logging.info(
            f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        logging.info(
            f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
            logging.info("Model saved!")

    writer.close()
    logging.info("Training complete. Best Val Acc: {:.2f}%".format(
        best_val_acc * 100))

    return best_val_acc


def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss, val_acc = 0.0, 0.0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            acc = calculate_accuracy(outputs, labels)

            val_loss += loss.item()
            val_acc += acc

    return val_loss / len(dataloader), val_acc / len(dataloader)