aiBatteryLifeCycle / scripts /data /patch_dl_notebooks_v2.py
NeerajCodz's picture
feat: full project β€” ML simulation, dashboard UI, models on HF Hub
f381be8
"""
Patch notebooks 04-09 for v2:
- Replace battery-grouped split with intra-battery chronological split
- Update import lines to include get_version_paths
- Update artifact save/load paths to use v2 directories
"""
import json
import re
from pathlib import Path
NB_DIR = Path(__file__).resolve().parent.parent / "notebooks"
# ──────────────────────────── helpers ────────────────────────────
def load_nb(name):
with open(NB_DIR / name, encoding="utf-8") as f:
return json.load(f)
def save_nb(nb, name):
with open(NB_DIR / name, "w", encoding="utf-8") as f:
json.dump(nb, f, indent=1, ensure_ascii=False)
print(f" βœ“ Saved {name}")
def get_code_cells(nb):
return [(i, c) for i, c in enumerate(nb["cells"]) if c["cell_type"] == "code"]
def set_source(cell, new_src):
"""Replace cell source, splitting into line-per-element list."""
lines = new_src.split("\n")
cell["source"] = [l + "\n" for l in lines[:-1]] + [lines[-1]]
def src(cell):
return "".join(cell["source"])
# ──────────── v2 intra-battery split (sequence version) ────────────
V2_SPLIT_SEQ = """\
# ── v2: intra-battery chronological split ──
# For each battery, first 80% of sequences β†’ train, last 20% β†’ test
train_idx, test_idx = [], []
for bid in np.unique(bids):
idxs = np.where(bids == bid)[0]
n = len(idxs)
cut = int(0.8 * n)
train_idx.extend(idxs[:cut].tolist())
test_idx.extend(idxs[cut:].tolist())
train_idx = np.array(train_idx)
test_idx = np.array(test_idx)
X_train, y_train = X_multi[train_idx], y_multi[train_idx]
X_test, y_test = X_multi[test_idx], y_multi[test_idx]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]} | Batteries in both: {len(np.unique(bids))}")"""
# ──────────────────────────── NB 04 ────────────────────────────
def patch_04():
print("Patching 04_lstm_rnn.ipynb ...")
nb = load_nb("04_lstm_rnn.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports β€” add get_version_paths, ensure_version_dirs
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,",
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,\n get_version_paths, ensure_version_dirs,"
)
# Add v2 paths setup at end
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')\nensure_version_dirs('v2')"
set_source(cc[0][1], s)
# Cell 1: replace battery-grouped split with intra-battery chrono split
s = src(cc[1][1])
# Replace everything between "# Battery-grouped split" and the scaler code
old_split = """# Battery-grouped split
unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_bats = set(unique_bids[n_train:])
train_mask = np.isin(bids, list(train_bats))
test_mask = np.isin(bids, list(test_bats))
X_train, y_train = X_multi[train_mask], y_multi[train_mask]
X_test, y_test = X_multi[test_mask], y_multi[test_mask]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]}")"""
s = s.replace(old_split, V2_SPLIT_SEQ)
set_source(cc[1][1], s)
# Cell 2: model save path
s = src(cc[2][1])
s = s.replace(
'MODELS_DIR / "deep" / f"{name.lower().replace(\' \', \'_\')}.pt"',
'v2["models_deep"] / f"{name.lower().replace(\' \', \'_\')}.pt"'
)
set_source(cc[2][1], s)
# Cells 3,4,5: save_fig β€” add v2 figures dir via FIGURES_DIR override
for idx in [3, 4, 5]:
s = src(cc[idx][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[idx][1], s)
# Cell 6: results save path
s = src(cc[6][1])
s = s.replace(
'ARTIFACTS_DIR / "lstm_soh_results.csv"',
'v2["results"] / "v2_lstm_soh_results.csv"'
)
s = s.replace(
'Saved to artifacts/lstm_soh_results.csv',
'Saved to v2 results'
)
set_source(cc[6][1], s)
save_nb(nb, "04_lstm_rnn.ipynb")
# ──────────────────────────── NB 05 ────────────────────────────
def patch_05():
print("Patching 05_transformer.ipynb ...")
nb = load_nb("05_transformer.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,",
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,\n get_version_paths, ensure_version_dirs,"
)
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')\nensure_version_dirs('v2')"
set_source(cc[0][1], s)
# Cell 1: data + split
s = src(cc[1][1])
old_split = """# Battery-grouped split
unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_bats = set(unique_bids[n_train:])
train_mask = np.isin(bids, list(train_bats))
test_mask = np.isin(bids, list(test_bats))
X_train, y_train = X_multi[train_mask], y_multi[train_mask]
X_test, y_test = X_multi[test_mask], y_multi[test_mask]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]}")"""
s = s.replace(old_split, V2_SPLIT_SEQ)
set_source(cc[1][1], s)
# Cell 2: BatteryGPT save
s = src(cc[2][1])
s = s.replace('MODELS_DIR / "deep" / "batterygpt.pt"', 'v2["models_deep"] / "batterygpt.pt"')
set_source(cc[2][1], s)
# Cell 3: TFT save
s = src(cc[3][1])
s = s.replace('MODELS_DIR / "deep" / "tft.pt"', 'v2["models_deep"] / "tft.pt"')
set_source(cc[3][1], s)
# Cell 4: iTransformer save
s = src(cc[4][1])
s = s.replace('MODELS_DIR / "deep" / "itransformer.keras"', 'v2["models_deep"] / "itransformer.keras"')
set_source(cc[4][1], s)
# Cell 5: Physics iTransformer save
s = src(cc[5][1])
s = s.replace('MODELS_DIR / "deep" / "physics_itransformer.keras"', 'v2["models_deep"] / "physics_itransformer.keras"')
set_source(cc[5][1], s)
# Cell 6: save_fig calls
s = src(cc[6][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[6][1], s)
# Cell 7: results save
s = src(cc[7][1])
s = s.replace('ARTIFACTS_DIR / "transformer_soh_results.csv"', 'v2["results"] / "v2_transformer_soh_results.csv"')
set_source(cc[7][1], s)
save_nb(nb, "05_transformer.ipynb")
# ──────────────────────────── NB 06 ────────────────────────────
def patch_06():
print("Patching 06_dynamic_graph.ipynb ...")
nb = load_nb("06_dynamic_graph.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,",
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,\n get_version_paths, ensure_version_dirs,"
)
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')\nensure_version_dirs('v2')"
set_source(cc[0][1], s)
# Cell 1: data + split
s = src(cc[1][1])
old_split = """# Battery-grouped split
unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_bats = set(unique_bids[n_train:])
train_mask = np.isin(bids, list(train_bats))
test_mask = np.isin(bids, list(test_bats))
X_train, y_train = X_multi[train_mask], y_multi[train_mask]
X_test, y_test = X_multi[test_mask], y_multi[test_mask]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]}")"""
s = s.replace(old_split, V2_SPLIT_SEQ)
set_source(cc[1][1], s)
# Cell 4: model save
s = src(cc[4][1])
s = s.replace('MODELS_DIR / "deep" / "dynamic_graph_itransformer.keras"',
'v2["models_deep"] / "dynamic_graph_itransformer.keras"')
set_source(cc[4][1], s)
# Cells 5,6,7: save_fig
for idx in [5, 6, 7]:
s = src(cc[idx][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[idx][1], s)
# Cell 7: results json save
s = src(cc[7][1])
s = s.replace('ARTIFACTS_DIR / "dg_itransformer_results.json"',
'v2["results"] / "v2_dg_itransformer_results.json"')
set_source(cc[7][1], s)
save_nb(nb, "06_dynamic_graph.ipynb")
# ──────────────────────────── NB 07 ────────────────────────────
def patch_07():
print("Patching 07_vae_lstm.ipynb ...")
nb = load_nb("07_vae_lstm.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,",
"from src.utils.config import (\n ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR,\n get_version_paths, ensure_version_dirs,"
)
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')\nensure_version_dirs('v2')"
set_source(cc[0][1], s)
# Cell 1: data + split
s = src(cc[1][1])
old_split = """# Battery-grouped split
unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_bats = set(unique_bids[n_train:])
train_mask = np.isin(bids, list(train_bats))
test_mask = np.isin(bids, list(test_bats))
X_train, y_train = X_multi[train_mask], y_multi[train_mask]
X_test, y_test = X_multi[test_mask], y_multi[test_mask]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]}")"""
s = s.replace(old_split, V2_SPLIT_SEQ)
set_source(cc[1][1], s)
# Cell 2: model save
s = src(cc[2][1])
s = s.replace('MODELS_DIR / "deep" / "vae_lstm.pt"', 'v2["models_deep"] / "vae_lstm.pt"')
set_source(cc[2][1], s)
# Cells 4,5,6: save_fig
for idx in [4, 5, 6]:
s = src(cc[idx][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[idx][1], s)
# Cell 6: results json save
s = src(cc[6][1])
s = s.replace('ARTIFACTS_DIR / "vae_lstm_results.json"',
'v2["results"] / "v2_vae_lstm_results.json"')
set_source(cc[6][1], s)
save_nb(nb, "07_vae_lstm.ipynb")
# ──────────────────────────── NB 08 ────────────────────────────
def patch_08():
print("Patching 08_ensemble.ipynb ...")
nb = load_nb("08_ensemble.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR",
"from src.utils.config import ARTIFACTS_DIR, FIGURES_DIR, MODELS_DIR, get_version_paths, ensure_version_dirs"
)
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')\nensure_version_dirs('v2')"
set_source(cc[0][1], s)
# Cell 1: data + split
s = src(cc[1][1])
old_split = """# Battery-grouped split
unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_bats = set(unique_bids[n_train:])
train_mask = np.isin(bids, list(train_bats))
test_mask = np.isin(bids, list(test_bats))
X_train, y_train = X_multi[train_mask], y_multi[train_mask]
X_test, y_test = X_multi[test_mask], y_multi[test_mask]
print(f"Train: {X_train.shape[0]} | Test: {X_test.shape[0]}")"""
s = s.replace(old_split, V2_SPLIT_SEQ)
set_source(cc[1][1], s)
# Cell 2: model loading paths β€” load from v2 deep models
s = src(cc[2][1])
s = s.replace('MODELS_DIR / "deep" / f"{name}.pt"', 'v2["models_deep"] / f"{name}.pt"')
set_source(cc[2][1], s)
# Cells 4,5,6: save_fig
for idx in [4, 5, 6]:
s = src(cc[idx][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[idx][1], s)
# Cell 5: results save
s = src(cc[5][1])
s = s.replace('ARTIFACTS_DIR / "ensemble_results.csv"',
'v2["results"] / "v2_ensemble_results.csv"')
set_source(cc[5][1], s)
save_nb(nb, "08_ensemble.ipynb")
# ──────────────────────────── NB 09 ────────────────────────────
def patch_09():
print("Patching 09_evaluation.ipynb ...")
nb = load_nb("09_evaluation.ipynb")
cc = get_code_cells(nb)
# Cell 0: imports
s = src(cc[0][1])
s = s.replace(
"from src.utils.config import ARTIFACTS_DIR, FIGURES_DIR",
"from src.utils.config import ARTIFACTS_DIR, FIGURES_DIR, get_version_paths"
)
s += "\n\n# v2 paths\nv2 = get_version_paths('v2')"
set_source(cc[0][1], s)
# Cell 1: result loading paths β†’ v2
s = src(cc[1][1])
s = s.replace('ARTIFACTS_DIR / "classical_soh_results.csv"',
'v2["results"] / "v2_classical_soh_results.csv"')
s = s.replace('ARTIFACTS_DIR / "lstm_soh_results.csv"',
'v2["results"] / "v2_lstm_soh_results.csv"')
s = s.replace('ARTIFACTS_DIR / "transformer_soh_results.csv"',
'v2["results"] / "v2_transformer_soh_results.csv"')
s = s.replace('ARTIFACTS_DIR / "dg_itransformer_results.json"',
'v2["results"] / "v2_dg_itransformer_results.json"')
s = s.replace('ARTIFACTS_DIR / "vae_lstm_results.json"',
'v2["results"] / "v2_vae_lstm_results.json"')
s = s.replace('ARTIFACTS_DIR / "ensemble_results.csv"',
'v2["results"] / "v2_ensemble_results.csv"')
set_source(cc[1][1], s)
# Cell 2: unified results save
s = src(cc[2][1])
s = s.replace('ARTIFACTS_DIR / "unified_results.csv"',
'v2["results"] / "v2_unified_results.csv"')
set_source(cc[2][1], s)
# Cell 3: save_fig
s = src(cc[3][1])
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[3][1], s)
# Cell 5: CED split β€” replace battery-grouped with chrono split
s = src(cc[5][1])
old_split_09 = """unique_bids = np.unique(bids)
rng = np.random.RandomState(42)
rng.shuffle(unique_bids)
n_train = int(0.8 * len(unique_bids))
train_bats = set(unique_bids[:n_train])
test_mask = ~np.isin(bids, list(train_bats))
y_test = y_all[test_mask]
bids_test = bids[test_mask]"""
new_split_09 = """# v2: intra-battery chronological split for CED
test_idx = []
for bid in np.unique(bids):
idxs = np.where(bids == bid)[0]
cut = int(0.8 * len(idxs))
test_idx.extend(idxs[cut:].tolist())
test_idx = np.array(test_idx)
y_test = y_all[test_idx]
bids_test = bids[test_idx]"""
s = s.replace(old_split_09, new_split_09)
# Also update save_fig if present
s = s.replace('save_fig(fig, "', 'save_fig(fig, "v2_')
set_source(cc[5][1], s)
# Cell 7: final rankings save
s = src(cc[7][1])
s = s.replace('ARTIFACTS_DIR / "final_rankings.csv"',
'v2["results"] / "v2_final_rankings.csv"')
set_source(cc[7][1], s)
save_nb(nb, "09_evaluation.ipynb")
# ──────────────────────────── main ────────────────────────────
if __name__ == "__main__":
patch_04()
patch_05()
patch_06()
patch_07()
patch_08()
patch_09()
print("\nAll 6 notebooks patched for v2!")