Update app.py
Browse files
app.py
CHANGED
|
@@ -558,7 +558,7 @@ def modelTFT(csv_file, prax):
|
|
| 558 |
train_dataloader,
|
| 559 |
val_dataloader,
|
| 560 |
model_path="optuna_test",
|
| 561 |
-
n_trials=
|
| 562 |
max_epochs=MAX_EPOCHS,
|
| 563 |
gradient_clip_val_range=(0.01, 0.5),
|
| 564 |
hidden_size_range=(8, 64),
|
|
@@ -568,7 +568,7 @@ def modelTFT(csv_file, prax):
|
|
| 568 |
dropout_range=(0.1, 0.3),
|
| 569 |
trainer_kwargs=dict(limit_train_batches=30),
|
| 570 |
reduce_on_plateau_patience=4,
|
| 571 |
-
pruner=optuna.pruners.MedianPruner(n_min_trials=
|
| 572 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
| 573 |
)
|
| 574 |
#torch.cuda.empty_cache()
|
|
@@ -582,6 +582,7 @@ def modelTFT(csv_file, prax):
|
|
| 582 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
| 583 |
callbacks=[lr_logger, early_stop_callback],
|
| 584 |
logger=logger,
|
|
|
|
| 585 |
)
|
| 586 |
|
| 587 |
tft = TemporalFusionTransformer.from_dataset(
|
|
@@ -795,7 +796,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
| 795 |
train_dataloader,
|
| 796 |
val_dataloader,
|
| 797 |
model_path="optuna_test",
|
| 798 |
-
n_trials=
|
| 799 |
max_epochs=MAX_EPOCHS,
|
| 800 |
gradient_clip_val_range=(0.01, 0.5),
|
| 801 |
hidden_size_range=(8, 64),
|
|
@@ -805,7 +806,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
| 805 |
dropout_range=(0.1, 0.3),
|
| 806 |
trainer_kwargs=dict(limit_train_batches=30),
|
| 807 |
reduce_on_plateau_patience=4,
|
| 808 |
-
pruner=optuna.pruners.MedianPruner(n_min_trials=
|
| 809 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
| 810 |
)
|
| 811 |
#torch.cuda.empty_cache()
|
|
@@ -819,6 +820,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
| 819 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
| 820 |
callbacks=[lr_logger, early_stop_callback],
|
| 821 |
logger=logger,
|
|
|
|
| 822 |
)
|
| 823 |
|
| 824 |
tft = TemporalFusionTransformer.from_dataset(
|
|
|
|
| 558 |
train_dataloader,
|
| 559 |
val_dataloader,
|
| 560 |
model_path="optuna_test",
|
| 561 |
+
n_trials=5,
|
| 562 |
max_epochs=MAX_EPOCHS,
|
| 563 |
gradient_clip_val_range=(0.01, 0.5),
|
| 564 |
hidden_size_range=(8, 64),
|
|
|
|
| 568 |
dropout_range=(0.1, 0.3),
|
| 569 |
trainer_kwargs=dict(limit_train_batches=30),
|
| 570 |
reduce_on_plateau_patience=4,
|
| 571 |
+
pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_startup_trials=3),
|
| 572 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
| 573 |
)
|
| 574 |
#torch.cuda.empty_cache()
|
|
|
|
| 582 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
| 583 |
callbacks=[lr_logger, early_stop_callback],
|
| 584 |
logger=logger,
|
| 585 |
+
precision="bf16-mixed",
|
| 586 |
)
|
| 587 |
|
| 588 |
tft = TemporalFusionTransformer.from_dataset(
|
|
|
|
| 796 |
train_dataloader,
|
| 797 |
val_dataloader,
|
| 798 |
model_path="optuna_test",
|
| 799 |
+
n_trials=5,
|
| 800 |
max_epochs=MAX_EPOCHS,
|
| 801 |
gradient_clip_val_range=(0.01, 0.5),
|
| 802 |
hidden_size_range=(8, 64),
|
|
|
|
| 806 |
dropout_range=(0.1, 0.3),
|
| 807 |
trainer_kwargs=dict(limit_train_batches=30),
|
| 808 |
reduce_on_plateau_patience=4,
|
| 809 |
+
pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_warmup_steps=3),
|
| 810 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
| 811 |
)
|
| 812 |
#torch.cuda.empty_cache()
|
|
|
|
| 820 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
| 821 |
callbacks=[lr_logger, early_stop_callback],
|
| 822 |
logger=logger,
|
| 823 |
+
precision="bf16-mixed",
|
| 824 |
)
|
| 825 |
|
| 826 |
tft = TemporalFusionTransformer.from_dataset(
|