feat: add crop roi bbox
Browse files- detector/data.py +16 -3
- train.py +7 -0
detector/data.py
CHANGED
|
@@ -96,11 +96,13 @@ class FontDataset(Dataset):
|
|
| 96 |
config_path: str = "configs/font.yml",
|
| 97 |
regression_use_tanh: bool = False,
|
| 98 |
transforms: bool = False,
|
|
|
|
| 99 |
):
|
| 100 |
self.path = path
|
| 101 |
self.fonts = load_font_with_exclusion(config_path)
|
| 102 |
self.regression_use_tanh = regression_use_tanh
|
| 103 |
self.transforms = transforms
|
|
|
|
| 104 |
|
| 105 |
self.images = [
|
| 106 |
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
|
|
@@ -146,6 +148,12 @@ class FontDataset(Dataset):
|
|
| 146 |
with open(label_path, "rb") as f:
|
| 147 |
label: FontLabel = pickle.load(f)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
# encode label
|
| 150 |
label = self.fontlabel2tensor(label, label_path)
|
| 151 |
|
|
@@ -188,6 +196,7 @@ class FontDataModule(LightningDataModule):
|
|
| 188 |
train_transforms: bool = False,
|
| 189 |
val_transforms: bool = False,
|
| 190 |
test_transforms: bool = False,
|
|
|
|
| 191 |
regression_use_tanh: bool = False,
|
| 192 |
**kwargs,
|
| 193 |
):
|
|
@@ -197,13 +206,17 @@ class FontDataModule(LightningDataModule):
|
|
| 197 |
self.val_shuffle = val_shuffle
|
| 198 |
self.test_shuffle = test_shuffle
|
| 199 |
self.train_dataset = FontDataset(
|
| 200 |
-
train_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
)
|
| 202 |
self.val_dataset = FontDataset(
|
| 203 |
-
val_path, config_path, regression_use_tanh, val_transforms
|
| 204 |
)
|
| 205 |
self.test_dataset = FontDataset(
|
| 206 |
-
test_path, config_path, regression_use_tanh, test_transforms
|
| 207 |
)
|
| 208 |
|
| 209 |
def get_train_num_iter(self, num_device: int) -> int:
|
|
|
|
| 96 |
config_path: str = "configs/font.yml",
|
| 97 |
regression_use_tanh: bool = False,
|
| 98 |
transforms: bool = False,
|
| 99 |
+
crop_roi_bbox: bool = False,
|
| 100 |
):
|
| 101 |
self.path = path
|
| 102 |
self.fonts = load_font_with_exclusion(config_path)
|
| 103 |
self.regression_use_tanh = regression_use_tanh
|
| 104 |
self.transforms = transforms
|
| 105 |
+
self.crop_roi_bbox = crop_roi_bbox
|
| 106 |
|
| 107 |
self.images = [
|
| 108 |
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
|
|
|
|
| 148 |
with open(label_path, "rb") as f:
|
| 149 |
label: FontLabel = pickle.load(f)
|
| 150 |
|
| 151 |
+
if self.crop_roi_bbox:
|
| 152 |
+
left, top, width, height = label.bbox
|
| 153 |
+
image = TF.crop(image, top, left, height, width)
|
| 154 |
+
label.image_width = width
|
| 155 |
+
label.image_height = height
|
| 156 |
+
|
| 157 |
# encode label
|
| 158 |
label = self.fontlabel2tensor(label, label_path)
|
| 159 |
|
|
|
|
| 196 |
train_transforms: bool = False,
|
| 197 |
val_transforms: bool = False,
|
| 198 |
test_transforms: bool = False,
|
| 199 |
+
crop_roi_bbox: bool = False,
|
| 200 |
regression_use_tanh: bool = False,
|
| 201 |
**kwargs,
|
| 202 |
):
|
|
|
|
| 206 |
self.val_shuffle = val_shuffle
|
| 207 |
self.test_shuffle = test_shuffle
|
| 208 |
self.train_dataset = FontDataset(
|
| 209 |
+
train_path,
|
| 210 |
+
config_path,
|
| 211 |
+
regression_use_tanh,
|
| 212 |
+
train_transforms,
|
| 213 |
+
crop_roi_bbox,
|
| 214 |
)
|
| 215 |
self.val_dataset = FontDataset(
|
| 216 |
+
val_path, config_path, regression_use_tanh, val_transforms, crop_roi_bbox
|
| 217 |
)
|
| 218 |
self.test_dataset = FontDataset(
|
| 219 |
+
test_path, config_path, regression_use_tanh, test_transforms, crop_roi_bbox
|
| 220 |
)
|
| 221 |
|
| 222 |
def get_train_num_iter(self, num_device: int) -> int:
|
train.py
CHANGED
|
@@ -48,6 +48,12 @@ parser.add_argument(
|
|
| 48 |
action="store_true",
|
| 49 |
help="Use pretrained model for ResNet (default: False)",
|
| 50 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
args = parser.parse_args()
|
| 53 |
|
|
@@ -85,6 +91,7 @@ data_module = FontDataModule(
|
|
| 85 |
test_shuffle=False,
|
| 86 |
regression_use_tanh=regression_use_tanh,
|
| 87 |
train_transforms=augmentation,
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
|
|
|
| 48 |
action="store_true",
|
| 49 |
help="Use pretrained model for ResNet (default: False)",
|
| 50 |
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"-i",
|
| 53 |
+
"--crop-roi-bbox",
|
| 54 |
+
action="store_true",
|
| 55 |
+
help="Crop ROI bounding box (default: False)",
|
| 56 |
+
)
|
| 57 |
|
| 58 |
args = parser.parse_args()
|
| 59 |
|
|
|
|
| 91 |
test_shuffle=False,
|
| 92 |
regression_use_tanh=regression_use_tanh,
|
| 93 |
train_transforms=augmentation,
|
| 94 |
+
crop_roi_bbox=args.crop_roi_bbox,
|
| 95 |
)
|
| 96 |
|
| 97 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|