Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
20.6 kB
import os
import random
from torchvision import datasets, transforms
from timm.data.constants import \
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data import create_transform
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
from .data_processing import ita_to_group
class LocalISICDataset(Dataset):
def __init__(self, root, transform = None, skin_color_csv = None, augment_transforms = None, split = 'train', cielab=False, skin_former = False, segment_out_skin = False):
"""
Args:
root (str or ``pathlib.Path``): Root directory path.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
skin_color_csv (str): Path to the CSV file containing information about each image ita, fitzpatrick scale and group it belongs
- ita: The individual typology angle (ITA) is a measure of the skin's reaction to sun exposure.
- fitzpatrick_scale: The Fitzpatrick scale is a numerical classification schema for human skin color.
- group : Each patient is classified into one of the groups based on his skin color type.
augment_transforms (dict): Dictionary containing augmentations to be applied on malignant cases.
transform (callable, optional): Optional transform to be applied on an image.
split (str): 'train' or 'valid', determines which subset to use.
cielab (bool): If True, convert images to CIELAB color space.
skin_former (bool): If True, use skin-former augmentation.
segment_out_skin (bool): If True, segment out skin from the image.
"""
self.root = root
self.transform = transform
self.split = split
self.skin_fomer = skin_former
self.segment_out_skin = segment_out_skin
# Probability of applying skin transformations to dark
self.group_shift_probs = [0.2, 0.4, 0, 0]
if augment_transforms is None:
self.augment_transforms = None
self.oversample_ratio = 1
else:
self.augment_transforms = augment_transforms
self.oversample_ratio = len(self.augment_transforms.keys())
split_dir = os.path.join(self.root, split)
benign_dir = os.path.join(split_dir, 'benign')
malignant_dir = os.path.join(split_dir, 'malignant')
benign_mask_dir = os.path.join(self.root, 'masks', split, 'benign')
malignant_mask_dir = os.path.join(self.root, 'masks', split, 'malignant')
if os.path.exists(benign_mask_dir) and os.path.exists(malignant_mask_dir):
benign_masks = [os.path.join(benign_mask_dir, mask) for mask in os.listdir(benign_mask_dir)
if mask.lower().endswith(('.jpg'))]
malignant_masks = [os.path.join(malignant_mask_dir, mask) for mask in os.listdir(malignant_mask_dir)
if mask.lower().endswith(('.jpg'))]
benign_masks = sorted(benign_masks)
malignant_masks = sorted(malignant_masks)
else:
benign_masks = [None] * len(os.listdir(benign_dir))
malignant_masks = [None] * len(os.listdir(malignant_dir))
benign_images = [(os.path.join(benign_dir, img), 0) for img in os.listdir(benign_dir)
if img.lower().endswith(('.jpg'))]
malignant_images = [(os.path.join(malignant_dir, img), 1) for img in os.listdir(malignant_dir)
if img.lower().endswith(('.jpg'))]
benign_images = sorted(benign_images, key=lambda x: x[0])
malignant_images = sorted(malignant_images, key=lambda x: x[0])
self.samples = []
self.samples.extend(zip(benign_images, benign_masks))
self.samples.extend(zip(malignant_images, malignant_masks))
self.benign_count = len(benign_images)
self.malignant_count = len(malignant_images)
self.classes = 2
if self.augment_transforms is None:
self.class_distribution = (self.benign_count, self.malignant_count)
else:
self.class_distribution = (self.benign_count, self.malignant_count * self.oversample_ratio)
self.use_cielab = cielab
if skin_color_csv is not None:
self.skin_data = pd.read_csv(skin_color_csv, sep=';')
self.samples_with_skin = []
skin_info_dict = {}
for _, row in self.skin_data.iterrows():
img_name = row['image_name']
if 'group' in self.skin_data.columns:
skin_info_dict[img_name] = {'ita': row['ita'], 'group': row['group']}
missing = 0
for (path, label), mask in self.samples:
img_filename = os.path.basename(path)
if img_filename in skin_info_dict:
if not skin_former and not segment_out_skin:
self.samples_with_skin.append((path, label, skin_info_dict[img_filename]['group']))
else:
self.samples_with_skin.append((path, label, skin_info_dict[img_filename]['group'], skin_info_dict[img_filename]['ita'], mask))
else:
missing+=1
self.samples = self.samples_with_skin
self.groups = 1
if skin_color_csv is not None:
self.groups = len(self.skin_data['group'].unique())
self.group = self.skin_data['group'].value_counts()
def __len__(self):
if self.augment_transforms is None or self.split != 'train':
return len(self.samples)
if self.split == 'train':
return self.benign_count + self.malignant_count * self.oversample_ratio
def __getitem__(self, idx):
if self.augment_transforms is not None and self.split == 'train' and idx >= self.benign_count:
adjusted_idx = self.benign_count + ((idx - self.benign_count) // self.oversample_ratio)
augment_type_idx = (idx - self.benign_count) % self.oversample_ratio
augment_type = list(self.augment_transforms.keys())[augment_type_idx]
if adjusted_idx >= len(self.samples):
adjusted_idx = self.benign_count + ((adjusted_idx - self.benign_count) % self.malignant_count)
else:
adjusted_idx = idx
augment_type = "original"
if isinstance(self.samples[adjusted_idx], tuple) and len(self.samples[adjusted_idx]) == 3 and not self.skin_fomer:
image_path, target, group = self.samples[adjusted_idx]
elif isinstance(self.samples[adjusted_idx], tuple) and len(self.samples[adjusted_idx]) == 5 and (self.skin_fomer or self.segment_out_skin):
image_path, target, group, ita, mask = self.samples[adjusted_idx]
else:
image_path, target = self.samples[adjusted_idx]
group = -1
if not self.use_cielab:
image = Image.open(image_path).convert('RGB')
elif (self.use_cielab and not self.skin_fomer and not self.segment_out_skin) or (self.split != 'train' and not self.segment_out_skin):
image = Image.fromarray(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2LAB))
elif self.use_cielab and self.skin_fomer:
np_image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2LAB)
shift_prob = self.group_shift_probs[group]
if random.random() < shift_prob:
# Apply transformation
mask = 1 - cv2.imread(mask)[:, :, 0] / 255
assert len(mask.shape) == 2 or mask.shape[-1] == 1, "Mask has to be grayscale"
assert mask.min() >= 0 and mask.max() <= 1, "Mask values have to be in [0,1]"
# Max ita in our darkest groups (28)
target_ita = random.random() * 38 - 10
delta_ita = float(ita.replace(",", ".")) - target_ita
np_image[:, :, 2] += (mask * delta_ita * 0.5).astype(np.uint) # Shift b
np_image[:, :, 0] -= (mask * delta_ita * 0.12).astype(np.uint) # Shift L
# New label
group = ita_to_group(target_ita)
image = Image.fromarray(np_image)
else:
np_image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2LAB)
mask = cv2.imread(mask)[:, :, 0] / 255
np_image = np_image.astype(np.float32)
np_image *= mask[:, :, np.newaxis]
np_image = np.clip(np_image, 0, 255).astype(np.uint8)
image = Image.fromarray(np_image)
if target == 1 and self.split == 'train' and augment_type != "original" and self.augment_transforms is not None:
image = self.augment_transforms[augment_type](image)
if self.transform:
image = self.transform(image)
return image, target, group
def get_class_distribution(self):
return self.class_distribution
def get_labels(self):
if self.augment_transforms is None or self.split != 'train':
labels = []
for sample in self.samples:
if isinstance(sample, tuple) and len(sample) >= 2:
labels.append(sample[1])
else:
img_data, _ = sample
_, target = img_data
labels.append(target)
return labels
else:
labels = []
for idx in range(len(self)):
if idx < self.benign_count:
sample = self.samples[idx]
else:
adjusted_idx = self.benign_count + ((idx - self.benign_count) // self.oversample_ratio)
if adjusted_idx >= len(self.samples):
adjusted_idx = self.benign_count + ((adjusted_idx - self.benign_count) % self.malignant_count)
sample = self.samples[adjusted_idx]
if isinstance(sample, tuple) and len(sample) >= 2:
labels.append(sample[1])
else:
img_data, _ = sample
_, target = img_data
labels.append(target)
return labels
class KaggleISICDataset(Dataset):
def __init__(self, csv_file, image_dir, skin_color_csv = None, transform=None, augment_transforms = None,split='train', test_size=0.2,seed=42):
"""
Args:
csv_file (str): Path to the CSV file containing image names and targets.
image_dir (str): Directory containing the image files.
skin_color_csv (str): Path to the CSV file containing information about each image ita, fitzpatrick scale and group it belongs
- ita: The individual typology angle (ITA) is a measure of the skin's reaction to sun exposure.
- fitzpatrick_scale: The Fitzpatrick scale is a numerical classification schema for human skin color.
- group : Each patient is classified into one of the groups based on his skin color type.
transform (callable, optional): Optional transform to be applied on an image.
augment_transforms (dict): Dictionary containing augmentations to be applied on malignant cases.
split (str): 'train' or 'valid', determines which subset to use.
test_size (float): Proportion of the dataset to allocate for validation.
seed (int): Random seed for reproducibility.
"""
self.image_dir = image_dir
self.transform = transform
self.split = split
if augment_transforms is None:
self.augment_transforms = None
self.oversample_ratio = 1
else:
self.augment_transforms = augment_transforms
self.oversample_ratio = len(self.augment_transforms.keys())
if skin_color_csv is not None:
print(f"Skin color csv is defined. Using predefined train/valid splits.")
df = pd.read_csv(skin_color_csv, sep=';')
train_df = df[df['split'] == "train"]
valid_df = df[df['split'] == "train"]
else:
df = pd.read_csv(csv_file)
train_df, valid_df = train_test_split(
df, test_size=test_size, stratify=df['target'], random_state=seed
)
self.data = train_df if split == 'train' else valid_df
if skin_color_csv is not None:
self.samples = [(row['image_name'], row['target'], row['group']) for _, row in self.data.iterrows()]
self.groups = len(self.data['group'].unique())
self.group = self.data['group'].value_counts()
else:
self.samples = [(row['image_name'], row['target'], None ) for _, row in self.data.iterrows()]
self.classes = 2
if self.augment_transforms is None:
self.class_distribution = (len(self.data[self.data['target'] == 0]), len(self.data[self.data['target'] == 1]))
else:
self.class_distribution = (len(self.data[self.data['target'] == 0]), len(self.data[self.data['target'] == 1]) * self.oversample_ratio)
def __len__(self):
if self.augment_transforms is None or self.split != 'train':
return len(self.data)
return len(self.data[self.data['target'] == 0]) + self.oversample_ratio * len(self.data[self.data['target'] == 1])
def __getitem__(self, idx):
if self.augment_transforms is not None and self.split == 'train':
actual_idx = min(idx // self.oversample_ratio, len(self.data) - 1) if idx % self.oversample_ratio != 0 else min(idx, len(self.data) - 1)
row = self.data.iloc[actual_idx]
augment_type = list(self.augment_transforms.keys())[idx % self.oversample_ratio]
else:
row = self.data.iloc[idx]
augment_type = "original"
image_name = row['image_name']
if not image_name.lower().endswith('.jpg'):
image_name = image_name + '.jpg'
image_path = os.path.join(self.image_dir, image_name)
target = row['target']
group = row['group'] if 'group' in row.index else -1
image = Image.open(image_path).convert('RGB')
if target == 1 and self.split == 'train' and self.augment_transforms is not None:
image = self.augment_transforms[augment_type](image)
if self.transform:
image = self.transform(image)
return image, target, group
def build_dataset(is_train, args, transform=None, nb_classes=2):
if not args.convert_to_ffcv and transform is None:
transform = build_transform(is_train, args)
print("Transform = ")
if isinstance(transform, tuple):
for trans in transform:
print(" - - - - - - - - - - ")
for t in trans.transforms:
print(t)
else:
for t in transform.transforms:
print(t)
print("---------------------------")
else:
import warnings
warnings.warn("As the transformations are built using factory method timm.data.create_transform,\
such a factory method should be implemented for ffcv module as well.\
Currently transformations cannot be dynamically allocated for ffcv module \
and any ffcv dataset that is built will contain no augmented images, if not specified here otherwise.",
UserWarning)
if args.data_set == 'CIFAR':
if not args.convert_to_ffcv :
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
else:
dataset = datasets.CIFAR100(args.data_path, train=is_train, download=True)
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
if not args.convert_to_ffcv :
dataset = datasets.ImageFolder(root, transform=transform)
else:
dataset = datasets.ImageFolder(root)
elif args.data_set == 'IMAGENET1K':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
if not args.convert_to_ffcv :
dataset = datasets.ImageFolder(root, transform=transform)
else:
dataset = datasets.ImageFolder(root)
assert args.nb_classes == nb_classes
elif args.data_set == 'IMAGENET100':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
if not args.convert_to_ffcv :
dataset = datasets.ImageFolder(root, transform=transform)
else:
dataset = datasets.ImageFolder(root)
assert args.nb_classes == nb_classes
elif args.data_set == 'TINY_IMAGENET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
if not args.convert_to_ffcv :
dataset = datasets.ImageFolder(root, transform=transform)
else:
dataset = datasets.ImageFolder(root)
assert args.nb_classes == nb_classes
elif args.data_set == "image_folder":
root = args.data_path if is_train else args.eval_data_path
if not args.convert_to_ffcv:
dataset = datasets.ImageFolder(root, transform=transform)
else:
dataset = datasets.ImageFolder(root)
nb_classes = args.nb_classes
assert len(dataset.class_to_idx) == nb_classes
else:
raise NotImplementedError()
print(f"Dataset type : {args.data_path}")
print("Reading from datapath", args.data_path)
print("Number of classes = %d" % nb_classes)
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
if is_train:
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
if not resize_im:
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
# warping (no cropping) when evaluated at 384 or larger
if args.input_size >= 384:
t.append(
transforms.Resize((args.input_size, args.input_size),
interpolation=transforms.InterpolationMode.BICUBIC),
)
print(f"Warping {args.input_size} size input images...")
else:
if args.crop_pct is None and args.input_size==224:
args.crop_pct = 224 / 256
elif args.crop_pct is None:
args.crop_pct = 1.0
size = int(args.input_size / args.crop_pct)
t.append(
transforms.Resize((size,size), interpolation=transforms.InterpolationMode.BICUBIC),
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)