ModelAuditor / train_Resnet50_ham10000.py
lukaskuhndkfz's picture
Upload train_Resnet50_ham10000.py with huggingface_hub
a493204 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from tqdm import tqdm
import wandb
import argparse
import random
import numpy as np
import io
import torchvision.transforms.functional as F
import torchvision.transforms.v2 as v2
class HAM10000Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = ['bkl', 'mel'] # benign (0) and malignant (1)
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.images = []
self.labels = []
# Load images from both classes
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
if img_name.endswith(('.jpg', '.jpeg', '.png')):
self.images.append(os.path.join(class_dir, img_name))
self.labels.append(self.class_to_idx[class_name])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
label = self.labels[idx]
# Load and transform image
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--resize', type=int, default=224,
help='Size to resize images to (default: 224)')
parser.add_argument('--seed', type=int, default=1,
help='Seed for random number generator (default: 1)')
parser.add_argument('--cuda', type=int, default=0,
help='CUDA device number (default: 0)')
parser.add_argument('--auditor_augs', action='store_true', default=False,
help='Enable auditor augmentations (default: False)')
parser.add_argument('--auto_aug', action='store_true', default=False,
help='Enable auto augmentations (default: False)')
args = parser.parse_args()
# Set seeds
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Initialize wandb
wandb.init(project="ModelAuditor", name="HAM10000_ResNet50_" + str(args.seed) + "_" + str(args.resize) +
("_AuditorAugs" if args.auditor_augs else "") + ("_AutoAugs" if args.auto_aug else ""))
# Define augmentations
if args.auditor_augs:
aug_list = [
# PUT HERE WHAT THE AUDITOR GIVES YOU
]
else:
aug_list = [transforms.ToTensor()]
# Define transforms
if args.auto_aug:
train_transform = transforms.Compose([
transforms.Resize((args.resize, args.resize)),
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET)
] + aug_list + [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
train_transform = transforms.Compose([
transforms.Resize((args.resize, args.resize)),
] + aug_list + [
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((args.resize, args.resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create datasets
train_dataset = HAM10000Dataset(root_dir='data/ham10000/vidir_modern', transform=train_transform)
# Split dataset into train and validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=8)
# Set device
device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
# Initialize model
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2) # 2 classes: benign and malignant
model = model.to(device)
# Initialize optimizer and criterion
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Initialize scaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# Training parameters
n_epochs = 10
# Add learning rate scheduler
warmup_epochs = 2
total_steps = len(train_loader) * n_epochs
warmup_steps = len(train_loader) * warmup_epochs
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.001,
total_steps=total_steps,
pct_start=warmup_steps/total_steps,
anneal_strategy='cos'
)
# Training loop
for epoch in range(n_epochs):
# Training phase
model.train()
train_loss = 0
for x, y in tqdm(train_loader, desc=f'Epoch {epoch+1}/{n_epochs}'):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
# Mixed precision training
with torch.cuda.amp.autocast():
outputs = model(x)
loss = criterion(outputs, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# Validation phase
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
with torch.cuda.amp.autocast():
outputs = model(x)
loss = criterion(outputs, y)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
val_loss /= len(val_loader)
accuracy = 100. * correct / total
# Log metrics
current_lr = scheduler.get_last_lr()[0]
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": accuracy,
"epoch": epoch + 1,
"learning_rate": current_lr
})
print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy:.2f}%')
# Save model after each epoch
torch.save(model.state_dict(), f'ham10000_resnet50_{args.seed}_{args.resize}' +
("_AuditorAugs" if args.auditor_augs else "") +
("_AutoAugs" if args.auto_aug else "") + '.pt')
if __name__ == "__main__":
main()