feat: add classification only option
Browse files- detector/model.py +10 -2
- train.py +7 -0
detector/model.py
CHANGED
|
@@ -83,16 +83,21 @@ class ResNet101Regressor(nn.Module):
|
|
| 83 |
|
| 84 |
|
| 85 |
class FontDetectorLoss(nn.Module):
|
| 86 |
-
def __init__(
|
|
|
|
|
|
|
| 87 |
super().__init__()
|
| 88 |
self.category_loss = nn.CrossEntropyLoss()
|
| 89 |
self.regression_loss = nn.MSELoss()
|
| 90 |
self.lambda_font = lambda_font
|
| 91 |
self.lambda_direction = lambda_direction
|
| 92 |
self.lambda_regression = lambda_regression
|
|
|
|
| 93 |
|
| 94 |
def forward(self, y_hat, y):
|
| 95 |
font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
|
|
|
|
|
|
|
| 96 |
direction_cat = self.category_loss(
|
| 97 |
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
|
| 98 |
)
|
|
@@ -130,6 +135,7 @@ class FontDetector(ptl.LightningModule):
|
|
| 130 |
lambda_font: float,
|
| 131 |
lambda_direction: float,
|
| 132 |
lambda_regression: float,
|
|
|
|
| 133 |
lr: float,
|
| 134 |
betas: Tuple[float, float],
|
| 135 |
num_warmup_iters: int,
|
|
@@ -138,7 +144,9 @@ class FontDetector(ptl.LightningModule):
|
|
| 138 |
):
|
| 139 |
super().__init__()
|
| 140 |
self.model = model
|
| 141 |
-
self.loss = FontDetectorLoss(
|
|
|
|
|
|
|
| 142 |
self.font_accur_train = torchmetrics.Accuracy(
|
| 143 |
task="multiclass", num_classes=config.FONT_COUNT
|
| 144 |
)
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
class FontDetectorLoss(nn.Module):
|
| 86 |
+
def __init__(
|
| 87 |
+
self, lambda_font, lambda_direction, lambda_regression, font_classification_only
|
| 88 |
+
):
|
| 89 |
super().__init__()
|
| 90 |
self.category_loss = nn.CrossEntropyLoss()
|
| 91 |
self.regression_loss = nn.MSELoss()
|
| 92 |
self.lambda_font = lambda_font
|
| 93 |
self.lambda_direction = lambda_direction
|
| 94 |
self.lambda_regression = lambda_regression
|
| 95 |
+
self.font_classfiication_only = font_classification_only
|
| 96 |
|
| 97 |
def forward(self, y_hat, y):
|
| 98 |
font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
|
| 99 |
+
if self.font_classfiication_only:
|
| 100 |
+
return self.lambda_font * font_cat
|
| 101 |
direction_cat = self.category_loss(
|
| 102 |
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
|
| 103 |
)
|
|
|
|
| 135 |
lambda_font: float,
|
| 136 |
lambda_direction: float,
|
| 137 |
lambda_regression: float,
|
| 138 |
+
font_classification_only: bool,
|
| 139 |
lr: float,
|
| 140 |
betas: Tuple[float, float],
|
| 141 |
num_warmup_iters: int,
|
|
|
|
| 144 |
):
|
| 145 |
super().__init__()
|
| 146 |
self.model = model
|
| 147 |
+
self.loss = FontDetectorLoss(
|
| 148 |
+
lambda_font, lambda_direction, lambda_regression, font_classification_only
|
| 149 |
+
)
|
| 150 |
self.font_accur_train = torchmetrics.Accuracy(
|
| 151 |
task="multiclass", num_classes=config.FONT_COUNT
|
| 152 |
)
|
train.py
CHANGED
|
@@ -84,6 +84,12 @@ parser.add_argument(
|
|
| 84 |
default=get_current_tag(),
|
| 85 |
help="Model name (default: current tag)",
|
| 86 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
args = parser.parse_args()
|
| 89 |
|
|
@@ -177,6 +183,7 @@ detector = FontDetector(
|
|
| 177 |
lambda_font=lambda_font,
|
| 178 |
lambda_direction=lambda_direction,
|
| 179 |
lambda_regression=lambda_regression,
|
|
|
|
| 180 |
lr=lr,
|
| 181 |
betas=(b1, b2),
|
| 182 |
num_warmup_iters=num_warmup_iter,
|
|
|
|
| 84 |
default=get_current_tag(),
|
| 85 |
help="Model name (default: current tag)",
|
| 86 |
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"-f",
|
| 89 |
+
"--font-classification-only",
|
| 90 |
+
action="store_true",
|
| 91 |
+
help="Font classification only (default: False)",
|
| 92 |
+
)
|
| 93 |
|
| 94 |
args = parser.parse_args()
|
| 95 |
|
|
|
|
| 183 |
lambda_font=lambda_font,
|
| 184 |
lambda_direction=lambda_direction,
|
| 185 |
lambda_regression=lambda_regression,
|
| 186 |
+
font_classification_only=args.font_classification_only,
|
| 187 |
lr=lr,
|
| 188 |
betas=(b1, b2),
|
| 189 |
num_warmup_iters=num_warmup_iter,
|