Data transforms now use v2
Browse files- model_training/data_setup.py +15 -11
model_training/data_setup.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
#####################################
|
| 2 |
# Packages & Dependencies
|
| 3 |
#####################################
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
|
| 7 |
import utils
|
|
@@ -14,18 +16,20 @@ import numpy as np
|
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
|
| 16 |
# Transformations applied to each image
|
| 17 |
-
BASE_TRANSFORMS =
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
])
|
| 21 |
|
| 22 |
-
TRAIN_TRANSFORMS =
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
])
|
| 30 |
|
| 31 |
|
|
|
|
| 1 |
#####################################
|
| 2 |
# Packages & Dependencies
|
| 3 |
#####################################
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import datasets
|
| 6 |
+
from torchvision.transforms import v2
|
| 7 |
from torch.utils.data import DataLoader
|
| 8 |
|
| 9 |
import utils
|
|
|
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
|
| 18 |
# Transformations applied to each image
|
| 19 |
+
BASE_TRANSFORMS = v2.Compose([
|
| 20 |
+
v2.ToImage(), # Convert to tensor
|
| 21 |
+
v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
|
| 22 |
+
v2.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
|
| 23 |
])
|
| 24 |
|
| 25 |
+
TRAIN_TRANSFORMS = v2.Compose([
|
| 26 |
+
v2.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
|
| 27 |
+
scale = (0.8, 1.2), # Scale between 80 and 120 percent
|
| 28 |
+
translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
|
| 29 |
+
shear = 10), # Shear up to -/+ 10 degrees
|
| 30 |
+
v2.ToImage(), # Convert to tensor
|
| 31 |
+
v2.ToDtype(torch.float32, scale = True), # Rescale pixel values to within [0, 1]
|
| 32 |
+
v2.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
|
| 33 |
])
|
| 34 |
|
| 35 |
|