| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision.models as models |
| import torchvision.transforms as transforms |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
| import os |
| from tqdm import tqdm |
| import wandb |
| import argparse |
| import random |
| import numpy as np |
| import io |
| import torchvision.transforms.functional as F |
| import torchvision.transforms.v2 as v2 |
|
|
| class HAM10000Dataset(Dataset): |
| def __init__(self, root_dir, transform=None): |
| self.root_dir = root_dir |
| self.transform = transform |
| self.classes = ['bkl', 'mel'] |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| |
| self.images = [] |
| self.labels = [] |
| |
| |
| for class_name in self.classes: |
| class_dir = os.path.join(root_dir, class_name) |
| for img_name in os.listdir(class_dir): |
| if img_name.endswith(('.jpg', '.jpeg', '.png')): |
| self.images.append(os.path.join(class_dir, img_name)) |
| self.labels.append(self.class_to_idx[class_name]) |
| |
| def __len__(self): |
| return len(self.images) |
| |
| def __getitem__(self, idx): |
| img_path = self.images[idx] |
| label = self.labels[idx] |
| |
| |
| image = Image.open(img_path).convert('RGB') |
| if self.transform: |
| image = self.transform(image) |
| |
| return image, label |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--resize', type=int, default=224, |
| help='Size to resize images to (default: 224)') |
| parser.add_argument('--seed', type=int, default=1, |
| help='Seed for random number generator (default: 1)') |
| parser.add_argument('--cuda', type=int, default=0, |
| help='CUDA device number (default: 0)') |
| parser.add_argument('--auditor_augs', action='store_true', default=False, |
| help='Enable auditor augmentations (default: False)') |
| parser.add_argument('--auto_aug', action='store_true', default=False, |
| help='Enable auto augmentations (default: False)') |
| args = parser.parse_args() |
|
|
| |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| |
| wandb.init(project="ModelAuditor", name="HAM10000_ResNet50_" + str(args.seed) + "_" + str(args.resize) + |
| ("_AuditorAugs" if args.auditor_augs else "") + ("_AutoAugs" if args.auto_aug else "")) |
|
|
| |
| if args.auditor_augs: |
| aug_list = [ |
| |
| ] |
| else: |
| aug_list = [transforms.ToTensor()] |
|
|
| |
| if args.auto_aug: |
| train_transform = transforms.Compose([ |
| transforms.Resize((args.resize, args.resize)), |
| transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET) |
| ] + aug_list + [ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| else: |
| train_transform = transforms.Compose([ |
| transforms.Resize((args.resize, args.resize)), |
| ] + aug_list + [ |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| val_transform = transforms.Compose([ |
| transforms.Resize((args.resize, args.resize)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| train_dataset = HAM10000Dataset(root_dir='data/ham10000/vidir_modern', transform=train_transform) |
| |
| |
| train_size = int(0.8 * len(train_dataset)) |
| val_size = len(train_dataset) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size]) |
| |
| |
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8) |
| val_loader = DataLoader(val_dataset, batch_size=64, num_workers=8) |
|
|
| |
| device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) |
| model.fc = nn.Linear(model.fc.in_features, 2) |
| model = model.to(device) |
|
|
| |
| optimizer = optim.Adam(model.parameters(), lr=0.001) |
| criterion = nn.CrossEntropyLoss() |
| |
| |
| scaler = torch.cuda.amp.GradScaler() |
|
|
| |
| n_epochs = 10 |
| |
| |
| warmup_epochs = 2 |
| total_steps = len(train_loader) * n_epochs |
| warmup_steps = len(train_loader) * warmup_epochs |
| scheduler = optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=0.001, |
| total_steps=total_steps, |
| pct_start=warmup_steps/total_steps, |
| anneal_strategy='cos' |
| ) |
|
|
| |
| for epoch in range(n_epochs): |
| |
| model.train() |
| train_loss = 0 |
| for x, y in tqdm(train_loader, desc=f'Epoch {epoch+1}/{n_epochs}'): |
| x, y = x.to(device), y.to(device) |
| |
| optimizer.zero_grad() |
| |
| |
| with torch.cuda.amp.autocast(): |
| outputs = model(x) |
| loss = criterion(outputs, y) |
| |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| |
| train_loss += loss.item() |
| |
| train_loss /= len(train_loader) |
| |
| |
| model.eval() |
| val_loss = 0 |
| correct = 0 |
| total = 0 |
| |
| with torch.no_grad(): |
| for x, y in val_loader: |
| x, y = x.to(device), y.to(device) |
| with torch.cuda.amp.autocast(): |
| outputs = model(x) |
| loss = criterion(outputs, y) |
| val_loss += loss.item() |
| |
| _, predicted = outputs.max(1) |
| total += y.size(0) |
| correct += predicted.eq(y).sum().item() |
| |
| val_loss /= len(val_loader) |
| accuracy = 100. * correct / total |
| |
| |
| current_lr = scheduler.get_last_lr()[0] |
| wandb.log({ |
| "train_loss": train_loss, |
| "val_loss": val_loss, |
| "val_accuracy": accuracy, |
| "epoch": epoch + 1, |
| "learning_rate": current_lr |
| }) |
| |
| print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {accuracy:.2f}%') |
| |
| |
| torch.save(model.state_dict(), f'ham10000_resnet50_{args.seed}_{args.resize}' + |
| ("_AuditorAugs" if args.auditor_augs else "") + |
| ("_AutoAugs" if args.auto_aug else "") + '.pt') |
|
|
|
|
| if __name__ == "__main__": |
| main() |