Spaces:
Configuration error
Configuration error
File size: 8,520 Bytes
f94b780 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"):
"""
Train a neural network model with specified training, validation, and testing datasets.
Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images.
This function performs a complete training loop, including:
- Creating DataLoaders for training, validation, and testing datasets
- Moving the model to the specified device (CPU/GPU)
- Training the model for a specified number of epochs
- Tracking and logging training, validation, and testing metrics
- Saving the best (based on validation performance) and last model weights
- Plotting and saving accuracy and loss graphs per epoch
Parameters:
-----------
model : torch.nn.Module
The neural network model to be trained
train_loader : torch.utils.data.DataLoader
Dataset used for training the model
val_loader : torch.utils.data.DataLoader
Dataset used for validating the model during training
test_loader : torch.utils.data.DataLoader
Dataset used for evaluating the model's performance after training
optimizer : torch.optim.Optimizer
Optimization algorithm for updating model weights
criterion : torch.nn.Module
Loss function used to compute the model's performance
epochs : int
Number of complete passes through the entire training dataset
output_folder : str
Folder path where the model weights and plots will be saved
device : str, optional
Computing device to use for training (default is "cpu")
Can be "cpu" or "cuda" for GPU training
Returns:
--------
None
Side Effects:
-------------
- Prints training, validation, and testing metrics for each epoch
- Saves the best performing model (based on validation accuracy) to "weights/best_model.pth"
- Saves the final model to "weights/last_model.pth"
- Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder
Example:
--------
>>> model = MyModel()
>>> optimizer = torch.optim.Adam(model.parameters())
>>> criterion = nn.CrossEntropyLoss()
>>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights")
"""
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
# Ensure weights folder exists
os.makedirs(output_folder, exist_ok=True)
print(f"Device Found: {device}, Starting Training 🚀")
# Move model to the specified device
model = model.to(device)
best_val_accuracy = 0.0 # Initialize best validation accuracy tracker
# Lists to store metrics per epoch for plotting
train_losses, val_losses, test_losses = [], [], []
train_accuracies, val_accuracies, test_accuracies = [], [], []
for epoch in range(epochs):
# ----------------------
# Training Phase
# ----------------------
model.train() # Set model to training mode
running_loss = 0.0
correct = 0
total = 0
train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False)
for images, labels in train_progress:
# Move tensors to the specified device
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # Reset gradients
outputs = model(images) # Forward pass
loss = criterion(outputs, labels) # Compute loss
loss.backward() # Backpropagation
optimizer.step() # Update weights
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct / total:.2f}%'
})
train_loss = running_loss / len(train_loader)
train_accuracy = 100 * correct / total
# ----------------------
# Validation Phase
# ----------------------
model.eval() # Set model to evaluation mode
val_loss = 0.0
correct_val = 0
total_val = 0
val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False)
with torch.no_grad():
for images, labels in val_progress:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).sum().item()
val_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct_val / total_val:.2f}%'
})
val_loss /= len(val_loader)
val_accuracy = 100 * correct_val / total_val
# ----------------------
# Testing Phase
# ----------------------
test_loss = 0.0
correct_test = 0
total_test = 0
test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False)
with torch.no_grad():
for images, labels in test_progress:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_test += labels.size(0)
correct_test += (predicted == labels).sum().item()
test_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct_test / total_test:.2f}%'
})
test_loss /= len(test_loader)
test_accuracy = 100 * correct_test / total_test
# Store metrics for plotting
train_losses.append(train_loss)
val_losses.append(val_loss)
test_losses.append(test_loss)
train_accuracies.append(train_accuracy)
val_accuracies.append(val_accuracy)
test_accuracies.append(test_accuracy)
# Log the metrics for this epoch
print(
f"Epoch [{epoch+1}/{epochs}]: "
f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | "
f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | "
f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
)
# Save the best model based on validation accuracy
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))
# Save the last model
torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth"))
print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy))
# ----------------------
# Plotting Metrics with Matplotlib
# ----------------------
epochs_range = range(1, epochs + 1)
# Plot Losses
plt.figure()
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.plot(epochs_range, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss per Epoch')
plt.legend()
loss_plot_path = os.path.join(output_folder, 'loss_plot.png')
plt.savefig(loss_plot_path)
plt.close()
print(f"Loss plot saved to {loss_plot_path}")
# Plot Accuracies
plt.figure()
plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.plot(epochs_range, test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy per Epoch')
plt.legend()
acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png')
plt.savefig(acc_plot_path)
plt.close()
print(f"Accuracy plot saved to {acc_plot_path}")
|