ifieryarrows commited on
Commit
4c2d2a0
·
verified ·
1 Parent(s): 5f4b46f

Sync from GitHub (tests passed)

Browse files
deep_learning/training/hyperopt.py CHANGED
@@ -83,8 +83,12 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
83
  Optimises for a composite score:
84
  score = -val_loss + 0.5 * directional_accuracy
85
  """
86
- import pytorch_lightning as pl
87
- from pytorch_lightning.callbacks import EarlyStopping
 
 
 
 
88
  from optuna.integration import PyTorchLightningPruningCallback
89
 
90
  from deep_learning.data.dataset import build_datasets, create_dataloaders
@@ -148,7 +152,10 @@ def run_hyperopt(
148
  Dict with best params, best value, and study summary.
149
  """
150
  import optuna
151
- import pytorch_lightning as pl
 
 
 
152
 
153
  from app.db import SessionLocal, init_db
154
  from deep_learning.data.feature_store import build_tft_dataframe
 
83
  Optimises for a composite score:
84
  score = -val_loss + 0.5 * directional_accuracy
85
  """
86
+ try:
87
+ import lightning.pytorch as pl
88
+ from lightning.pytorch.callbacks import EarlyStopping
89
+ except ImportError:
90
+ import pytorch_lightning as pl # type: ignore[no-redef]
91
+ from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
92
  from optuna.integration import PyTorchLightningPruningCallback
93
 
94
  from deep_learning.data.dataset import build_datasets, create_dataloaders
 
152
  Dict with best params, best value, and study summary.
153
  """
154
  import optuna
155
+ try:
156
+ import lightning.pytorch as pl
157
+ except ImportError:
158
+ import pytorch_lightning as pl # type: ignore[no-redef]
159
 
160
  from app.db import SessionLocal, init_db
161
  from deep_learning.data.feature_store import build_tft_dataframe
deep_learning/training/trainer.py CHANGED
@@ -39,8 +39,15 @@ def train_tft_model(
39
  Returns:
40
  Dict with metrics, checkpoint path, and feature importance.
41
  """
42
- import pytorch_lightning as pl
43
- from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
 
 
 
 
 
 
 
44
 
45
  from app.db import SessionLocal, init_db
46
  from deep_learning.data.feature_store import build_tft_dataframe
 
39
  Returns:
40
  Dict with metrics, checkpoint path, and feature importance.
41
  """
42
+ # pytorch_forecasting >=1.0 uses the unified `lightning` package.
43
+ # Importing from `pytorch_lightning` gives a different LightningModule
44
+ # base class, causing "model must be a LightningModule" at trainer.fit().
45
+ try:
46
+ import lightning.pytorch as pl
47
+ from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
48
+ except ImportError:
49
+ import pytorch_lightning as pl # type: ignore[no-redef]
50
+ from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint # type: ignore[no-redef]
51
 
52
  from app.db import SessionLocal, init_db
53
  from deep_learning.data.feature_store import build_tft_dataframe
requirements.txt CHANGED
@@ -24,7 +24,7 @@ scikit-learn>=1.3.2
24
 
25
  # TFT-ASRO Deep Learning
26
  pytorch-forecasting>=1.0.0
27
- pytorch-lightning>=2.0.0
28
  optuna>=3.5.0
29
  joblib>=1.3.0
30
 
 
24
 
25
  # TFT-ASRO Deep Learning
26
  pytorch-forecasting>=1.0.0
27
+ lightning>=2.0.0
28
  optuna>=3.5.0
29
  joblib>=1.3.0
30