File size: 8,520 Bytes
f94b780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"):
    """
    Train a neural network model with specified training, validation, and testing datasets.
    Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images.

    This function performs a complete training loop, including:
    - Creating DataLoaders for training, validation, and testing datasets
    - Moving the model to the specified device (CPU/GPU)
    - Training the model for a specified number of epochs
    - Tracking and logging training, validation, and testing metrics
    - Saving the best (based on validation performance) and last model weights
    - Plotting and saving accuracy and loss graphs per epoch

    Parameters:
    -----------
    model : torch.nn.Module
        The neural network model to be trained
    train_loader : torch.utils.data.DataLoader
        Dataset used for training the model
    val_loader : torch.utils.data.DataLoader
        Dataset used for validating the model during training
    test_loader : torch.utils.data.DataLoader
        Dataset used for evaluating the model's performance after training
    optimizer : torch.optim.Optimizer
        Optimization algorithm for updating model weights
    criterion : torch.nn.Module
        Loss function used to compute the model's performance
    epochs : int
        Number of complete passes through the entire training dataset
    output_folder : str
        Folder path where the model weights and plots will be saved
    device : str, optional
        Computing device to use for training (default is "cpu")
        Can be "cpu" or "cuda" for GPU training

    Returns:
    --------
    None

    Side Effects:
    -------------
    - Prints training, validation, and testing metrics for each epoch
    - Saves the best performing model (based on validation accuracy) to "weights/best_model.pth"
    - Saves the final model to "weights/last_model.pth"
    - Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder

    Example:
    --------
    >>> model = MyModel()
    >>> optimizer = torch.optim.Adam(model.parameters())
    >>> criterion = nn.CrossEntropyLoss()
    >>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights")
    """
    import os
    import torch
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    import matplotlib.pyplot as plt

    # Ensure weights folder exists
    os.makedirs(output_folder, exist_ok=True)

    print(f"Device Found: {device}, Starting Training 🚀")
    
    # Move model to the specified device
    model = model.to(device)

    best_val_accuracy = 0.0  # Initialize best validation accuracy tracker

    # Lists to store metrics per epoch for plotting
    train_losses, val_losses, test_losses = [], [], []
    train_accuracies, val_accuracies, test_accuracies = [], [], []

    for epoch in range(epochs):
        # ----------------------
        # Training Phase
        # ----------------------
        model.train()  # Set model to training mode
        running_loss = 0.0
        correct = 0
        total = 0

        train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False)
        for images, labels in train_progress:
            # Move tensors to the specified device
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()  # Reset gradients
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            train_progress.set_postfix({
                'Loss': f'{loss.item():.4f}', 
                'Accuracy': f'{100 * correct / total:.2f}%'
            })

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # ----------------------
        # Validation Phase
        # ----------------------
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False)
        with torch.no_grad():
            for images, labels in val_progress:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

                val_progress.set_postfix({
                    'Loss': f'{loss.item():.4f}', 
                    'Accuracy': f'{100 * correct_val / total_val:.2f}%'
                })

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct_val / total_val

        # ----------------------
        # Testing Phase
        # ----------------------
        test_loss = 0.0
        correct_test = 0
        total_test = 0

        test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False)
        with torch.no_grad():
            for images, labels in test_progress:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()

                test_progress.set_postfix({
                    'Loss': f'{loss.item():.4f}', 
                    'Accuracy': f'{100 * correct_test / total_test:.2f}%'
                })

        test_loss /= len(test_loader)
        test_accuracy = 100 * correct_test / total_test

        # Store metrics for plotting
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        test_losses.append(test_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        test_accuracies.append(test_accuracy)

        # Log the metrics for this epoch
        print(
            f"Epoch [{epoch+1}/{epochs}]: "
            f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | "
            f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | "
            f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
        )

        # Save the best model based on validation accuracy
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))

    # Save the last model
    torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth"))
    print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy))
    
    # ----------------------
    # Plotting Metrics with Matplotlib
    # ----------------------
    epochs_range = range(1, epochs + 1)
    
    # Plot Losses
    plt.figure()
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.plot(epochs_range, test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss per Epoch')
    plt.legend()
    loss_plot_path = os.path.join(output_folder, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved to {loss_plot_path}")

    # Plot Accuracies
    plt.figure()
    plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
    plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
    plt.plot(epochs_range, test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy per Epoch')
    plt.legend()
    acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png')
    plt.savefig(acc_plot_path)
    plt.close()
    print(f"Accuracy plot saved to {acc_plot_path}")