fix: transform label when horizontal flip
Browse files- detector/data.py +16 -1
detector/data.py
CHANGED
|
@@ -152,6 +152,21 @@ class RandomCropPreserveAspectRatio(object):
|
|
| 152 |
return image, label
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
class FontDataset(Dataset):
|
| 156 |
def __init__(
|
| 157 |
self,
|
|
@@ -222,6 +237,7 @@ class FontDataset(Dataset):
|
|
| 222 |
RandomColorJitter(preserve=0.2),
|
| 223 |
RandomCrop(crop_factor=0.54, preserve=0),
|
| 224 |
RandomRotate(preserve=0.2),
|
|
|
|
| 225 |
]
|
| 226 |
image_transforms = [
|
| 227 |
torchvision.transforms.GaussianBlur(
|
|
@@ -231,7 +247,6 @@ class FontDataset(Dataset):
|
|
| 231 |
torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
|
| 232 |
torchvision.transforms.ToTensor(),
|
| 233 |
RandomNoise(max_noise=0.05, preserve=0.1),
|
| 234 |
-
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
| 235 |
]
|
| 236 |
else:
|
| 237 |
raise ValueError(f"Unknown transform: {transforms}")
|
|
|
|
| 152 |
return image, label
|
| 153 |
|
| 154 |
|
| 155 |
+
class RandomHorizontalFlip(object):
|
| 156 |
+
def __init__(self, preserve: float = 0.5):
|
| 157 |
+
self.preserve = preserve
|
| 158 |
+
|
| 159 |
+
def __call__(self, batch):
|
| 160 |
+
if random.random() < self.preserve:
|
| 161 |
+
return batch
|
| 162 |
+
|
| 163 |
+
image, label = batch
|
| 164 |
+
image = TF.hflip(image)
|
| 165 |
+
label[11] = 1 - label[11]
|
| 166 |
+
|
| 167 |
+
return image, label
|
| 168 |
+
|
| 169 |
+
|
| 170 |
class FontDataset(Dataset):
|
| 171 |
def __init__(
|
| 172 |
self,
|
|
|
|
| 237 |
RandomColorJitter(preserve=0.2),
|
| 238 |
RandomCrop(crop_factor=0.54, preserve=0),
|
| 239 |
RandomRotate(preserve=0.2),
|
| 240 |
+
RandomHorizontalFlip(preserve=0.5),
|
| 241 |
]
|
| 242 |
image_transforms = [
|
| 243 |
torchvision.transforms.GaussianBlur(
|
|
|
|
| 247 |
torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
|
| 248 |
torchvision.transforms.ToTensor(),
|
| 249 |
RandomNoise(max_noise=0.05, preserve=0.1),
|
|
|
|
| 250 |
]
|
| 251 |
else:
|
| 252 |
raise ValueError(f"Unknown transform: {transforms}")
|