# Load the ResNet50 model def ResNet50(num_classes, channels=3): return ResNet(Bottleneck, [3,4,6,3], num_classes, channels) model = ResNet50(num_classes=1000) # Parallelize training across multiple GPUs # model = torch.nn.DataParallel(model) # Set the model to run on the device model = model.to(device) # Define the loss function and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Function to evaluate the model def evaluate_model(model, val_loader, criterion): model.eval() val_loss = 0.0 correct = 0 total = 0 class_correct = [0] * len(val_dataset.classes) class_total = [0] * len(val_dataset.classes) with torch.no_grad(): for inputs, labels in tqdm(val_loader, desc="Validating"): inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) for i in range(len(labels)): label = labels[i] class_correct[label] += (predicted[i] == label).item() class_total[label] += 1 val_loss /= len(val_loader) accuracy = 100.0 * correct / total per_class_accuracy = { val_dataset.classes[i]: 100.0 * class_correct[i] / class_total[i] for i in range(len(val_dataset.classes)) if class_total[i] > 0 } return val_loss, accuracy, per_class_accuracy # Train the model print(f'Training the model on ImageNet') for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): inputs, labels = inputs.to(device), labels.to(device) # Zero out the optimizer optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass loss.backward() optimizer.step() running_loss += loss.item() # Calculate accuracy during training _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) # Average loss and accuracy for the epoch train_loss = running_loss / len(train_loader) train_accuracy = 100.0 * correct / total print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%") # Run validation after all epochs print(f"Validating the model on unseen data after training...") val_loss, val_accuracy, per_class_accuracy = evaluate_model(model, val_loader, criterion) print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%") print("Per-class Accuracy:") for class_name, acc in per_class_accuracy.items(): print(f"{class_name}: {acc:.2f}%") # Save the model at the end of training torch.save(model.state_dict(), "resnet50_imagenet.pth") print("Model saved as resnet50_imagenet_last_epoch.pth")