Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- deep_learning/training/hyperopt.py +4 -1
- requirements.txt +1 -0
deep_learning/training/hyperopt.py
CHANGED
|
@@ -102,7 +102,10 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 102 |
except ImportError:
|
| 103 |
import pytorch_lightning as pl # type: ignore[no-redef]
|
| 104 |
from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
| 108 |
from deep_learning.models.tft_copper import create_tft_model
|
|
|
|
| 102 |
except ImportError:
|
| 103 |
import pytorch_lightning as pl # type: ignore[no-redef]
|
| 104 |
from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
|
| 105 |
+
try:
|
| 106 |
+
from optuna_integration.pytorch_lightning import PyTorchLightningPruningCallback
|
| 107 |
+
except ImportError:
|
| 108 |
+
from optuna.integration import PyTorchLightningPruningCallback # type: ignore[no-redef]
|
| 109 |
|
| 110 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
| 111 |
from deep_learning.models.tft_copper import create_tft_model
|
requirements.txt
CHANGED
|
@@ -26,6 +26,7 @@ scikit-learn>=1.3.2
|
|
| 26 |
pytorch-forecasting>=1.0.0
|
| 27 |
lightning>=2.0.0
|
| 28 |
optuna>=3.5.0
|
|
|
|
| 29 |
joblib>=1.3.0
|
| 30 |
|
| 31 |
# Data processing
|
|
|
|
| 26 |
pytorch-forecasting>=1.0.0
|
| 27 |
lightning>=2.0.0
|
| 28 |
optuna>=3.5.0
|
| 29 |
+
optuna-integration[pytorch_lightning]>=3.5.0
|
| 30 |
joblib>=1.3.0
|
| 31 |
|
| 32 |
# Data processing
|