| | import os |
| | import torch |
| | from torchvision import transforms, datasets |
| | from albumentations import ( |
| | HorizontalFlip, |
| | VerticalFlip, |
| | ShiftScaleRotate, |
| | CLAHE, |
| | RandomRotate90, |
| | Transpose, |
| | ShiftScaleRotate, |
| | HueSaturationValue, |
| | GaussNoise, |
| | Sharpen, |
| | Emboss, |
| | RandomBrightnessContrast, |
| | OneOf, |
| | Compose, |
| | ) |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | torch.hub.set_dir('./cache') |
| | os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
| |
|
| | def strong_aug(p=0.5): |
| | return Compose( |
| | [ |
| | RandomRotate90(p=0.2), |
| | Transpose(p=0.2), |
| | HorizontalFlip(p=0.5), |
| | VerticalFlip(p=0.5), |
| | OneOf( |
| | [ |
| | GaussNoise(), |
| | ], |
| | p=0.2, |
| | ), |
| | ShiftScaleRotate(p=0.2), |
| | OneOf( |
| | [ |
| | CLAHE(clip_limit=2), |
| | Sharpen(), |
| | Emboss(), |
| | RandomBrightnessContrast(), |
| | ], |
| | p=0.2, |
| | ), |
| | HueSaturationValue(p=0.2), |
| | ], |
| | p=p, |
| | ) |
| |
|
| |
|
| | def augment(aug, image): |
| | return aug(image=image)["image"] |
| |
|
| |
|
| | class Aug(object): |
| | def __call__(self, img): |
| | aug = strong_aug(p=0.9) |
| | return Image.fromarray(augment(aug, np.array(img))) |
| |
|
| |
|
| | def normalize_data(): |
| | mean = [0.485, 0.456, 0.406] |
| | std = [0.229, 0.224, 0.225] |
| |
|
| | return { |
| | "train": transforms.Compose( |
| | [Aug(), transforms.ToTensor(), transforms.Normalize(mean, std)] |
| | ), |
| | "valid": transforms.Compose( |
| | [transforms.ToTensor(), transforms.Normalize(mean, std)] |
| | ), |
| | "test": transforms.Compose( |
| | [transforms.ToTensor(), transforms.Normalize(mean, std)] |
| | ), |
| | "vid": transforms.Compose([transforms.Normalize(mean, std)]), |
| | } |
| |
|
| |
|
| | def load_data(data_dir="sample/", batch_size=4): |
| | data_dir = data_dir |
| | image_datasets = { |
| | x: datasets.ImageFolder(os.path.join(data_dir, x), normalize_data()[x]) |
| | for x in ["train", "valid", "test"] |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "valid", "test"]} |
| |
|
| | train_dataloaders = torch.utils.data.DataLoader( |
| | image_datasets["train"], |
| | batch_size, |
| | shuffle=True, |
| | num_workers=0, |
| | pin_memory=True, |
| | ) |
| | validation_dataloaders = torch.utils.data.DataLoader( |
| | image_datasets["valid"], |
| | batch_size, |
| | shuffle=False, |
| | num_workers=0, |
| | pin_memory=True, |
| | ) |
| | test_dataloaders = torch.utils.data.DataLoader( |
| | image_datasets["test"], |
| | batch_size, |
| | shuffle=False, |
| | num_workers=0, |
| | pin_memory=True, |
| | ) |
| |
|
| | dataloaders = { |
| | "train": train_dataloaders, |
| | "validation": validation_dataloaders, |
| | "test": test_dataloaders, |
| | } |
| |
|
| | return dataloaders, dataset_sizes |
| |
|
| |
|