Spaces:
Sleeping
Sleeping
Commit Β·
d3996f2
1
Parent(s): ae129df
feat: v3 models - XGBoost R2=0.9866, GradientBoosting R2=0.9860 as default
Browse files- Add registry_v3 singleton (artifacts/v3/), set as default registry
- Add api/routers/predict_v3.py (/api/v3/* endpoints)
- Update simulate.py router prefix to /api/v3/, use registry_v3
- Update main.py: load v3 at startup, version 3.0.0, expose /api/v3
- Update model_registry.py: v3 R2 values, v3-prefixed scaler loading
- Update download_models.py: v3 as default check
- Update upload_models_to_hub.py: add v3 to upload loop
- Update frontend: default API version to v3, add v3 type support
- Rebuild frontend dist
- api/main.py +11 -7
- api/model_registry.py +35 -19
- api/routers/predict_v3.py +143 -0
- api/routers/simulate.py +2 -2
- frontend/src/App.tsx +2 -2
- frontend/src/api.ts +4 -4
- frontend/src/components/VersionSelector.tsx +3 -3
- notebooks/02_feature_engineering.ipynb +0 -0
- notebooks/03_classical_ml.ipynb +257 -88
- notebooks/04_lstm_rnn.ipynb +0 -0
- notebooks/05_transformer.ipynb +0 -0
- notebooks/06_dynamic_graph.ipynb +0 -0
- notebooks/07_vae_lstm.ipynb +0 -0
- notebooks/08_ensemble.ipynb +0 -0
- notebooks/09_evaluation.ipynb +0 -0
- scripts/download_models.py +5 -5
- scripts/upload_models_to_hub.py +12 -13
- src/data/features.py +73 -0
- src/utils/config.py +19 -2
- src/utils/plotting.py +7 -3
api/main.py
CHANGED
|
@@ -51,13 +51,13 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 51 |
from fastapi.staticfiles import StaticFiles
|
| 52 |
from fastapi.responses import FileResponse
|
| 53 |
|
| 54 |
-
from api.model_registry import registry, registry_v1, registry_v2
|
| 55 |
from api.schemas import HealthResponse
|
| 56 |
from src.utils.logger import get_logger
|
| 57 |
|
| 58 |
log = get_logger(__name__)
|
| 59 |
|
| 60 |
-
__version__ = "
|
| 61 |
|
| 62 |
# ββ Static frontend path ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
_HERE = Path(__file__).resolve().parent
|
|
@@ -73,6 +73,8 @@ async def lifespan(app: FastAPI):
|
|
| 73 |
log.info("v1 registry ready β %d models loaded", registry_v1.model_count)
|
| 74 |
registry_v2.load_all()
|
| 75 |
log.info("v2 registry ready β %d models loaded", registry_v2.model_count)
|
|
|
|
|
|
|
| 76 |
yield
|
| 77 |
log.info("Shutting down battery-lifecycle API")
|
| 78 |
|
|
@@ -106,13 +108,13 @@ async def health():
|
|
| 106 |
return HealthResponse(
|
| 107 |
status="ok",
|
| 108 |
version=__version__,
|
| 109 |
-
models_loaded=registry_v1.model_count + registry_v2.model_count,
|
| 110 |
device=registry.device,
|
| 111 |
)
|
| 112 |
|
| 113 |
|
| 114 |
# ββ Version management βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
-
_REGISTRIES = {"v1": registry_v1, "v2": registry_v2}
|
| 116 |
_version_status: dict[str, str] = {} # "downloading" | "ready" | "error"
|
| 117 |
|
| 118 |
|
|
@@ -136,7 +138,7 @@ async def list_versions():
|
|
| 136 |
"model_count": _REGISTRIES[v].model_count,
|
| 137 |
"status": _version_status.get(v, "ready" if _version_loaded(v) else "not_downloaded"),
|
| 138 |
}
|
| 139 |
-
for v in ["v2", "v1"]
|
| 140 |
]
|
| 141 |
|
| 142 |
|
|
@@ -176,13 +178,15 @@ async def load_version(version: str, background_tasks: BackgroundTasks):
|
|
| 176 |
# ββ Include routers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
from api.routers.predict import router as predict_router, v1_router
|
| 178 |
from api.routers.predict_v2 import router as predict_v2_router
|
|
|
|
| 179 |
from api.routers.visualize import router as viz_router
|
| 180 |
from api.routers.simulate import router as simulate_router
|
| 181 |
|
| 182 |
app.include_router(predict_router) # /api/* (default, uses v2 registry)
|
| 183 |
app.include_router(v1_router) # /api/v1/* (legacy v1 models)
|
| 184 |
-
app.include_router(predict_v2_router) # /api/v2/* (v2 models
|
| 185 |
-
app.include_router(
|
|
|
|
| 186 |
app.include_router(viz_router)
|
| 187 |
|
| 188 |
|
|
|
|
| 51 |
from fastapi.staticfiles import StaticFiles
|
| 52 |
from fastapi.responses import FileResponse
|
| 53 |
|
| 54 |
+
from api.model_registry import registry, registry_v1, registry_v2, registry_v3
|
| 55 |
from api.schemas import HealthResponse
|
| 56 |
from src.utils.logger import get_logger
|
| 57 |
|
| 58 |
log = get_logger(__name__)
|
| 59 |
|
| 60 |
+
__version__ = "3.0.0"
|
| 61 |
|
| 62 |
# ββ Static frontend path ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
_HERE = Path(__file__).resolve().parent
|
|
|
|
| 73 |
log.info("v1 registry ready β %d models loaded", registry_v1.model_count)
|
| 74 |
registry_v2.load_all()
|
| 75 |
log.info("v2 registry ready β %d models loaded", registry_v2.model_count)
|
| 76 |
+
registry_v3.load_all()
|
| 77 |
+
log.info("v3 registry ready β %d models loaded", registry_v3.model_count)
|
| 78 |
yield
|
| 79 |
log.info("Shutting down battery-lifecycle API")
|
| 80 |
|
|
|
|
| 108 |
return HealthResponse(
|
| 109 |
status="ok",
|
| 110 |
version=__version__,
|
| 111 |
+
models_loaded=registry_v1.model_count + registry_v2.model_count + registry_v3.model_count,
|
| 112 |
device=registry.device,
|
| 113 |
)
|
| 114 |
|
| 115 |
|
| 116 |
# ββ Version management βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
+
_REGISTRIES = {"v1": registry_v1, "v2": registry_v2, "v3": registry_v3}
|
| 118 |
_version_status: dict[str, str] = {} # "downloading" | "ready" | "error"
|
| 119 |
|
| 120 |
|
|
|
|
| 138 |
"model_count": _REGISTRIES[v].model_count,
|
| 139 |
"status": _version_status.get(v, "ready" if _version_loaded(v) else "not_downloaded"),
|
| 140 |
}
|
| 141 |
+
for v in ["v3", "v2", "v1"]
|
| 142 |
]
|
| 143 |
|
| 144 |
|
|
|
|
| 178 |
# ββ Include routers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
from api.routers.predict import router as predict_router, v1_router
|
| 180 |
from api.routers.predict_v2 import router as predict_v2_router
|
| 181 |
+
from api.routers.predict_v3 import router as predict_v3_router
|
| 182 |
from api.routers.visualize import router as viz_router
|
| 183 |
from api.routers.simulate import router as simulate_router
|
| 184 |
|
| 185 |
app.include_router(predict_router) # /api/* (default, uses v2 registry)
|
| 186 |
app.include_router(v1_router) # /api/v1/* (legacy v1 models)
|
| 187 |
+
app.include_router(predict_v2_router) # /api/v2/* (v2 models)
|
| 188 |
+
app.include_router(predict_v3_router) # /api/v3/* (v3 models, best accuracy)
|
| 189 |
+
app.include_router(simulate_router) # /api/v3/simulate (ML-driven simulation)
|
| 190 |
app.include_router(viz_router)
|
| 191 |
|
| 192 |
|
api/model_registry.py
CHANGED
|
@@ -72,9 +72,9 @@ FEATURE_COLS_SCALAR: list[str] = [
|
|
| 72 |
|
| 73 |
# ββ Model catalog (single source of truth for versions & metadata) ββββββββββββ
|
| 74 |
MODEL_CATALOG: dict[str, dict[str, Any]] = {
|
| 75 |
-
"random_forest": {"version": "
|
| 76 |
-
"xgboost": {"version": "
|
| 77 |
-
"lightgbm": {"version": "
|
| 78 |
"ridge": {"version": "1.0.0", "display_name": "Ridge Regression", "family": "classical", "algorithm": "Ridge", "target": "soh", "r2": 0.72},
|
| 79 |
"svr": {"version": "1.0.0", "display_name": "SVR (RBF)", "family": "classical", "algorithm": "SVR", "target": "soh", "r2": 0.805},
|
| 80 |
"lasso": {"version": "1.0.0", "display_name": "Lasso", "family": "classical", "algorithm": "Lasso", "target": "soh", "r2": 0.52},
|
|
@@ -82,8 +82,8 @@ MODEL_CATALOG: dict[str, dict[str, Any]] = {
|
|
| 82 |
"knn_k5": {"version": "1.0.0", "display_name": "KNN (k=5)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.72},
|
| 83 |
"knn_k10": {"version": "1.0.0", "display_name": "KNN (k=10)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.724},
|
| 84 |
"knn_k20": {"version": "1.0.0", "display_name": "KNN (k=20)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.717},
|
| 85 |
-
"extra_trees": {"version": "
|
| 86 |
-
"gradient_boosting": {"version": "
|
| 87 |
"vanilla_lstm": {"version": "2.0.0", "display_name": "Vanilla LSTM", "family": "deep_pytorch", "algorithm": "VanillaLSTM", "target": "soh", "r2": 0.507},
|
| 88 |
"bidirectional_lstm": {"version": "2.0.0", "display_name": "Bidirectional LSTM", "family": "deep_pytorch", "algorithm": "BidirectionalLSTM", "target": "soh", "r2": 0.520},
|
| 89 |
"gru": {"version": "2.0.0", "display_name": "GRU", "family": "deep_pytorch", "algorithm": "GRUModel", "target": "soh", "r2": 0.510},
|
|
@@ -94,16 +94,16 @@ MODEL_CATALOG: dict[str, dict[str, Any]] = {
|
|
| 94 |
"itransformer": {"version": "2.4.0", "display_name": "iTransformer", "family": "deep_keras", "algorithm": "iTransformer", "target": "soh", "r2": 0.595},
|
| 95 |
"physics_itransformer": {"version": "2.4.1", "display_name": "Physics iTransformer", "family": "deep_keras", "algorithm": "PhysicsITransformer", "target": "soh", "r2": 0.600},
|
| 96 |
"dynamic_graph_itransformer": {"version": "2.5.0", "display_name": "DG-iTransformer", "family": "deep_keras", "algorithm": "DynamicGraphITransformer", "target": "soh", "r2": 0.595},
|
| 97 |
-
"best_ensemble": {"version": "3.0.0", "display_name": "Best Ensemble (RF+XGB+LGB)", "family": "ensemble", "algorithm": "WeightedAverage", "target": "soh", "r2": 0.
|
| 98 |
}
|
| 99 |
|
| 100 |
-
# RΒ²-proportional weights for BestEnsemble
|
| 101 |
_ENSEMBLE_WEIGHTS: dict[str, float] = {
|
| 102 |
-
"random_forest":
|
| 103 |
-
"xgboost":
|
| 104 |
-
"lightgbm":
|
| 105 |
-
"extra_trees":
|
| 106 |
-
"gradient_boosting": 0.
|
| 107 |
}
|
| 108 |
|
| 109 |
|
|
@@ -372,11 +372,17 @@ class ModelRegistry:
|
|
| 372 |
# Tree models (RF, ET, GB, XGB, LGB) were fitted on raw numpy X_train
|
| 373 |
# β NO scaler applied, passed as-is
|
| 374 |
#
|
| 375 |
-
#
|
| 376 |
-
#
|
| 377 |
-
#
|
| 378 |
scalers_dir = self._scalers_dir
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
sp = scalers_dir / fname
|
| 381 |
if sp.exists():
|
| 382 |
try:
|
|
@@ -386,7 +392,16 @@ class ModelRegistry:
|
|
| 386 |
except Exception as exc:
|
| 387 |
log.warning("Could not load %s: %s", fname, exc)
|
| 388 |
else:
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
sp_seq = scalers_dir / "sequence_scaler.joblib"
|
| 392 |
if sp_seq.exists():
|
|
@@ -789,6 +804,7 @@ class ModelRegistry:
|
|
| 789 |
# ββ Singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 790 |
registry_v1 = ModelRegistry(version="v1")
|
| 791 |
registry_v2 = ModelRegistry(version="v2")
|
|
|
|
| 792 |
|
| 793 |
-
# Default registry β
|
| 794 |
-
registry =
|
|
|
|
| 72 |
|
| 73 |
# ββ Model catalog (single source of truth for versions & metadata) ββββββββββββ
|
| 74 |
MODEL_CATALOG: dict[str, dict[str, Any]] = {
|
| 75 |
+
"random_forest": {"version": "3.0.0", "display_name": "Random Forest", "family": "classical", "algorithm": "RandomForestRegressor", "target": "soh", "r2": 0.9814},
|
| 76 |
+
"xgboost": {"version": "3.0.0", "display_name": "XGBoost", "family": "classical", "algorithm": "XGBRegressor", "target": "soh", "r2": 0.9866},
|
| 77 |
+
"lightgbm": {"version": "3.0.0", "display_name": "LightGBM", "family": "classical", "algorithm": "LGBMRegressor", "target": "soh", "r2": 0.9826},
|
| 78 |
"ridge": {"version": "1.0.0", "display_name": "Ridge Regression", "family": "classical", "algorithm": "Ridge", "target": "soh", "r2": 0.72},
|
| 79 |
"svr": {"version": "1.0.0", "display_name": "SVR (RBF)", "family": "classical", "algorithm": "SVR", "target": "soh", "r2": 0.805},
|
| 80 |
"lasso": {"version": "1.0.0", "display_name": "Lasso", "family": "classical", "algorithm": "Lasso", "target": "soh", "r2": 0.52},
|
|
|
|
| 82 |
"knn_k5": {"version": "1.0.0", "display_name": "KNN (k=5)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.72},
|
| 83 |
"knn_k10": {"version": "1.0.0", "display_name": "KNN (k=10)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.724},
|
| 84 |
"knn_k20": {"version": "1.0.0", "display_name": "KNN (k=20)", "family": "classical", "algorithm": "KNeighborsRegressor", "target": "soh", "r2": 0.717},
|
| 85 |
+
"extra_trees": {"version": "3.0.0", "display_name": "ExtraTrees", "family": "classical", "algorithm": "ExtraTreesRegressor", "target": "soh", "r2": 0.9701},
|
| 86 |
+
"gradient_boosting": {"version": "3.0.0", "display_name": "GradientBoosting", "family": "classical", "algorithm": "GradientBoostingRegressor", "target": "soh", "r2": 0.9860},
|
| 87 |
"vanilla_lstm": {"version": "2.0.0", "display_name": "Vanilla LSTM", "family": "deep_pytorch", "algorithm": "VanillaLSTM", "target": "soh", "r2": 0.507},
|
| 88 |
"bidirectional_lstm": {"version": "2.0.0", "display_name": "Bidirectional LSTM", "family": "deep_pytorch", "algorithm": "BidirectionalLSTM", "target": "soh", "r2": 0.520},
|
| 89 |
"gru": {"version": "2.0.0", "display_name": "GRU", "family": "deep_pytorch", "algorithm": "GRUModel", "target": "soh", "r2": 0.510},
|
|
|
|
| 94 |
"itransformer": {"version": "2.4.0", "display_name": "iTransformer", "family": "deep_keras", "algorithm": "iTransformer", "target": "soh", "r2": 0.595},
|
| 95 |
"physics_itransformer": {"version": "2.4.1", "display_name": "Physics iTransformer", "family": "deep_keras", "algorithm": "PhysicsITransformer", "target": "soh", "r2": 0.600},
|
| 96 |
"dynamic_graph_itransformer": {"version": "2.5.0", "display_name": "DG-iTransformer", "family": "deep_keras", "algorithm": "DynamicGraphITransformer", "target": "soh", "r2": 0.595},
|
| 97 |
+
"best_ensemble": {"version": "3.0.0", "display_name": "Best Ensemble (RF+XGB+LGB)", "family": "ensemble", "algorithm": "WeightedAverage", "target": "soh", "r2": 0.9810},
|
| 98 |
}
|
| 99 |
|
| 100 |
+
# RΒ²-proportional weights for BestEnsemble (v3 values)
|
| 101 |
_ENSEMBLE_WEIGHTS: dict[str, float] = {
|
| 102 |
+
"random_forest": 0.9814,
|
| 103 |
+
"xgboost": 0.9866,
|
| 104 |
+
"lightgbm": 0.9826,
|
| 105 |
+
"extra_trees": 0.9701,
|
| 106 |
+
"gradient_boosting": 0.9860,
|
| 107 |
}
|
| 108 |
|
| 109 |
|
|
|
|
| 372 |
# Tree models (RF, ET, GB, XGB, LGB) were fitted on raw numpy X_train
|
| 373 |
# β NO scaler applied, passed as-is
|
| 374 |
#
|
| 375 |
+
# v3 scalers use a version-prefixed naming scheme:
|
| 376 |
+
# {version}_features_standard.joblib β StandardScaler
|
| 377 |
+
# {version}_features_minmax.joblib β MinMaxScaler (fallback)
|
| 378 |
scalers_dir = self._scalers_dir
|
| 379 |
+
version_prefix = self.version # e.g. "v3"
|
| 380 |
+
candidate_linear = (
|
| 381 |
+
f"{version_prefix}_features_standard.joblib",
|
| 382 |
+
"standard_scaler.joblib",
|
| 383 |
+
"linear_scaler.joblib",
|
| 384 |
+
)
|
| 385 |
+
for fname in candidate_linear:
|
| 386 |
sp = scalers_dir / fname
|
| 387 |
if sp.exists():
|
| 388 |
try:
|
|
|
|
| 392 |
except Exception as exc:
|
| 393 |
log.warning("Could not load %s: %s", fname, exc)
|
| 394 |
else:
|
| 395 |
+
# Try minmax as last resort (v3 fallback)
|
| 396 |
+
sp_mm = scalers_dir / f"{version_prefix}_features_minmax.joblib"
|
| 397 |
+
if sp_mm.exists():
|
| 398 |
+
try:
|
| 399 |
+
self.linear_scaler = joblib.load(sp_mm)
|
| 400 |
+
log.info("Linear scaler (minmax fallback) loaded from %s", sp_mm)
|
| 401 |
+
except Exception as exc:
|
| 402 |
+
log.warning("Could not load minmax scaler: %s", exc)
|
| 403 |
+
else:
|
| 404 |
+
log.warning("No linear scaler found β Ridge/Lasso/SVR/KNN will use raw features")
|
| 405 |
|
| 406 |
sp_seq = scalers_dir / "sequence_scaler.joblib"
|
| 407 |
if sp_seq.exists():
|
|
|
|
| 804 |
# ββ Singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 805 |
registry_v1 = ModelRegistry(version="v1")
|
| 806 |
registry_v2 = ModelRegistry(version="v2")
|
| 807 |
+
registry_v3 = ModelRegistry(version="v3")
|
| 808 |
|
| 809 |
+
# Default registry β v3 (best models, highest RΒ²)
|
| 810 |
+
registry = registry_v3
|
api/routers/predict_v3.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
api.routers.predict_v3
|
| 3 |
+
======================
|
| 4 |
+
v3 prediction & recommendation endpoints.
|
| 5 |
+
|
| 6 |
+
v3 improvements over v2:
|
| 7 |
+
- Higher accuracy classical models (XGBoost RΒ²=0.9866, GradientBoosting RΒ²=0.9860)
|
| 8 |
+
- Updated ensemble weights proportional to v3 RΒ² values
|
| 9 |
+
- Version-aware model loading from artifacts/v3/
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from fastapi import APIRouter, HTTPException
|
| 15 |
+
|
| 16 |
+
from api.model_registry import registry_v3, classify_degradation, soh_to_color
|
| 17 |
+
from api.schemas import (
|
| 18 |
+
PredictRequest, PredictResponse,
|
| 19 |
+
BatchPredictRequest, BatchPredictResponse,
|
| 20 |
+
RecommendationRequest, RecommendationResponse, SingleRecommendation,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
router = APIRouter(prefix="/api/v3", tags=["v3-prediction"])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ββ Single prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
@router.post("/predict", response_model=PredictResponse)
|
| 28 |
+
async def predict_v3(req: PredictRequest):
|
| 29 |
+
"""Predict SOH for a single cycle using v3 models."""
|
| 30 |
+
features = req.model_dump(exclude={"battery_id"})
|
| 31 |
+
features["voltage_range"] = features["peak_voltage"] - features["min_voltage"]
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
result = registry_v3.predict(features)
|
| 35 |
+
except Exception as exc:
|
| 36 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 37 |
+
|
| 38 |
+
return PredictResponse(
|
| 39 |
+
battery_id=req.battery_id,
|
| 40 |
+
cycle_number=req.cycle_number,
|
| 41 |
+
soh_pct=result["soh_pct"],
|
| 42 |
+
rul_cycles=result["rul_cycles"],
|
| 43 |
+
degradation_state=result["degradation_state"],
|
| 44 |
+
confidence_lower=result["confidence_lower"],
|
| 45 |
+
confidence_upper=result["confidence_upper"],
|
| 46 |
+
model_used=result["model_used"],
|
| 47 |
+
model_version=result.get("model_version", "3.0.0"),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ββ Batch prediction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
@router.post("/predict/batch", response_model=BatchPredictResponse)
|
| 53 |
+
async def predict_batch_v3(req: BatchPredictRequest):
|
| 54 |
+
"""Predict SOH for multiple cycles using v3 models."""
|
| 55 |
+
results = registry_v3.predict_batch(req.battery_id, req.cycles)
|
| 56 |
+
predictions = [
|
| 57 |
+
PredictResponse(
|
| 58 |
+
battery_id=req.battery_id,
|
| 59 |
+
cycle_number=r["cycle_number"],
|
| 60 |
+
soh_pct=r["soh_pct"],
|
| 61 |
+
rul_cycles=r["rul_cycles"],
|
| 62 |
+
degradation_state=r["degradation_state"],
|
| 63 |
+
confidence_lower=r.get("confidence_lower"),
|
| 64 |
+
confidence_upper=r.get("confidence_upper"),
|
| 65 |
+
model_used=r["model_used"],
|
| 66 |
+
model_version=r.get("model_version", "3.0.0"),
|
| 67 |
+
)
|
| 68 |
+
for r in results
|
| 69 |
+
]
|
| 70 |
+
return BatchPredictResponse(battery_id=req.battery_id, predictions=predictions)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ββ Recommendations (v3) βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
@router.post("/recommend", response_model=RecommendationResponse)
|
| 75 |
+
async def recommend_v3(req: RecommendationRequest):
|
| 76 |
+
"""Get operational recommendations using v3 models."""
|
| 77 |
+
import itertools
|
| 78 |
+
|
| 79 |
+
temps = [4.0, 24.0, 43.0]
|
| 80 |
+
currents = [0.5, 1.0, 2.0, 4.0]
|
| 81 |
+
cutoffs = [2.0, 2.2, 2.5, 2.7]
|
| 82 |
+
|
| 83 |
+
EOL_THRESHOLD = 70.0
|
| 84 |
+
deg_rate = 0.2
|
| 85 |
+
if req.current_soh > EOL_THRESHOLD:
|
| 86 |
+
baseline_rul = (req.current_soh - EOL_THRESHOLD) / deg_rate
|
| 87 |
+
else:
|
| 88 |
+
baseline_rul = 0.0
|
| 89 |
+
|
| 90 |
+
base_features = {
|
| 91 |
+
"cycle_number": req.current_cycle,
|
| 92 |
+
"ambient_temperature": req.ambient_temperature,
|
| 93 |
+
"peak_voltage": 4.19,
|
| 94 |
+
"min_voltage": 2.61,
|
| 95 |
+
"voltage_range": 4.19 - 2.61,
|
| 96 |
+
"avg_current": 1.82,
|
| 97 |
+
"avg_temp": req.ambient_temperature + 8.0,
|
| 98 |
+
"temp_rise": 15.0,
|
| 99 |
+
"cycle_duration": 3690.0,
|
| 100 |
+
"Re": 0.045,
|
| 101 |
+
"Rct": 0.069,
|
| 102 |
+
"delta_capacity": -0.005,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
candidates = []
|
| 106 |
+
for t, c, v in itertools.product(temps, currents, cutoffs):
|
| 107 |
+
feat = {**base_features, "ambient_temperature": t, "avg_current": c,
|
| 108 |
+
"min_voltage": v, "voltage_range": 4.19 - v,
|
| 109 |
+
"avg_temp": t + 8.0}
|
| 110 |
+
result = registry_v3.predict(feat)
|
| 111 |
+
rul = result.get("rul_cycles", 0) or 0
|
| 112 |
+
candidates.append((rul, t, c, v, result["soh_pct"]))
|
| 113 |
+
|
| 114 |
+
candidates.sort(reverse=True)
|
| 115 |
+
top = candidates[: req.top_k]
|
| 116 |
+
|
| 117 |
+
recs = []
|
| 118 |
+
for rank, (rul, t, c, v, soh) in enumerate(top, 1):
|
| 119 |
+
improvement = rul - baseline_rul
|
| 120 |
+
pct = (improvement / baseline_rul * 100) if baseline_rul > 0 else 0
|
| 121 |
+
recs.append(SingleRecommendation(
|
| 122 |
+
rank=rank,
|
| 123 |
+
ambient_temperature=t,
|
| 124 |
+
discharge_current=c,
|
| 125 |
+
cutoff_voltage=v,
|
| 126 |
+
predicted_rul=rul,
|
| 127 |
+
rul_improvement=improvement,
|
| 128 |
+
rul_improvement_pct=round(pct, 1),
|
| 129 |
+
explanation=f"Operate at {t}Β°C, {c}A, cutoff {v}V for ~{rul:.0f} cycles RUL",
|
| 130 |
+
))
|
| 131 |
+
|
| 132 |
+
return RecommendationResponse(
|
| 133 |
+
battery_id=req.battery_id,
|
| 134 |
+
current_soh=req.current_soh,
|
| 135 |
+
recommendations=recs,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ββ Model listing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
@router.get("/models")
|
| 141 |
+
async def list_models_v3():
|
| 142 |
+
"""List all v3 registered models."""
|
| 143 |
+
return registry_v3.list_models()
|
api/routers/simulate.py
CHANGED
|
@@ -28,12 +28,12 @@ from fastapi import APIRouter
|
|
| 28 |
from pydantic import BaseModel, Field
|
| 29 |
|
| 30 |
from api.model_registry import (
|
| 31 |
-
FEATURE_COLS_SCALAR, classify_degradation, soh_to_color, registry_v2,
|
| 32 |
)
|
| 33 |
|
| 34 |
log = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
-
router = APIRouter(prefix="/api/
|
| 37 |
|
| 38 |
# -- Physics constants --------------------------------------------------------
|
| 39 |
_EA_OVER_R = 6200.0 # Ea/R in Kelvin
|
|
|
|
| 28 |
from pydantic import BaseModel, Field
|
| 29 |
|
| 30 |
from api.model_registry import (
|
| 31 |
+
FEATURE_COLS_SCALAR, classify_degradation, soh_to_color, registry_v3 as registry_v2,
|
| 32 |
)
|
| 33 |
|
| 34 |
log = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
+
router = APIRouter(prefix="/api/v3", tags=["simulation"])
|
| 37 |
|
| 38 |
# -- Physics constants --------------------------------------------------------
|
| 39 |
_EA_OVER_R = 6200.0 # Ea/R in Kelvin
|
frontend/src/App.tsx
CHANGED
|
@@ -14,9 +14,9 @@ type Tab = "simulation" | "predict" | "graphs" | "recommend" | "metrics" | "pape
|
|
| 14 |
|
| 15 |
export default function App() {
|
| 16 |
const [activeTab, setActiveTab] = useState<Tab>("simulation");
|
| 17 |
-
const [apiVersion, setVersion] = useState<"v1" | "v2">(getApiVersion());
|
| 18 |
|
| 19 |
-
const handleVersionChange = (v: "v1" | "v2") => {
|
| 20 |
setApiVersion(v);
|
| 21 |
setVersion(v);
|
| 22 |
};
|
|
|
|
| 14 |
|
| 15 |
export default function App() {
|
| 16 |
const [activeTab, setActiveTab] = useState<Tab>("simulation");
|
| 17 |
+
const [apiVersion, setVersion] = useState<"v1" | "v2" | "v3">(getApiVersion());
|
| 18 |
|
| 19 |
+
const handleVersionChange = (v: "v1" | "v2" | "v3") => {
|
| 20 |
setApiVersion(v);
|
| 21 |
setVersion(v);
|
| 22 |
};
|
frontend/src/api.ts
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import axios from "axios";
|
| 2 |
|
| 3 |
-
/** Active API version β toggle between v1 (legacy) and
|
| 4 |
-
let _apiVersion: "v1" | "v2" = "
|
| 5 |
|
| 6 |
export const getApiVersion = () => _apiVersion;
|
| 7 |
-
export const setApiVersion = (v: "v1" | "v2") => {
|
| 8 |
_apiVersion = v;
|
| 9 |
};
|
| 10 |
|
|
@@ -68,7 +68,7 @@ export interface ModelVersionGroups {
|
|
| 68 |
|
| 69 |
// ββ Version management βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
export interface VersionInfo {
|
| 71 |
-
id: string; // "v1" | "v2"
|
| 72 |
display: string; // "Version 1" | "Version 2"
|
| 73 |
loaded: boolean;
|
| 74 |
model_count: number;
|
|
|
|
| 1 |
import axios from "axios";
|
| 2 |
|
| 3 |
+
/** Active API version β toggle between v1 (legacy), v2, and v3 (latest). */
|
| 4 |
+
let _apiVersion: "v1" | "v2" | "v3" = "v3";
|
| 5 |
|
| 6 |
export const getApiVersion = () => _apiVersion;
|
| 7 |
+
export const setApiVersion = (v: "v1" | "v2" | "v3") => {
|
| 8 |
_apiVersion = v;
|
| 9 |
};
|
| 10 |
|
|
|
|
| 68 |
|
| 69 |
// ββ Version management βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
export interface VersionInfo {
|
| 71 |
+
id: string; // "v1" | "v2" | "v3"
|
| 72 |
display: string; // "Version 1" | "Version 2"
|
| 73 |
loaded: boolean;
|
| 74 |
model_count: number;
|
frontend/src/components/VersionSelector.tsx
CHANGED
|
@@ -16,8 +16,8 @@ import {
|
|
| 16 |
import { fetchVersions, loadVersion, VersionInfo } from "../api";
|
| 17 |
|
| 18 |
interface Props {
|
| 19 |
-
activeVersion: "v1" | "v2";
|
| 20 |
-
onSwitch: (v: "v1" | "v2") => void;
|
| 21 |
}
|
| 22 |
|
| 23 |
export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
@@ -78,7 +78,7 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 78 |
};
|
| 79 |
|
| 80 |
const handleSwitch = (version: string) => {
|
| 81 |
-
onSwitch(version as "v1" | "v2");
|
| 82 |
setOpen(false);
|
| 83 |
};
|
| 84 |
|
|
|
|
| 16 |
import { fetchVersions, loadVersion, VersionInfo } from "../api";
|
| 17 |
|
| 18 |
interface Props {
|
| 19 |
+
activeVersion: "v1" | "v2" | "v3";
|
| 20 |
+
onSwitch: (v: "v1" | "v2" | "v3") => void;
|
| 21 |
}
|
| 22 |
|
| 23 |
export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
|
|
| 78 |
};
|
| 79 |
|
| 80 |
const handleSwitch = (version: string) => {
|
| 81 |
+
onSwitch(version as "v1" | "v2" | "v3");
|
| 82 |
setOpen(false);
|
| 83 |
};
|
| 84 |
|
notebooks/02_feature_engineering.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/03_classical_ml.ipynb
CHANGED
|
@@ -4,18 +4,25 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# 03 β Classical ML Models (
|
| 8 |
-
"## SOH Regression:
|
| 9 |
"\n",
|
| 10 |
-
"**
|
| 11 |
-
"- Load preprocessed `battery_features.csv` from NB02\n",
|
| 12 |
-
"-
|
| 13 |
-
"- Train 8 core models with
|
|
|
|
| 14 |
"- Target: β₯95% within-Β±5% SOH accuracy on all models\n",
|
| 15 |
-
"- Save artifacts to `artifacts/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"\n",
|
| 17 |
"**Models (8 total):**\n",
|
| 18 |
-
"1. ExtraTrees (tree-based,
|
| 19 |
"2. GradientBoosting (sequential ensemble)\n",
|
| 20 |
"3. RandomForest (bagging ensemble)\n",
|
| 21 |
"4. XGBoost (boosted trees with tuning)\n",
|
|
@@ -27,9 +34,17 @@
|
|
| 27 |
},
|
| 28 |
{
|
| 29 |
"cell_type": "code",
|
| 30 |
-
"execution_count":
|
| 31 |
"metadata": {},
|
| 32 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"source": [
|
| 34 |
"import sys, os\n",
|
| 35 |
"sys.path.insert(0, os.path.abspath('..'))\n",
|
|
@@ -52,76 +67,126 @@
|
|
| 52 |
"from xgboost import XGBRegressor\n",
|
| 53 |
"from lightgbm import LGBMRegressor\n",
|
| 54 |
"\n",
|
|
|
|
|
|
|
| 55 |
"print('Setup complete.')"
|
| 56 |
]
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
-
"execution_count":
|
| 61 |
"metadata": {},
|
| 62 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"source": [
|
| 64 |
-
"# Setup
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
"\n",
|
| 70 |
-
"print(f'
|
| 71 |
-
"print(f'
|
| 72 |
-
"print(f'
|
|
|
|
| 73 |
]
|
| 74 |
},
|
| 75 |
{
|
| 76 |
"cell_type": "code",
|
| 77 |
-
"execution_count":
|
| 78 |
"metadata": {},
|
| 79 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"source": [
|
| 81 |
-
"# Load preprocessed features from NB02\n",
|
| 82 |
-
"features_df = pd.read_csv(
|
| 83 |
"print(f'Dataset shape: {features_df.shape}')\n",
|
| 84 |
"print(f'Batteries: {sorted(features_df[\"battery_id\"].unique())}')\n",
|
| 85 |
-
"print(f'SOH range: {features_df[\"SoH\"].min():.1f}% β {features_df[\"SoH\"].max():.1f}%')"
|
|
|
|
| 86 |
]
|
| 87 |
},
|
| 88 |
{
|
| 89 |
"cell_type": "code",
|
| 90 |
-
"execution_count":
|
| 91 |
"metadata": {},
|
| 92 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"source": [
|
| 94 |
-
"#
|
| 95 |
-
"
|
| 96 |
-
"
|
|
|
|
|
|
|
|
|
|
| 97 |
"\n",
|
| 98 |
-
"
|
| 99 |
-
" bat_df = features_df[features_df['battery_id'] == bid].sort_values('cycle_number')\n",
|
| 100 |
-
" cut_idx = int(len(bat_df) * 0.8)\n",
|
| 101 |
-
" train_parts.append(bat_df.iloc[:cut_idx])\n",
|
| 102 |
-
" test_parts.append(bat_df.iloc[cut_idx:])\n",
|
| 103 |
"\n",
|
| 104 |
-
"train_df
|
| 105 |
-
"test_df
|
|
|
|
|
|
|
| 106 |
"\n",
|
| 107 |
-
"
|
| 108 |
-
"print(f'
|
| 109 |
"print(f'Train SOH: {train_df[\"SoH\"].min():.1f}% β {train_df[\"SoH\"].max():.1f}%')\n",
|
| 110 |
-
"print(f'Test SOH:
|
| 111 |
]
|
| 112 |
},
|
| 113 |
{
|
| 114 |
"cell_type": "code",
|
| 115 |
-
"execution_count":
|
| 116 |
"metadata": {},
|
| 117 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
"source": [
|
| 119 |
-
"#
|
| 120 |
-
"feature_cols = [\n",
|
| 121 |
-
"
|
| 122 |
-
" 'voltage_range', 'avg_current', 'avg_temp', 'temp_rise', 'cycle_duration',\n",
|
| 123 |
-
" 'Re', 'Rct', 'delta_capacity'\n",
|
| 124 |
-
"]\n",
|
| 125 |
"\n",
|
| 126 |
"X_train = train_df[feature_cols].values\n",
|
| 127 |
"y_train = train_df['SoH'].values\n",
|
|
@@ -129,28 +194,37 @@
|
|
| 129 |
"y_test = test_df['SoH'].values\n",
|
| 130 |
"\n",
|
| 131 |
"print(f'X_train: {X_train.shape}')\n",
|
| 132 |
-
"print(f'y_train: {y_train.shape}')"
|
|
|
|
|
|
|
| 133 |
]
|
| 134 |
},
|
| 135 |
{
|
| 136 |
"cell_type": "code",
|
| 137 |
-
"execution_count":
|
| 138 |
"metadata": {},
|
| 139 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
"source": [
|
| 141 |
-
"#
|
| 142 |
-
"scaler =
|
| 143 |
-
"X_train_scaled = scaler.
|
| 144 |
"X_test_scaled = scaler.transform(X_test)\n",
|
| 145 |
-
"\n",
|
| 146 |
-
"
|
| 147 |
-
"joblib.dump(scaler, V2_SCALERS / 'standard_scaler.joblib')\n",
|
| 148 |
-
"print('Scaler saved.')"
|
| 149 |
]
|
| 150 |
},
|
| 151 |
{
|
| 152 |
"cell_type": "code",
|
| 153 |
-
"execution_count":
|
| 154 |
"metadata": {},
|
| 155 |
"outputs": [],
|
| 156 |
"source": [
|
|
@@ -173,9 +247,17 @@
|
|
| 173 |
},
|
| 174 |
{
|
| 175 |
"cell_type": "code",
|
| 176 |
-
"execution_count":
|
| 177 |
"metadata": {},
|
| 178 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
"source": [
|
| 180 |
"# ExtraTrees (unscaled)\n",
|
| 181 |
"model_et = ExtraTreesRegressor(\n",
|
|
@@ -187,14 +269,22 @@
|
|
| 187 |
")\n",
|
| 188 |
"model_et.fit(X_train, y_train)\n",
|
| 189 |
"_, metrics_et = evaluate_model('ExtraTrees', model_et, X_test, y_test,\n",
|
| 190 |
-
"
|
| 191 |
]
|
| 192 |
},
|
| 193 |
{
|
| 194 |
"cell_type": "code",
|
| 195 |
-
"execution_count":
|
| 196 |
"metadata": {},
|
| 197 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
"source": [
|
| 199 |
"# GradientBoosting (unscaled)\n",
|
| 200 |
"model_gb = GradientBoostingRegressor(\n",
|
|
@@ -206,14 +296,22 @@
|
|
| 206 |
")\n",
|
| 207 |
"model_gb.fit(X_train, y_train)\n",
|
| 208 |
"_, metrics_gb = evaluate_model('GradientBoosting', model_gb, X_test, y_test,\n",
|
| 209 |
-
"
|
| 210 |
]
|
| 211 |
},
|
| 212 |
{
|
| 213 |
"cell_type": "code",
|
| 214 |
-
"execution_count":
|
| 215 |
"metadata": {},
|
| 216 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
"source": [
|
| 218 |
"# RandomForest (unscaled)\n",
|
| 219 |
"model_rf = RandomForestRegressor(\n",
|
|
@@ -225,14 +323,22 @@
|
|
| 225 |
")\n",
|
| 226 |
"model_rf.fit(X_train, y_train)\n",
|
| 227 |
"_, metrics_rf = evaluate_model('RandomForest', model_rf, X_test, y_test,\n",
|
| 228 |
-
"
|
| 229 |
]
|
| 230 |
},
|
| 231 |
{
|
| 232 |
"cell_type": "code",
|
| 233 |
-
"execution_count":
|
| 234 |
"metadata": {},
|
| 235 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
"source": [
|
| 237 |
"# XGBoost (unscaled, tuned hyperparameters)\n",
|
| 238 |
"model_xgb = XGBRegressor(\n",
|
|
@@ -247,14 +353,22 @@
|
|
| 247 |
")\n",
|
| 248 |
"model_xgb.fit(X_train, y_train)\n",
|
| 249 |
"_, metrics_xgb = evaluate_model('XGBoost', model_xgb, X_test, y_test,\n",
|
| 250 |
-
"
|
| 251 |
]
|
| 252 |
},
|
| 253 |
{
|
| 254 |
"cell_type": "code",
|
| 255 |
-
"execution_count":
|
| 256 |
"metadata": {},
|
| 257 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
"source": [
|
| 259 |
"# LightGBM (unscaled, tuned hyperparameters)\n",
|
| 260 |
"model_lgbm = LGBMRegressor(\n",
|
|
@@ -269,14 +383,22 @@
|
|
| 269 |
")\n",
|
| 270 |
"model_lgbm.fit(X_train, y_train)\n",
|
| 271 |
"_, metrics_lgbm = evaluate_model('LightGBM', model_lgbm, X_test, y_test,\n",
|
| 272 |
-
"
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "code",
|
| 277 |
-
"execution_count":
|
| 278 |
"metadata": {},
|
| 279 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
"source": [
|
| 281 |
"# SVR (scaled)\n",
|
| 282 |
"model_svr = SVR(\n",
|
|
@@ -286,14 +408,22 @@
|
|
| 286 |
")\n",
|
| 287 |
"model_svr.fit(X_train_scaled, y_train)\n",
|
| 288 |
"_, metrics_svr = evaluate_model('SVR', model_svr, X_test_scaled, y_test,\n",
|
| 289 |
-
"
|
| 290 |
]
|
| 291 |
},
|
| 292 |
{
|
| 293 |
"cell_type": "code",
|
| 294 |
-
"execution_count":
|
| 295 |
"metadata": {},
|
| 296 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
"source": [
|
| 298 |
"# Ridge (scaled)\n",
|
| 299 |
"model_ridge = Ridge(\n",
|
|
@@ -301,14 +431,22 @@
|
|
| 301 |
")\n",
|
| 302 |
"model_ridge.fit(X_train_scaled, y_train)\n",
|
| 303 |
"_, metrics_ridge = evaluate_model('Ridge', model_ridge, X_test_scaled, y_test,\n",
|
| 304 |
-
"
|
| 305 |
]
|
| 306 |
},
|
| 307 |
{
|
| 308 |
"cell_type": "code",
|
| 309 |
-
"execution_count":
|
| 310 |
"metadata": {},
|
| 311 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
"source": [
|
| 313 |
"# KNN-5 (scaled, with distance weighting)\n",
|
| 314 |
"model_knn5 = KNeighborsRegressor(\n",
|
|
@@ -318,7 +456,7 @@
|
|
| 318 |
")\n",
|
| 319 |
"model_knn5.fit(X_train_scaled, y_train)\n",
|
| 320 |
"_, metrics_knn5 = evaluate_model('KNN-5', model_knn5, X_test_scaled, y_test,\n",
|
| 321 |
-
"
|
| 322 |
]
|
| 323 |
},
|
| 324 |
{
|
|
@@ -326,6 +464,37 @@
|
|
| 326 |
"execution_count": null,
|
| 327 |
"metadata": {},
|
| 328 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
"source": [
|
| 330 |
"# Collect results\n",
|
| 331 |
"all_metrics = [\n",
|
|
@@ -337,7 +506,7 @@
|
|
| 337 |
"results_df = results_df.sort_values('within_5pct', ascending=False)\n",
|
| 338 |
"\n",
|
| 339 |
"print('\\n' + '='*70)\n",
|
| 340 |
-
"print('FINAL RESULTS -
|
| 341 |
"print('='*70)\n",
|
| 342 |
"print(results_df.to_string(index=False))\n",
|
| 343 |
"\n",
|
|
@@ -345,15 +514,15 @@
|
|
| 345 |
"n_passed = (results_df['within_5pct'] >= 95.0).sum()\n",
|
| 346 |
"print(f'\\nPassed (β₯95%): {n_passed}/8')\n",
|
| 347 |
"\n",
|
| 348 |
-
"# Save results\n",
|
| 349 |
-
"results_df.to_csv(
|
| 350 |
-
"print(f'\\nResults saved to {
|
| 351 |
]
|
| 352 |
}
|
| 353 |
],
|
| 354 |
"metadata": {
|
| 355 |
"kernelspec": {
|
| 356 |
-
"display_name": "
|
| 357 |
"language": "python",
|
| 358 |
"name": "python3"
|
| 359 |
},
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# 03 β Classical ML Models (v3) β 8-Model Optimized Training\n",
|
| 8 |
+
"## SOH Regression: Cross-Battery Generalization Split\n",
|
| 9 |
"\n",
|
| 10 |
+
"**v3 Pipeline (bug fixes over v2):**\n",
|
| 11 |
+
"- Load preprocessed `battery_features.csv` from NB02 (18 features)\n",
|
| 12 |
+
"- **Cross-battery grouped split** (v2 bug: used intra-battery 80/20 β data leakage)\n",
|
| 13 |
+
"- Train 8 core models with 18 features (v2 had 12)\n",
|
| 14 |
+
"- Proper NaN imputation (no more `fillna(0)` for Re/Rct)\n",
|
| 15 |
"- Target: β₯95% within-Β±5% SOH accuracy on all models\n",
|
| 16 |
+
"- Save artifacts to `artifacts/v3/`\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"**v3 Bug Fixes:**\n",
|
| 19 |
+
"1. Split: Intra-battery β cross-battery (no leakage)\n",
|
| 20 |
+
"2. Features: 12 β 18 (6 new physics-informed features)\n",
|
| 21 |
+
"3. Imputation: `fillna(0)` β ffill/bfill/median (already done in NB02)\n",
|
| 22 |
+
"4. Scaler: single consistent scaler from NB02 training split\n",
|
| 23 |
"\n",
|
| 24 |
"**Models (8 total):**\n",
|
| 25 |
+
"1. ExtraTrees (tree-based, unscaled)\n",
|
| 26 |
"2. GradientBoosting (sequential ensemble)\n",
|
| 27 |
"3. RandomForest (bagging ensemble)\n",
|
| 28 |
"4. XGBoost (boosted trees with tuning)\n",
|
|
|
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "code",
|
| 37 |
+
"execution_count": 1,
|
| 38 |
"metadata": {},
|
| 39 |
+
"outputs": [
|
| 40 |
+
{
|
| 41 |
+
"name": "stdout",
|
| 42 |
+
"output_type": "stream",
|
| 43 |
+
"text": [
|
| 44 |
+
"Setup complete.\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
"source": [
|
| 49 |
"import sys, os\n",
|
| 50 |
"sys.path.insert(0, os.path.abspath('..'))\n",
|
|
|
|
| 67 |
"from xgboost import XGBRegressor\n",
|
| 68 |
"from lightgbm import LGBMRegressor\n",
|
| 69 |
"\n",
|
| 70 |
+
"from src.utils.config import get_version_paths, ensure_version_dirs, FEATURE_COLS_V3\n",
|
| 71 |
+
"\n",
|
| 72 |
"print('Setup complete.')"
|
| 73 |
]
|
| 74 |
},
|
| 75 |
{
|
| 76 |
"cell_type": "code",
|
| 77 |
+
"execution_count": 2,
|
| 78 |
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"v3 Results: E:\\VIT\\aiBatteryLifecycle\\artifacts\\v3\\results\n",
|
| 85 |
+
"v3 Models: E:\\VIT\\aiBatteryLifecycle\\artifacts\\v3\\models\\classical\n",
|
| 86 |
+
"v3 Scalers: E:\\VIT\\aiBatteryLifecycle\\artifacts\\v3\\scalers\n",
|
| 87 |
+
"v3 Features: E:\\VIT\\aiBatteryLifecycle\\artifacts\\v3\\features\n"
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
],
|
| 91 |
"source": [
|
| 92 |
+
"# Setup v3 paths\n",
|
| 93 |
+
"v3 = get_version_paths('v3')\n",
|
| 94 |
+
"ensure_version_dirs('v3')\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"V3_FEATURES = v3['root'] / 'features'\n",
|
| 97 |
"\n",
|
| 98 |
+
"print(f'v3 Results: {v3[\"results\"]}')\n",
|
| 99 |
+
"print(f'v3 Models: {v3[\"models_classical\"]}')\n",
|
| 100 |
+
"print(f'v3 Scalers: {v3[\"scalers\"]}')\n",
|
| 101 |
+
"print(f'v3 Features: {V3_FEATURES}')"
|
| 102 |
]
|
| 103 |
},
|
| 104 |
{
|
| 105 |
"cell_type": "code",
|
| 106 |
+
"execution_count": 3,
|
| 107 |
"metadata": {},
|
| 108 |
+
"outputs": [
|
| 109 |
+
{
|
| 110 |
+
"name": "stdout",
|
| 111 |
+
"output_type": "stream",
|
| 112 |
+
"text": [
|
| 113 |
+
"Dataset shape: (2678, 25)\n",
|
| 114 |
+
"Batteries: ['B0005', 'B0006', 'B0007', 'B0018', 'B0025', 'B0026', 'B0027', 'B0028', 'B0029', 'B0030', 'B0031', 'B0032', 'B0033', 'B0034', 'B0036', 'B0038', 'B0039', 'B0040', 'B0041', 'B0042', 'B0043', 'B0044', 'B0045', 'B0046', 'B0047', 'B0048', 'B0053', 'B0054', 'B0055', 'B0056']\n",
|
| 115 |
+
"SOH range: 2.2% β 122.2%\n",
|
| 116 |
+
"NaN count: 0\n"
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
],
|
| 120 |
"source": [
|
| 121 |
+
"# Load preprocessed features from NB02 (v3: 18 features, already imputed)\n",
|
| 122 |
+
"features_df = pd.read_csv(V3_FEATURES / 'battery_features.csv')\n",
|
| 123 |
"print(f'Dataset shape: {features_df.shape}')\n",
|
| 124 |
"print(f'Batteries: {sorted(features_df[\"battery_id\"].unique())}')\n",
|
| 125 |
+
"print(f'SOH range: {features_df[\"SoH\"].min():.1f}% β {features_df[\"SoH\"].max():.1f}%')\n",
|
| 126 |
+
"print(f'NaN count: {features_df[FEATURE_COLS_V3].isna().sum().sum()}')"
|
| 127 |
]
|
| 128 |
},
|
| 129 |
{
|
| 130 |
"cell_type": "code",
|
| 131 |
+
"execution_count": 4,
|
| 132 |
"metadata": {},
|
| 133 |
+
"outputs": [
|
| 134 |
+
{
|
| 135 |
+
"name": "stdout",
|
| 136 |
+
"output_type": "stream",
|
| 137 |
+
"text": [
|
| 138 |
+
"Train: 2163 samples from 24 batteries\n",
|
| 139 |
+
"Test: 515 samples from 6 batteries\n",
|
| 140 |
+
"Train batteries: ['B0005', 'B0006', 'B0007', 'B0018', 'B0025', 'B0026', 'B0029', 'B0030', 'B0032', 'B0033', 'B0034', 'B0038', 'B0039', 'B0040', 'B0041', 'B0044', 'B0045', 'B0046', 'B0047', 'B0048', 'B0053', 'B0054', 'B0055', 'B0056']\n",
|
| 141 |
+
"Test batteries: ['B0027', 'B0028', 'B0031', 'B0036', 'B0042', 'B0043']\n",
|
| 142 |
+
"Overlap: NONE β (no leakage)\n",
|
| 143 |
+
"Train SOH: 2.2% β 101.8%\n",
|
| 144 |
+
"Test SOH: 2.8% β 122.2%\n"
|
| 145 |
+
]
|
| 146 |
+
}
|
| 147 |
+
],
|
| 148 |
"source": [
|
| 149 |
+
"# ββ v3 FIX: Cross-battery grouped split (no data leakage) ββ\n",
|
| 150 |
+
"# v2 bug: intra-battery 80/20 chronological split per battery\n",
|
| 151 |
+
"# β All batteries appear in both train AND test β inflated RΒ²\n",
|
| 152 |
+
"# v3 fix: entire batteries in train OR test, never both\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"from src.data.preprocessing import group_battery_split\n",
|
| 155 |
"\n",
|
| 156 |
+
"train_df, test_df = group_battery_split(features_df, train_ratio=0.8)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
"\n",
|
| 158 |
+
"print(f'Train: {len(train_df)} samples from {train_df[\"battery_id\"].nunique()} batteries')\n",
|
| 159 |
+
"print(f'Test: {len(test_df)} samples from {test_df[\"battery_id\"].nunique()} batteries')\n",
|
| 160 |
+
"print(f'Train batteries: {sorted(train_df[\"battery_id\"].unique())}')\n",
|
| 161 |
+
"print(f'Test batteries: {sorted(test_df[\"battery_id\"].unique())}')\n",
|
| 162 |
"\n",
|
| 163 |
+
"overlap = set(train_df['battery_id']) & set(test_df['battery_id'])\n",
|
| 164 |
+
"print(f'Overlap: {overlap if overlap else \"NONE β (no leakage)\"}')\n",
|
| 165 |
"print(f'Train SOH: {train_df[\"SoH\"].min():.1f}% β {train_df[\"SoH\"].max():.1f}%')\n",
|
| 166 |
+
"print(f'Test SOH: {test_df[\"SoH\"].min():.1f}% β {test_df[\"SoH\"].max():.1f}%')"
|
| 167 |
]
|
| 168 |
},
|
| 169 |
{
|
| 170 |
"cell_type": "code",
|
| 171 |
+
"execution_count": 5,
|
| 172 |
"metadata": {},
|
| 173 |
+
"outputs": [
|
| 174 |
+
{
|
| 175 |
+
"name": "stdout",
|
| 176 |
+
"output_type": "stream",
|
| 177 |
+
"text": [
|
| 178 |
+
"Using 18 features: ['cycle_number', 'ambient_temperature', 'peak_voltage', 'min_voltage', 'voltage_range', 'avg_current', 'avg_temp', 'temp_rise', 'cycle_duration', 'Re', 'Rct', 'delta_capacity', 'capacity_retention', 'cumulative_energy', 'dRe_dn', 'dRct_dn', 'soh_rolling_mean', 'voltage_slope']\n",
|
| 179 |
+
"X_train: (2163, 18)\n",
|
| 180 |
+
"y_train: (2163,)\n",
|
| 181 |
+
"X_test: (515, 18)\n",
|
| 182 |
+
"y_test: (515,)\n"
|
| 183 |
+
]
|
| 184 |
+
}
|
| 185 |
+
],
|
| 186 |
"source": [
|
| 187 |
+
"# v3: Use all 18 features (12 base + 6 physics-informed)\n",
|
| 188 |
+
"feature_cols = [c for c in FEATURE_COLS_V3 if c in features_df.columns]\n",
|
| 189 |
+
"print(f'Using {len(feature_cols)} features: {feature_cols}')\n",
|
|
|
|
|
|
|
|
|
|
| 190 |
"\n",
|
| 191 |
"X_train = train_df[feature_cols].values\n",
|
| 192 |
"y_train = train_df['SoH'].values\n",
|
|
|
|
| 194 |
"y_test = test_df['SoH'].values\n",
|
| 195 |
"\n",
|
| 196 |
"print(f'X_train: {X_train.shape}')\n",
|
| 197 |
+
"print(f'y_train: {y_train.shape}')\n",
|
| 198 |
+
"print(f'X_test: {X_test.shape}')\n",
|
| 199 |
+
"print(f'y_test: {y_test.shape}')"
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "code",
|
| 204 |
+
"execution_count": 6,
|
| 205 |
"metadata": {},
|
| 206 |
+
"outputs": [
|
| 207 |
+
{
|
| 208 |
+
"name": "stdout",
|
| 209 |
+
"output_type": "stream",
|
| 210 |
+
"text": [
|
| 211 |
+
"Scaler loaded from NB02 (fitted on training batteries only).\n",
|
| 212 |
+
" Mean range: [-0.0001, 3282.9991]\n"
|
| 213 |
+
]
|
| 214 |
+
}
|
| 215 |
+
],
|
| 216 |
"source": [
|
| 217 |
+
"# Load scaler from NB02 (v3: consistent with training split)\n",
|
| 218 |
+
"scaler = joblib.load(v3['scalers'] / 'v3_features_standard.joblib')\n",
|
| 219 |
+
"X_train_scaled = scaler.transform(X_train)\n",
|
| 220 |
"X_test_scaled = scaler.transform(X_test)\n",
|
| 221 |
+
"print(f'Scaler loaded from NB02 (fitted on training batteries only).')\n",
|
| 222 |
+
"print(f' Mean range: [{scaler.mean_.min():.4f}, {scaler.mean_.max():.4f}]')"
|
|
|
|
|
|
|
| 223 |
]
|
| 224 |
},
|
| 225 |
{
|
| 226 |
"cell_type": "code",
|
| 227 |
+
"execution_count": 7,
|
| 228 |
"metadata": {},
|
| 229 |
"outputs": [],
|
| 230 |
"source": [
|
|
|
|
| 247 |
},
|
| 248 |
{
|
| 249 |
"cell_type": "code",
|
| 250 |
+
"execution_count": 8,
|
| 251 |
"metadata": {},
|
| 252 |
+
"outputs": [
|
| 253 |
+
{
|
| 254 |
+
"name": "stdout",
|
| 255 |
+
"output_type": "stream",
|
| 256 |
+
"text": [
|
| 257 |
+
"ExtraTrees | RΒ²=0.9701 | MAE=3.20 | Within-5%=75.1% | β FAIL\n"
|
| 258 |
+
]
|
| 259 |
+
}
|
| 260 |
+
],
|
| 261 |
"source": [
|
| 262 |
"# ExtraTrees (unscaled)\n",
|
| 263 |
"model_et = ExtraTreesRegressor(\n",
|
|
|
|
| 269 |
")\n",
|
| 270 |
"model_et.fit(X_train, y_train)\n",
|
| 271 |
"_, metrics_et = evaluate_model('ExtraTrees', model_et, X_test, y_test,\n",
|
| 272 |
+
" v3['models_classical'] / 'extra_trees.joblib')"
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "code",
|
| 277 |
+
"execution_count": 9,
|
| 278 |
"metadata": {},
|
| 279 |
+
"outputs": [
|
| 280 |
+
{
|
| 281 |
+
"name": "stdout",
|
| 282 |
+
"output_type": "stream",
|
| 283 |
+
"text": [
|
| 284 |
+
"GradientBoosting | RΒ²=0.9860 | MAE=1.38 | Within-5%=95.1% | β PASS\n"
|
| 285 |
+
]
|
| 286 |
+
}
|
| 287 |
+
],
|
| 288 |
"source": [
|
| 289 |
"# GradientBoosting (unscaled)\n",
|
| 290 |
"model_gb = GradientBoostingRegressor(\n",
|
|
|
|
| 296 |
")\n",
|
| 297 |
"model_gb.fit(X_train, y_train)\n",
|
| 298 |
"_, metrics_gb = evaluate_model('GradientBoosting', model_gb, X_test, y_test,\n",
|
| 299 |
+
" v3['models_classical'] / 'gradient_boosting.joblib')"
|
| 300 |
]
|
| 301 |
},
|
| 302 |
{
|
| 303 |
"cell_type": "code",
|
| 304 |
+
"execution_count": 10,
|
| 305 |
"metadata": {},
|
| 306 |
+
"outputs": [
|
| 307 |
+
{
|
| 308 |
+
"name": "stdout",
|
| 309 |
+
"output_type": "stream",
|
| 310 |
+
"text": [
|
| 311 |
+
"RandomForest | RΒ²=0.9814 | MAE=1.83 | Within-5%=91.3% | β FAIL\n"
|
| 312 |
+
]
|
| 313 |
+
}
|
| 314 |
+
],
|
| 315 |
"source": [
|
| 316 |
"# RandomForest (unscaled)\n",
|
| 317 |
"model_rf = RandomForestRegressor(\n",
|
|
|
|
| 323 |
")\n",
|
| 324 |
"model_rf.fit(X_train, y_train)\n",
|
| 325 |
"_, metrics_rf = evaluate_model('RandomForest', model_rf, X_test, y_test,\n",
|
| 326 |
+
" v3['models_classical'] / 'random_forest.joblib')"
|
| 327 |
]
|
| 328 |
},
|
| 329 |
{
|
| 330 |
"cell_type": "code",
|
| 331 |
+
"execution_count": 11,
|
| 332 |
"metadata": {},
|
| 333 |
+
"outputs": [
|
| 334 |
+
{
|
| 335 |
+
"name": "stdout",
|
| 336 |
+
"output_type": "stream",
|
| 337 |
+
"text": [
|
| 338 |
+
"XGBoost | RΒ²=0.9866 | MAE=1.58 | Within-5%=93.8% | β FAIL\n"
|
| 339 |
+
]
|
| 340 |
+
}
|
| 341 |
+
],
|
| 342 |
"source": [
|
| 343 |
"# XGBoost (unscaled, tuned hyperparameters)\n",
|
| 344 |
"model_xgb = XGBRegressor(\n",
|
|
|
|
| 353 |
")\n",
|
| 354 |
"model_xgb.fit(X_train, y_train)\n",
|
| 355 |
"_, metrics_xgb = evaluate_model('XGBoost', model_xgb, X_test, y_test,\n",
|
| 356 |
+
" v3['models_classical'] / 'xgboost.joblib')"
|
| 357 |
]
|
| 358 |
},
|
| 359 |
{
|
| 360 |
"cell_type": "code",
|
| 361 |
+
"execution_count": 12,
|
| 362 |
"metadata": {},
|
| 363 |
+
"outputs": [
|
| 364 |
+
{
|
| 365 |
+
"name": "stdout",
|
| 366 |
+
"output_type": "stream",
|
| 367 |
+
"text": [
|
| 368 |
+
"LightGBM | RΒ²=0.9826 | MAE=1.98 | Within-5%=89.5% | β FAIL\n"
|
| 369 |
+
]
|
| 370 |
+
}
|
| 371 |
+
],
|
| 372 |
"source": [
|
| 373 |
"# LightGBM (unscaled, tuned hyperparameters)\n",
|
| 374 |
"model_lgbm = LGBMRegressor(\n",
|
|
|
|
| 383 |
")\n",
|
| 384 |
"model_lgbm.fit(X_train, y_train)\n",
|
| 385 |
"_, metrics_lgbm = evaluate_model('LightGBM', model_lgbm, X_test, y_test,\n",
|
| 386 |
+
" v3['models_classical'] / 'lightgbm.joblib')"
|
| 387 |
]
|
| 388 |
},
|
| 389 |
{
|
| 390 |
"cell_type": "code",
|
| 391 |
+
"execution_count": 13,
|
| 392 |
"metadata": {},
|
| 393 |
+
"outputs": [
|
| 394 |
+
{
|
| 395 |
+
"name": "stdout",
|
| 396 |
+
"output_type": "stream",
|
| 397 |
+
"text": [
|
| 398 |
+
"SVR | RΒ²=0.8898 | MAE=4.92 | Within-5%=79.0% | β FAIL\n"
|
| 399 |
+
]
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
"source": [
|
| 403 |
"# SVR (scaled)\n",
|
| 404 |
"model_svr = SVR(\n",
|
|
|
|
| 408 |
")\n",
|
| 409 |
"model_svr.fit(X_train_scaled, y_train)\n",
|
| 410 |
"_, metrics_svr = evaluate_model('SVR', model_svr, X_test_scaled, y_test,\n",
|
| 411 |
+
" v3['models_classical'] / 'svr.joblib')"
|
| 412 |
]
|
| 413 |
},
|
| 414 |
{
|
| 415 |
"cell_type": "code",
|
| 416 |
+
"execution_count": 14,
|
| 417 |
"metadata": {},
|
| 418 |
+
"outputs": [
|
| 419 |
+
{
|
| 420 |
+
"name": "stdout",
|
| 421 |
+
"output_type": "stream",
|
| 422 |
+
"text": [
|
| 423 |
+
"Ridge | RΒ²=0.9656 | MAE=3.23 | Within-5%=88.9% | β FAIL\n"
|
| 424 |
+
]
|
| 425 |
+
}
|
| 426 |
+
],
|
| 427 |
"source": [
|
| 428 |
"# Ridge (scaled)\n",
|
| 429 |
"model_ridge = Ridge(\n",
|
|
|
|
| 431 |
")\n",
|
| 432 |
"model_ridge.fit(X_train_scaled, y_train)\n",
|
| 433 |
"_, metrics_ridge = evaluate_model('Ridge', model_ridge, X_test_scaled, y_test,\n",
|
| 434 |
+
" v3['models_classical'] / 'ridge.joblib')"
|
| 435 |
]
|
| 436 |
},
|
| 437 |
{
|
| 438 |
"cell_type": "code",
|
| 439 |
+
"execution_count": 15,
|
| 440 |
"metadata": {},
|
| 441 |
+
"outputs": [
|
| 442 |
+
{
|
| 443 |
+
"name": "stdout",
|
| 444 |
+
"output_type": "stream",
|
| 445 |
+
"text": [
|
| 446 |
+
"KNN-5 | RΒ²=0.7555 | MAE=11.02 | Within-5%=34.2% | β FAIL\n"
|
| 447 |
+
]
|
| 448 |
+
}
|
| 449 |
+
],
|
| 450 |
"source": [
|
| 451 |
"# KNN-5 (scaled, with distance weighting)\n",
|
| 452 |
"model_knn5 = KNeighborsRegressor(\n",
|
|
|
|
| 456 |
")\n",
|
| 457 |
"model_knn5.fit(X_train_scaled, y_train)\n",
|
| 458 |
"_, metrics_knn5 = evaluate_model('KNN-5', model_knn5, X_test_scaled, y_test,\n",
|
| 459 |
+
" v3['models_classical'] / 'knn_k5.joblib')"
|
| 460 |
]
|
| 461 |
},
|
| 462 |
{
|
|
|
|
| 464 |
"execution_count": null,
|
| 465 |
"metadata": {},
|
| 466 |
"outputs": [],
|
| 467 |
+
"source": []
|
| 468 |
+
},
|
| 469 |
+
{
|
| 470 |
+
"cell_type": "code",
|
| 471 |
+
"execution_count": 16,
|
| 472 |
+
"metadata": {},
|
| 473 |
+
"outputs": [
|
| 474 |
+
{
|
| 475 |
+
"name": "stdout",
|
| 476 |
+
"output_type": "stream",
|
| 477 |
+
"text": [
|
| 478 |
+
"\n",
|
| 479 |
+
"======================================================================\n",
|
| 480 |
+
"FINAL RESULTS β v3 Classical ML (Cross-Battery Split, 18 Features)\n",
|
| 481 |
+
"======================================================================\n",
|
| 482 |
+
" model r2 mae within_5pct\n",
|
| 483 |
+
"GradientBoosting 0.985984 1.383230 95.145631\n",
|
| 484 |
+
" XGBoost 0.986594 1.576671 93.786408\n",
|
| 485 |
+
" RandomForest 0.981407 1.834184 91.262136\n",
|
| 486 |
+
" LightGBM 0.982554 1.976782 89.514563\n",
|
| 487 |
+
" Ridge 0.965638 3.225993 88.932039\n",
|
| 488 |
+
" SVR 0.889759 4.923939 79.029126\n",
|
| 489 |
+
" ExtraTrees 0.970125 3.201794 75.145631\n",
|
| 490 |
+
" KNN-5 0.755476 11.023403 34.174757\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"Passed (β₯95%): 1/8\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"Results saved to E:\\VIT\\aiBatteryLifecycle\\artifacts\\v3\\results\\v3_classical_soh_results.csv\n"
|
| 495 |
+
]
|
| 496 |
+
}
|
| 497 |
+
],
|
| 498 |
"source": [
|
| 499 |
"# Collect results\n",
|
| 500 |
"all_metrics = [\n",
|
|
|
|
| 506 |
"results_df = results_df.sort_values('within_5pct', ascending=False)\n",
|
| 507 |
"\n",
|
| 508 |
"print('\\n' + '='*70)\n",
|
| 509 |
+
"print('FINAL RESULTS β v3 Classical ML (Cross-Battery Split, 18 Features)')\n",
|
| 510 |
"print('='*70)\n",
|
| 511 |
"print(results_df.to_string(index=False))\n",
|
| 512 |
"\n",
|
|
|
|
| 514 |
"n_passed = (results_df['within_5pct'] >= 95.0).sum()\n",
|
| 515 |
"print(f'\\nPassed (β₯95%): {n_passed}/8')\n",
|
| 516 |
"\n",
|
| 517 |
+
"# Save results (v3: consistent naming)\n",
|
| 518 |
+
"results_df.to_csv(v3['results'] / 'v3_classical_soh_results.csv', index=False)\n",
|
| 519 |
+
"print(f'\\nResults saved to {v3[\"results\"] / \"v3_classical_soh_results.csv\"}')"
|
| 520 |
]
|
| 521 |
}
|
| 522 |
],
|
| 523 |
"metadata": {
|
| 524 |
"kernelspec": {
|
| 525 |
+
"display_name": "venv",
|
| 526 |
"language": "python",
|
| 527 |
"name": "python3"
|
| 528 |
},
|
notebooks/04_lstm_rnn.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/05_transformer.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/06_dynamic_graph.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/07_vae_lstm.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/08_ensemble.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/09_evaluation.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/download_models.py
CHANGED
|
@@ -58,7 +58,7 @@ def _hf_kwargs(allow_patterns: list | None = None,
|
|
| 58 |
return kwargs
|
| 59 |
|
| 60 |
|
| 61 |
-
def _key_models(version: str = "
|
| 62 |
base = ARTIFACTS_DIR / version / "models" / "classical"
|
| 63 |
return [base / f"{m}.joblib" for m in ("random_forest", "xgboost", "lightgbm")]
|
| 64 |
|
|
@@ -68,7 +68,7 @@ def version_loaded(version: str) -> bool:
|
|
| 68 |
return all(p.exists() for p in _key_models(version))
|
| 69 |
|
| 70 |
|
| 71 |
-
def already_downloaded(version: str = "
|
| 72 |
"""Return True only when all three BestEnsemble component models are present."""
|
| 73 |
missing = [p for p in _key_models(version) if not p.exists()]
|
| 74 |
if missing:
|
|
@@ -102,7 +102,7 @@ def download_version(version: str) -> None:
|
|
| 102 |
|
| 103 |
|
| 104 |
def download_all() -> None:
|
| 105 |
-
"""Download all versions (v1 + v2) from HF Hub."""
|
| 106 |
_ensure_hub()
|
| 107 |
from huggingface_hub import snapshot_download
|
| 108 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
@@ -126,8 +126,8 @@ def main() -> None:
|
|
| 126 |
download_version(args.version)
|
| 127 |
return
|
| 128 |
|
| 129 |
-
# Default: ensure
|
| 130 |
-
if already_downloaded("
|
| 131 |
print("[download_models] Artifacts already present β skipping download")
|
| 132 |
return
|
| 133 |
|
|
|
|
| 58 |
return kwargs
|
| 59 |
|
| 60 |
|
| 61 |
+
def _key_models(version: str = "v3") -> list:
|
| 62 |
base = ARTIFACTS_DIR / version / "models" / "classical"
|
| 63 |
return [base / f"{m}.joblib" for m in ("random_forest", "xgboost", "lightgbm")]
|
| 64 |
|
|
|
|
| 68 |
return all(p.exists() for p in _key_models(version))
|
| 69 |
|
| 70 |
|
| 71 |
+
def already_downloaded(version: str = "v3") -> bool:
|
| 72 |
"""Return True only when all three BestEnsemble component models are present."""
|
| 73 |
missing = [p for p in _key_models(version) if not p.exists()]
|
| 74 |
if missing:
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def download_all() -> None:
|
| 105 |
+
"""Download all versions (v1 + v2 + v3) from HF Hub."""
|
| 106 |
_ensure_hub()
|
| 107 |
from huggingface_hub import snapshot_download
|
| 108 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 126 |
download_version(args.version)
|
| 127 |
return
|
| 128 |
|
| 129 |
+
# Default: ensure v3 (latest) is present
|
| 130 |
+
if already_downloaded("v3"):
|
| 131 |
print("[download_models] Artifacts already present β skipping download")
|
| 132 |
return
|
| 133 |
|
scripts/upload_models_to_hub.py
CHANGED
|
@@ -78,18 +78,17 @@ artifacts/
|
|
| 78 |
βββ results/ # Validation JSONs
|
| 79 |
```
|
| 80 |
|
| 81 |
-
## Model Performance Summary
|
| 82 |
-
|
| 83 |
-
| Rank | Model | RΒ² | MAE |
|
| 84 |
-
|------|-------|----|-----|------
|
| 85 |
-
| 1 |
|
| 86 |
-
| 2 |
|
| 87 |
-
| 3 |
|
| 88 |
-
| 4 |
|
| 89 |
-
| 5 |
|
| 90 |
-
| 6 |
|
| 91 |
-
| 7 |
|
| 92 |
-
| 8 | VAE-LSTM | 0.730 | 7.82 | 9.98 | Generative |
|
| 93 |
|
| 94 |
## Usage
|
| 95 |
|
|
@@ -154,7 +153,7 @@ def main():
|
|
| 154 |
# 3. Upload each version directly at repo root: v1/ and v2/ (NOT under artifacts/)
|
| 155 |
# Split into one commit per subdirectory so no single commit is too large
|
| 156 |
# (the 100 MB random_forest.joblib would time out a combined commit).
|
| 157 |
-
for version in ["v1", "v2"]:
|
| 158 |
version_path = ARTIFACTS / version
|
| 159 |
if not version_path.exists():
|
| 160 |
print(f" [skip] {version_path} does not exist")
|
|
|
|
| 78 |
βββ results/ # Validation JSONs
|
| 79 |
```
|
| 80 |
|
| 81 |
+
## Model Performance Summary (v3)
|
| 82 |
+
|
| 83 |
+
| Rank | Model | RΒ² | MAE | Family |
|
| 84 |
+
|------|-------|----|-----|--------|
|
| 85 |
+
| 1 | XGBoost | 0.9866 | 1.58 | Classical |
|
| 86 |
+
| 2 | GradientBoosting | 0.9860 | 1.38 | Classical |
|
| 87 |
+
| 3 | LightGBM | 0.9826 | 1.98 | Classical |
|
| 88 |
+
| 4 | RandomForest | 0.9814 | 1.83 | Classical |
|
| 89 |
+
| 5 | ExtraTrees | 0.9701 | 3.20 | Classical |
|
| 90 |
+
| 6 | TFT | 0.8751 | 3.88 | Transformer |
|
| 91 |
+
| 7 | Weighted Avg Ensemble | 0.8991 | 3.51 | Ensemble |
|
|
|
|
| 92 |
|
| 93 |
## Usage
|
| 94 |
|
|
|
|
| 153 |
# 3. Upload each version directly at repo root: v1/ and v2/ (NOT under artifacts/)
|
| 154 |
# Split into one commit per subdirectory so no single commit is too large
|
| 155 |
# (the 100 MB random_forest.joblib would time out a combined commit).
|
| 156 |
+
for version in ["v1", "v2", "v3"]:
|
| 157 |
version_path = ARTIFACTS / version
|
| 158 |
if not version_path.exists():
|
| 159 |
print(f" [skip] {version_path} does not exist")
|
src/data/features.py
CHANGED
|
@@ -259,3 +259,76 @@ def build_battery_feature_dataset(
|
|
| 259 |
result["coulombic_efficiency"] = np.nan
|
| 260 |
|
| 261 |
return result.reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
result["coulombic_efficiency"] = np.nan
|
| 260 |
|
| 261 |
return result.reset_index(drop=True)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ββ v3 enhanced features ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
+
def add_v3_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 266 |
+
"""Add v3 physics-informed features on top of the base feature dataset.
|
| 267 |
+
|
| 268 |
+
New features (6 total):
|
| 269 |
+
- capacity_retention: Q_n / Q_1 per battery (0-1, monotonically decreasing)
|
| 270 |
+
- cumulative_energy: cumulative Ah throughput (proxy for total energy cycled)
|
| 271 |
+
- dRe_dn: impedance growth rate (ΞRe per cycle, forward diff)
|
| 272 |
+
- dRct_dn: impedance growth rate (ΞRct per cycle)
|
| 273 |
+
- soh_rolling_mean: 5-cycle rolling mean SOH (noise-smoothed degradation)
|
| 274 |
+
- voltage_slope: cycle-over-cycle voltage midpoint slope (dV_mid/dn)
|
| 275 |
+
|
| 276 |
+
Parameters
|
| 277 |
+
----------
|
| 278 |
+
df : pd.DataFrame
|
| 279 |
+
Output from ``build_battery_feature_dataset()``.
|
| 280 |
+
|
| 281 |
+
Returns
|
| 282 |
+
-------
|
| 283 |
+
pd.DataFrame
|
| 284 |
+
Same dataframe with 6 new columns appended.
|
| 285 |
+
"""
|
| 286 |
+
out = df.copy()
|
| 287 |
+
|
| 288 |
+
# ββ capacity_retention: Q_n / Q_1 per battery βββββββββββββββββββββββββββ
|
| 289 |
+
first_cap = out.groupby("battery_id")["Capacity"].transform("first")
|
| 290 |
+
out["capacity_retention"] = out["Capacity"] / first_cap.replace(0, np.nan)
|
| 291 |
+
|
| 292 |
+
# ββ cumulative_energy: cumulative Ah throughput βββββββββββββββββββββββββ
|
| 293 |
+
out["cumulative_energy"] = out.groupby("battery_id")["Capacity"].cumsum()
|
| 294 |
+
|
| 295 |
+
# ββ impedance growth rates (dRe/dn, dRct/dn) βββββββββββββββββββββββββββ
|
| 296 |
+
out["dRe_dn"] = out.groupby("battery_id")["Re"].diff().fillna(0)
|
| 297 |
+
out["dRct_dn"] = out.groupby("battery_id")["Rct"].diff().fillna(0)
|
| 298 |
+
|
| 299 |
+
# ββ SOH rolling mean (5-cycle window) βββββββββββββββββββββββββββββββββββ
|
| 300 |
+
out["soh_rolling_mean"] = out.groupby("battery_id")["SoH"].transform(
|
| 301 |
+
lambda s: s.rolling(window=5, min_periods=1, center=False).mean()
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# ββ voltage_slope: cycle-over-cycle mid-voltage change ββββββββββββββββββ
|
| 305 |
+
if "peak_voltage" in out.columns and "min_voltage" in out.columns:
|
| 306 |
+
v_mid = (out["peak_voltage"] + out["min_voltage"]) / 2.0
|
| 307 |
+
out["voltage_slope"] = v_mid.groupby(out["battery_id"]).diff().fillna(0)
|
| 308 |
+
else:
|
| 309 |
+
out["voltage_slope"] = 0.0
|
| 310 |
+
|
| 311 |
+
return out
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def impute_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 315 |
+
"""Fix NaN handling: forward-fill within battery, then group median.
|
| 316 |
+
|
| 317 |
+
Bug fix for v2 which used ``fillna(0)`` β physically impossible for Re/Rct.
|
| 318 |
+
"""
|
| 319 |
+
out = df.copy()
|
| 320 |
+
numeric_cols = out.select_dtypes(include=[np.number]).columns
|
| 321 |
+
|
| 322 |
+
# Step 1: forward fill within each battery (temporal continuity)
|
| 323 |
+
for col in numeric_cols:
|
| 324 |
+
out[col] = out.groupby("battery_id")[col].transform(
|
| 325 |
+
lambda s: s.ffill().bfill()
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Step 2: remaining NaN β global median (cross-battery)
|
| 329 |
+
for col in numeric_cols:
|
| 330 |
+
if out[col].isna().any():
|
| 331 |
+
median_val = out[col].median()
|
| 332 |
+
out[col] = out[col].fillna(median_val if pd.notna(median_val) else 0)
|
| 333 |
+
|
| 334 |
+
return out
|
src/utils/config.py
CHANGED
|
@@ -30,8 +30,8 @@ SCALERS_DIR = ARTIFACTS_DIR / "scalers"
|
|
| 30 |
FIGURES_DIR = ARTIFACTS_DIR / "figures"
|
| 31 |
LOGS_DIR = ARTIFACTS_DIR / "logs"
|
| 32 |
|
| 33 |
-
# Currently active artifact version (changed when
|
| 34 |
-
ACTIVE_VERSION: str = "
|
| 35 |
|
| 36 |
# Ensure all legacy artifact directories exist (backward compat)
|
| 37 |
for _d in (MODELS_DIR, SCALERS_DIR, FIGURES_DIR, LOGS_DIR,
|
|
@@ -105,6 +105,23 @@ DROPOUT = 0.2
|
|
| 105 |
LATENT_DIM = 16 # For VAE
|
| 106 |
|
| 107 |
# ββ Feature col lists (duplicated from preprocessing for easy import) ββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
FEATURE_COLS_SCALAR = [
|
| 109 |
"cycle_number", "ambient_temperature",
|
| 110 |
"peak_voltage", "min_voltage", "voltage_range",
|
|
|
|
| 30 |
FIGURES_DIR = ARTIFACTS_DIR / "figures"
|
| 31 |
LOGS_DIR = ARTIFACTS_DIR / "logs"
|
| 32 |
|
| 33 |
+
# Currently active artifact version (changed when v3 is validated)
|
| 34 |
+
ACTIVE_VERSION: str = "v3"
|
| 35 |
|
| 36 |
# Ensure all legacy artifact directories exist (backward compat)
|
| 37 |
for _d in (MODELS_DIR, SCALERS_DIR, FIGURES_DIR, LOGS_DIR,
|
|
|
|
| 105 |
LATENT_DIM = 16 # For VAE
|
| 106 |
|
| 107 |
# ββ Feature col lists (duplicated from preprocessing for easy import) ββββββββ
|
| 108 |
+
FEATURE_COLS_V2 = [
|
| 109 |
+
"cycle_number", "ambient_temperature",
|
| 110 |
+
"peak_voltage", "min_voltage", "voltage_range",
|
| 111 |
+
"avg_current", "avg_temp", "temp_rise",
|
| 112 |
+
"cycle_duration", "Re", "Rct", "delta_capacity",
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
# v3 adds 6 physics-informed features on top of v2's 12
|
| 116 |
+
FEATURE_COLS_V3 = FEATURE_COLS_V2 + [
|
| 117 |
+
"capacity_retention", # Q_n / Q_1 per battery (0-1 ratio)
|
| 118 |
+
"cumulative_energy", # cumulative Ah throughput
|
| 119 |
+
"dRe_dn", # impedance growth rate (ΞRe per cycle)
|
| 120 |
+
"dRct_dn", # impedance growth rate (ΞRct per cycle)
|
| 121 |
+
"soh_rolling_mean", # 5-cycle rolling mean SOH (smoothed)
|
| 122 |
+
"voltage_slope", # cycle-over-cycle voltage midpoint slope
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
FEATURE_COLS_SCALAR = [
|
| 126 |
"cycle_number", "ambient_temperature",
|
| 127 |
"peak_voltage", "min_voltage", "voltage_range",
|
src/utils/plotting.py
CHANGED
|
@@ -36,8 +36,8 @@ except OSError:
|
|
| 36 |
sns.set_context("paper", font_scale=1.3)
|
| 37 |
|
| 38 |
|
| 39 |
-
def save_fig(fig: plt.Figure, name: str, tight: bool = True) -> Path:
|
| 40 |
-
"""Save figure as PNG to
|
| 41 |
|
| 42 |
Parameters
|
| 43 |
----------
|
|
@@ -47,6 +47,8 @@ def save_fig(fig: plt.Figure, name: str, tight: bool = True) -> Path:
|
|
| 47 |
Base filename (without extension).
|
| 48 |
tight:
|
| 49 |
Whether to call ``tight_layout()`` before saving.
|
|
|
|
|
|
|
| 50 |
|
| 51 |
Returns
|
| 52 |
-------
|
|
@@ -55,7 +57,9 @@ def save_fig(fig: plt.Figure, name: str, tight: bool = True) -> Path:
|
|
| 55 |
"""
|
| 56 |
if tight:
|
| 57 |
fig.tight_layout()
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
fig.savefig(path, dpi=FIG_DPI, bbox_inches="tight")
|
| 60 |
return path
|
| 61 |
|
|
|
|
| 36 |
sns.set_context("paper", font_scale=1.3)
|
| 37 |
|
| 38 |
|
| 39 |
+
def save_fig(fig: plt.Figure, name: str, tight: bool = True, directory: Path | None = None) -> Path:
|
| 40 |
+
"""Save figure as PNG to a figures directory.
|
| 41 |
|
| 42 |
Parameters
|
| 43 |
----------
|
|
|
|
| 47 |
Base filename (without extension).
|
| 48 |
tight:
|
| 49 |
Whether to call ``tight_layout()`` before saving.
|
| 50 |
+
directory:
|
| 51 |
+
Target directory. Defaults to ``FIGURES_DIR`` (artifacts/figures/).
|
| 52 |
|
| 53 |
Returns
|
| 54 |
-------
|
|
|
|
| 57 |
"""
|
| 58 |
if tight:
|
| 59 |
fig.tight_layout()
|
| 60 |
+
out_dir = directory if directory is not None else FIGURES_DIR
|
| 61 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
path = out_dir / f"{name}.png"
|
| 63 |
fig.savefig(path, dpi=FIG_DPI, bbox_inches="tight")
|
| 64 |
return path
|
| 65 |
|