Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- deep_learning/training/hyperopt.py +10 -3
- deep_learning/training/trainer.py +9 -2
- requirements.txt +1 -1
deep_learning/training/hyperopt.py
CHANGED
|
@@ -83,8 +83,12 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 83 |
Optimises for a composite score:
|
| 84 |
score = -val_loss + 0.5 * directional_accuracy
|
| 85 |
"""
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
from optuna.integration import PyTorchLightningPruningCallback
|
| 89 |
|
| 90 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
|
@@ -148,7 +152,10 @@ def run_hyperopt(
|
|
| 148 |
Dict with best params, best value, and study summary.
|
| 149 |
"""
|
| 150 |
import optuna
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
from app.db import SessionLocal, init_db
|
| 154 |
from deep_learning.data.feature_store import build_tft_dataframe
|
|
|
|
| 83 |
Optimises for a composite score:
|
| 84 |
score = -val_loss + 0.5 * directional_accuracy
|
| 85 |
"""
|
| 86 |
+
try:
|
| 87 |
+
import lightning.pytorch as pl
|
| 88 |
+
from lightning.pytorch.callbacks import EarlyStopping
|
| 89 |
+
except ImportError:
|
| 90 |
+
import pytorch_lightning as pl # type: ignore[no-redef]
|
| 91 |
+
from pytorch_lightning.callbacks import EarlyStopping # type: ignore[no-redef]
|
| 92 |
from optuna.integration import PyTorchLightningPruningCallback
|
| 93 |
|
| 94 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
|
|
|
| 152 |
Dict with best params, best value, and study summary.
|
| 153 |
"""
|
| 154 |
import optuna
|
| 155 |
+
try:
|
| 156 |
+
import lightning.pytorch as pl
|
| 157 |
+
except ImportError:
|
| 158 |
+
import pytorch_lightning as pl # type: ignore[no-redef]
|
| 159 |
|
| 160 |
from app.db import SessionLocal, init_db
|
| 161 |
from deep_learning.data.feature_store import build_tft_dataframe
|
deep_learning/training/trainer.py
CHANGED
|
@@ -39,8 +39,15 @@ def train_tft_model(
|
|
| 39 |
Returns:
|
| 40 |
Dict with metrics, checkpoint path, and feature importance.
|
| 41 |
"""
|
| 42 |
-
|
| 43 |
-
from pytorch_lightning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
from app.db import SessionLocal, init_db
|
| 46 |
from deep_learning.data.feature_store import build_tft_dataframe
|
|
|
|
| 39 |
Returns:
|
| 40 |
Dict with metrics, checkpoint path, and feature importance.
|
| 41 |
"""
|
| 42 |
+
# pytorch_forecasting >=1.0 uses the unified `lightning` package.
|
| 43 |
+
# Importing from `pytorch_lightning` gives a different LightningModule
|
| 44 |
+
# base class, causing "model must be a LightningModule" at trainer.fit().
|
| 45 |
+
try:
|
| 46 |
+
import lightning.pytorch as pl
|
| 47 |
+
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
|
| 48 |
+
except ImportError:
|
| 49 |
+
import pytorch_lightning as pl # type: ignore[no-redef]
|
| 50 |
+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint # type: ignore[no-redef]
|
| 51 |
|
| 52 |
from app.db import SessionLocal, init_db
|
| 53 |
from deep_learning.data.feature_store import build_tft_dataframe
|
requirements.txt
CHANGED
|
@@ -24,7 +24,7 @@ scikit-learn>=1.3.2
|
|
| 24 |
|
| 25 |
# TFT-ASRO Deep Learning
|
| 26 |
pytorch-forecasting>=1.0.0
|
| 27 |
-
|
| 28 |
optuna>=3.5.0
|
| 29 |
joblib>=1.3.0
|
| 30 |
|
|
|
|
| 24 |
|
| 25 |
# TFT-ASRO Deep Learning
|
| 26 |
pytorch-forecasting>=1.0.0
|
| 27 |
+
lightning>=2.0.0
|
| 28 |
optuna>=3.5.0
|
| 29 |
joblib>=1.3.0
|
| 30 |
|