feat: add torch compile feature
Browse files
train.py
CHANGED
|
@@ -154,5 +154,8 @@ detector = FontDetector(
|
|
| 154 |
num_epochs=num_epochs,
|
| 155 |
)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
| 158 |
trainer.test(detector, datamodule=data_module)
|
|
|
|
| 154 |
num_epochs=num_epochs,
|
| 155 |
)
|
| 156 |
|
| 157 |
+
if torch.__version__ >= "2.0":
|
| 158 |
+
detector = torch.compile(detector)
|
| 159 |
+
|
| 160 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
| 161 |
trainer.test(detector, datamodule=data_module)
|