fix: torch.compile
Browse files
train.py
CHANGED
|
@@ -142,6 +142,9 @@ elif args.model == "resnet101":
|
|
| 142 |
else:
|
| 143 |
raise NotImplementedError()
|
| 144 |
|
|
|
|
|
|
|
|
|
|
| 145 |
detector = FontDetector(
|
| 146 |
model=model,
|
| 147 |
lambda_font=lambda_font,
|
|
@@ -154,8 +157,5 @@ detector = FontDetector(
|
|
| 154 |
num_epochs=num_epochs,
|
| 155 |
)
|
| 156 |
|
| 157 |
-
if torch.__version__ >= "2.0":
|
| 158 |
-
detector = torch.compile(detector)
|
| 159 |
-
|
| 160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
| 161 |
trainer.test(detector, datamodule=data_module)
|
|
|
|
| 142 |
else:
|
| 143 |
raise NotImplementedError()
|
| 144 |
|
| 145 |
+
if torch.__version__ >= "2.0":
|
| 146 |
+
model = torch.compile(model)
|
| 147 |
+
|
| 148 |
detector = FontDetector(
|
| 149 |
model=model,
|
| 150 |
lambda_font=lambda_font,
|
|
|
|
| 157 |
num_epochs=num_epochs,
|
| 158 |
)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
| 160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
| 161 |
trainer.test(detector, datamodule=data_module)
|