ifieryarrows commited on
Commit
828ef50
·
verified ·
1 Parent(s): 0b39593

Sync from GitHub (tests passed)

Browse files
deep_learning/training/hyperopt.py CHANGED
@@ -102,7 +102,10 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
102
  except ImportError:
103
  import pytorch_lightning as pl # type: ignore[no-redef]
104
  from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
105
- from optuna.integration import PyTorchLightningPruningCallback
 
 
 
106
 
107
  from deep_learning.data.dataset import build_datasets, create_dataloaders
108
  from deep_learning.models.tft_copper import create_tft_model
 
102
  except ImportError:
103
  import pytorch_lightning as pl # type: ignore[no-redef]
104
  from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
105
+ try:
106
+ from optuna_integration.pytorch_lightning import PyTorchLightningPruningCallback
107
+ except ImportError:
108
+ from optuna.integration import PyTorchLightningPruningCallback # type: ignore[no-redef]
109
 
110
  from deep_learning.data.dataset import build_datasets, create_dataloaders
111
  from deep_learning.models.tft_copper import create_tft_model
requirements.txt CHANGED
@@ -26,6 +26,7 @@ scikit-learn>=1.3.2
26
  pytorch-forecasting>=1.0.0
27
  lightning>=2.0.0
28
  optuna>=3.5.0
 
29
  joblib>=1.3.0
30
 
31
  # Data processing
 
26
  pytorch-forecasting>=1.0.0
27
  lightning>=2.0.0
28
  optuna>=3.5.0
29
+ optuna-integration[pytorch_lightning]>=3.5.0
30
  joblib>=1.3.0
31
 
32
  # Data processing