File size: 799 Bytes
c4ed1c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torch
from torchvision import transforms
target_size = (128, 128)
transform = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Resize(target_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.ConvertImageDtype(torch.float),
]),
'test': transforms.Compose([
transforms.ToTensor(),
transforms.Resize(target_size),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.ConvertImageDtype(torch.float),
]),
}
target_transform = transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)) |