fix: data augmentation
Browse files- detector/data.py +27 -18
detector/data.py
CHANGED
|
@@ -17,13 +17,19 @@ from PIL import Image
|
|
| 17 |
|
| 18 |
|
| 19 |
class RandomColorJitter(object):
|
| 20 |
-
def __init__(
|
|
|
|
|
|
|
| 21 |
self.brightness = brightness
|
| 22 |
self.contrast = contrast
|
| 23 |
self.saturation = saturation
|
| 24 |
self.hue = hue
|
|
|
|
| 25 |
|
| 26 |
def __call__(self, batch):
|
|
|
|
|
|
|
|
|
|
| 27 |
image, label = batch
|
| 28 |
text_color = label[2:5].clone().view(3, 1, 1)
|
| 29 |
stroke_color = label[7:10].clone().view(3, 1, 1)
|
|
@@ -54,10 +60,14 @@ class RandomColorJitter(object):
|
|
| 54 |
|
| 55 |
|
| 56 |
class RandomCrop(object):
|
| 57 |
-
def __init__(self, crop_factor: float = 0.1):
|
| 58 |
self.crop_factor = crop_factor
|
|
|
|
| 59 |
|
| 60 |
def __call__(self, batch):
|
|
|
|
|
|
|
|
|
|
| 61 |
image, label = batch
|
| 62 |
width, height = image.size
|
| 63 |
|
|
@@ -80,10 +90,14 @@ class RandomCrop(object):
|
|
| 80 |
|
| 81 |
|
| 82 |
class RandomRotate(object):
|
| 83 |
-
def __init__(self, max_angle: int = 15):
|
| 84 |
self.max_angle = max_angle
|
|
|
|
| 85 |
|
| 86 |
def __call__(self, batch):
|
|
|
|
|
|
|
|
|
|
| 87 |
image, label = batch
|
| 88 |
|
| 89 |
angle = random.uniform(-self.max_angle, self.max_angle)
|
|
@@ -177,8 +191,8 @@ class FontDataset(Dataset):
|
|
| 177 |
if self.transforms is not None:
|
| 178 |
transform = transforms.Compose(
|
| 179 |
[
|
| 180 |
-
|
| 181 |
-
|
| 182 |
]
|
| 183 |
)
|
| 184 |
image, label = transform((image, label))
|
|
@@ -210,20 +224,15 @@ class FontDataset(Dataset):
|
|
| 210 |
|
| 211 |
transform = transforms.Compose(
|
| 212 |
[
|
| 213 |
-
|
| 214 |
-
RandomCrop(crop_factor=0.54),
|
| 215 |
-
|
| 216 |
]
|
| 217 |
)
|
| 218 |
image, label = transform((image, label))
|
| 219 |
|
| 220 |
-
transform = transforms.
|
| 221 |
-
|
| 222 |
-
transforms.RandomApply(
|
| 223 |
-
transforms.GaussianBlur(random.randint(2, 5), sigma=(0.1, 5.0)),
|
| 224 |
-
p=0.8,
|
| 225 |
-
),
|
| 226 |
-
]
|
| 227 |
)
|
| 228 |
|
| 229 |
image = transform(image)
|
|
@@ -259,9 +268,9 @@ class FontDataModule(LightningDataModule):
|
|
| 259 |
train_shuffle: bool = True,
|
| 260 |
val_shuffle: bool = False,
|
| 261 |
test_shuffle: bool = False,
|
| 262 |
-
train_transforms: bool =
|
| 263 |
-
val_transforms: bool =
|
| 264 |
-
test_transforms: bool =
|
| 265 |
crop_roi_bbox: bool = False,
|
| 266 |
regression_use_tanh: bool = False,
|
| 267 |
**kwargs,
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class RandomColorJitter(object):
|
| 20 |
+
def __init__(
|
| 21 |
+
self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05, preserve=0.2
|
| 22 |
+
):
|
| 23 |
self.brightness = brightness
|
| 24 |
self.contrast = contrast
|
| 25 |
self.saturation = saturation
|
| 26 |
self.hue = hue
|
| 27 |
+
self.preserve = preserve
|
| 28 |
|
| 29 |
def __call__(self, batch):
|
| 30 |
+
if random.random() < self.preserve:
|
| 31 |
+
return batch
|
| 32 |
+
|
| 33 |
image, label = batch
|
| 34 |
text_color = label[2:5].clone().view(3, 1, 1)
|
| 35 |
stroke_color = label[7:10].clone().view(3, 1, 1)
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
class RandomCrop(object):
|
| 63 |
+
def __init__(self, crop_factor: float = 0.1, preserve: float = 0.2):
|
| 64 |
self.crop_factor = crop_factor
|
| 65 |
+
self.preserve = preserve
|
| 66 |
|
| 67 |
def __call__(self, batch):
|
| 68 |
+
if random.random() < self.preserve:
|
| 69 |
+
return batch
|
| 70 |
+
|
| 71 |
image, label = batch
|
| 72 |
width, height = image.size
|
| 73 |
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
class RandomRotate(object):
|
| 93 |
+
def __init__(self, max_angle: int = 15, preserve: float = 0.2):
|
| 94 |
self.max_angle = max_angle
|
| 95 |
+
self.preserve = preserve
|
| 96 |
|
| 97 |
def __call__(self, batch):
|
| 98 |
+
if random.random() < self.preserve:
|
| 99 |
+
return batch
|
| 100 |
+
|
| 101 |
image, label = batch
|
| 102 |
|
| 103 |
angle = random.uniform(-self.max_angle, self.max_angle)
|
|
|
|
| 191 |
if self.transforms is not None:
|
| 192 |
transform = transforms.Compose(
|
| 193 |
[
|
| 194 |
+
RandomColorJitter(preserve=0.2),
|
| 195 |
+
RandomCrop(preserve=0.2),
|
| 196 |
]
|
| 197 |
)
|
| 198 |
image, label = transform((image, label))
|
|
|
|
| 224 |
|
| 225 |
transform = transforms.Compose(
|
| 226 |
[
|
| 227 |
+
RandomColorJitter(preserve=0.2),
|
| 228 |
+
RandomCrop(crop_factor=0.54, preserve=0),
|
| 229 |
+
RandomRotate(preserve=0.2),
|
| 230 |
]
|
| 231 |
)
|
| 232 |
image, label = transform((image, label))
|
| 233 |
|
| 234 |
+
transform = transforms.GaussianBlur(
|
| 235 |
+
random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
|
| 238 |
image = transform(image)
|
|
|
|
| 268 |
train_shuffle: bool = True,
|
| 269 |
val_shuffle: bool = False,
|
| 270 |
test_shuffle: bool = False,
|
| 271 |
+
train_transforms: bool = None,
|
| 272 |
+
val_transforms: bool = None,
|
| 273 |
+
test_transforms: bool = None,
|
| 274 |
crop_roi_bbox: bool = False,
|
| 275 |
regression_use_tanh: bool = False,
|
| 276 |
**kwargs,
|