File size: 4,818 Bytes
dae5c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import str2bool
import torchvision.transforms as transforms
from src.datasets.datasets import LocalISICDataset
import matplotlib.pyplot as plt
import torch
import argparse
import numpy as np
def str2bool(v):
"""
Converts string to bool type; enables command line
arguments in the format of '--arg1 true --arg2 false'
"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def main(args):
if not args.cielab:
transform = transforms.Compose([
transforms.Resize((args.input_size, args.input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
transform = transforms.Compose([
transforms.Resize((args.input_size, args.input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.370, 0.133, 0.092], std=[0.327, 0.090, 0.105])
])
malignant_class_transform = {
"original": transforms.Compose([]),
"horizontal_flip": transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]),
"vertical_flip": transforms.Compose([transforms.RandomVerticalFlip(p=1.0)]),
"rotate": transforms.Compose([transforms.RandomRotation(15)]),
"translate": transforms.Compose([transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))])
}
dataset = LocalISICDataset(args.data_path,
transform=transform,
augment_transforms = malignant_class_transform,
split='valid',
skin_color_csv=args.skin_color_csv,
cielab=args.cielab,
skin_former=args.skin_former,
segment_out_skin=args.segment_out_skin,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=args.pin_mem,
drop_last=True
)
for i in range(max(1, int(args.visualize_num // args.batch_size))):
images, targets, groups = next(iter(dataloader))
fig, axes = plt.subplots(1, args.batch_size, figsize=(15, 5))
for j in range(args.batch_size):
img = images[j].numpy() if not isinstance(images[j], np.ndarray) else images[j]
if not args.cielab:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
else:
mean = np.array([0.370, 0.133, 0.092])
std = np.array([0.327, 0.090, 0.105])
mean = mean.reshape(3, 1, 1)
std = std.reshape(3, 1, 1)
img = std * img + mean
img = np.clip(img, 0, 1)
img = np.transpose(img, (1, 2, 0))
axes[j].imshow(img)
axes[j].set_title(f'Label: {targets[j].item()}, Group: {groups[j].item()}')
axes[j].axis('off')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser('Melanoma imaes visualization')
parser.add_argument('--data_path', default='./isic2020_challenge', type=str,
help='Path to dataset with train/valid folders')
parser.add_argument('--input_size', default=224, type=int)
parser.add_argument('--skin_color_csv', default="./isic2020_challenge/ISIC_2020_full.csv", type=str,
help='Path to the CSV file containing skin color labels')
parser.add_argument('--visualize_num', default=5, type=int,
help='Number of images to visualize')
parser.add_argument('--batch_size', default=5, type=int,
help='Batch size for visualization')
# Augmentation parameters
parser.add_argument('--cielab', action='store_true', default=True,
help='Load images to CIELab colorspace')
parser.add_argument('--skin_former', action='store_true', default=False,
help='Transform lighter skin types to darker ones')
parser.add_argument('--segment_out_skin', type=str2bool, default=True,
help='Segment out skin from images')
parser.add_argument('--device', default='cpu', type=str,
help='Device to use for training')
parser.add_argument('--pin_mem', type=str2bool, default=True,
help='Pin CPU memory for data loading')
args = parser.parse_args()
print(args)
main(args) |