ImageNet-Mini / default.py
cloudlyte's picture
code files
fc41719
# Load the ResNet50 model
def ResNet50(num_classes, channels=3):
return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)
model = ResNet50(num_classes=1000)
# Parallelize training across multiple GPUs
# model = torch.nn.DataParallel(model)
# Set the model to run on the device
model = model.to(device)
# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Function to evaluate the model
def evaluate_model(model, val_loader, criterion):
model.eval()
val_loss = 0.0
correct = 0
total = 0
class_correct = [0] * len(val_dataset.classes)
class_total = [0] * len(val_dataset.classes)
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validating"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
for i in range(len(labels)):
label = labels[i]
class_correct[label] += (predicted[i] == label).item()
class_total[label] += 1
val_loss /= len(val_loader)
accuracy = 100.0 * correct / total
per_class_accuracy = {
val_dataset.classes[i]: 100.0 * class_correct[i] / class_total[i]
for i in range(len(val_dataset.classes))
if class_total[i] > 0
}
return val_loss, accuracy, per_class_accuracy
# Train the model
print(f'Training the model on ImageNet')
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
inputs, labels = inputs.to(device), labels.to(device)
# Zero out the optimizer
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
running_loss += loss.item()
# Calculate accuracy during training
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
# Average loss and accuracy for the epoch
train_loss = running_loss / len(train_loader)
train_accuracy = 100.0 * correct / total
print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")
# Run validation after all epochs
print(f"Validating the model on unseen data after training...")
val_loss, val_accuracy, per_class_accuracy = evaluate_model(model, val_loader, criterion)
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
print("Per-class Accuracy:")
for class_name, acc in per_class_accuracy.items():
print(f"{class_name}: {acc:.2f}%")
# Save the model at the end of training
torch.save(model.state_dict(), "resnet50_imagenet.pth")
print("Model saved as resnet50_imagenet_last_epoch.pth")