feat: add deepfont baseline
Browse files- detector/model.py +31 -0
- train.py +6 -1
detector/model.py
CHANGED
|
@@ -10,6 +10,37 @@ import torch.nn as nn
|
|
| 10 |
import pytorch_lightning as ptl
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
class ResNet18Regressor(nn.Module):
|
| 14 |
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
|
| 15 |
super().__init__()
|
|
|
|
| 10 |
import pytorch_lightning as ptl
|
| 11 |
|
| 12 |
|
| 13 |
+
class DeepFontBaseline(nn.Module):
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.model = nn.Sequential(
|
| 17 |
+
nn.Conv2d(3, 64, 11, 2),
|
| 18 |
+
nn.BatchNorm2d(64),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.MaxPool2d(2, 2),
|
| 21 |
+
nn.Conv2d(64, 128, 3, 1, 1),
|
| 22 |
+
nn.BatchNorm2d(128),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.MaxPool2d(2, 2),
|
| 25 |
+
nn.Conv2d(128, 256, 3, 1, 1),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
# fc
|
| 32 |
+
nn.Flatten(),
|
| 33 |
+
nn.Linear(256 * 12 * 12, 4096),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
nn.Linear(4096, 4096),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
nn.Linear(4096, config.FONT_COUNT),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, X):
|
| 41 |
+
return self.model(X)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
class ResNet18Regressor(nn.Module):
|
| 45 |
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
|
| 46 |
super().__init__()
|
train.py
CHANGED
|
@@ -39,7 +39,7 @@ parser.add_argument(
|
|
| 39 |
"--model",
|
| 40 |
type=str,
|
| 41 |
default="resnet18",
|
| 42 |
-
choices=["resnet18", "resnet34", "resnet50", "resnet101"],
|
| 43 |
help="Model to use (default: resnet18)",
|
| 44 |
)
|
| 45 |
parser.add_argument(
|
|
@@ -181,6 +181,11 @@ elif args.model == "resnet101":
|
|
| 181 |
model = ResNet101Regressor(
|
| 182 |
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
|
| 183 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
else:
|
| 185 |
raise NotImplementedError()
|
| 186 |
|
|
|
|
| 39 |
"--model",
|
| 40 |
type=str,
|
| 41 |
default="resnet18",
|
| 42 |
+
choices=["resnet18", "resnet34", "resnet50", "resnet101", "deepfont"],
|
| 43 |
help="Model to use (default: resnet18)",
|
| 44 |
)
|
| 45 |
parser.add_argument(
|
|
|
|
| 181 |
model = ResNet101Regressor(
|
| 182 |
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
|
| 183 |
)
|
| 184 |
+
elif args.model == "deepfont":
|
| 185 |
+
assert args.pretrained is False
|
| 186 |
+
assert args.size == 105
|
| 187 |
+
assert args.font_classification_only is True
|
| 188 |
+
model = DeepFontBaseline()
|
| 189 |
else:
|
| 190 |
raise NotImplementedError()
|
| 191 |
|