Spaces:
Running
Running
Commit ·
06f9287
0
Parent(s):
inital commit
Browse files- .gitignore +23 -0
- app.py +226 -0
- smart_grid_env.py +239 -0
- train.py +133 -0
.gitignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python Cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Virtual Environment
|
| 7 |
+
venv/
|
| 8 |
+
|
| 9 |
+
# Model Checkpoints & Artifacts
|
| 10 |
+
best_model/
|
| 11 |
+
checkpoints/
|
| 12 |
+
*.zip
|
| 13 |
+
|
| 14 |
+
# Logs and Evaluations
|
| 15 |
+
eval_logs/
|
| 16 |
+
tb_logs/
|
| 17 |
+
*.npz
|
| 18 |
+
|
| 19 |
+
# Pickled Data (e.g., normalization vectors)
|
| 20 |
+
*.pkl
|
| 21 |
+
|
| 22 |
+
# IDE / Editor
|
| 23 |
+
.vscode/
|
app.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py — Streamlit dashboard for SmartGridEnv PPO agent
|
| 3 |
+
|
| 4 |
+
Fixes vs. original:
|
| 5 |
+
- Loads VecNormalize stats (vec_normalize.pkl) alongside the PPO model
|
| 6 |
+
so observations are correctly normalised at inference time
|
| 7 |
+
- int(action.item()) fixes numpy array comparison in action_to_text()
|
| 8 |
+
- Added a rule-based baseline agent for comparison
|
| 9 |
+
- Richer charts: cost-per-hour bar chart + solar/demand/battery area chart
|
| 10 |
+
- Step-level info table logged per episode
|
| 11 |
+
- Graceful error handling throughout
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import streamlit as st
|
| 19 |
+
from stable_baselines3 import PPO
|
| 20 |
+
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
| 21 |
+
|
| 22 |
+
from smart_grid_env import SmartGridEnv
|
| 23 |
+
|
| 24 |
+
# ── Page config ───────────────────────────────────────────────────────────────
|
| 25 |
+
st.set_page_config(page_title="Smart Grid AI Control", layout="wide")
|
| 26 |
+
st.title("Smart Grid Energy Management System")
|
| 27 |
+
st.markdown("### PPO Reinforcement Learning Agent vs. Rule-Based Baseline")
|
| 28 |
+
|
| 29 |
+
# ── Sidebar ───────────────────────────────────────────────────────────────────
|
| 30 |
+
st.sidebar.header("Simulation Settings")
|
| 31 |
+
sim_speed = st.sidebar.slider("Speed (sec / step)", 0.05, 2.0, 0.3)
|
| 32 |
+
agent_choice = st.sidebar.radio(
|
| 33 |
+
"Agent to run",
|
| 34 |
+
["PPO Agent", "Rule-Based Baseline", "Compare Both"],
|
| 35 |
+
)
|
| 36 |
+
run_btn = st.sidebar.button("▶ Start 24-Hour Simulation", type="primary")
|
| 37 |
+
|
| 38 |
+
st.sidebar.markdown("---")
|
| 39 |
+
st.sidebar.markdown(
|
| 40 |
+
"**Rule-based logic:** charge when price < 0.20 and battery < 80%, "
|
| 41 |
+
"discharge when price > 0.40 and battery > 20%, else hold."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 46 |
+
def action_to_label(action: int) -> str:
|
| 47 |
+
return {0: "Hold ⏸", 1: "Charge ⬆", 2: "Discharge ⬇"}.get(action, "?")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def rule_based_action(obs: np.ndarray) -> int:
|
| 51 |
+
"""Simple price-threshold rule — useful as a sanity-check baseline."""
|
| 52 |
+
battery, solar, demand, price = obs
|
| 53 |
+
if price < 0.20 and battery < 80.0:
|
| 54 |
+
return 1 # cheap electricity → charge
|
| 55 |
+
if price > 0.40 and battery > 20.0:
|
| 56 |
+
return 2 # expensive electricity → use battery
|
| 57 |
+
return 0 # hold
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_ppo_model():
|
| 61 |
+
"""Load trained PPO model + normalisation stats. Returns (model, vec_env) or None."""
|
| 62 |
+
if not os.path.exists("ppo_smart_grid.zip"):
|
| 63 |
+
return None, None
|
| 64 |
+
try:
|
| 65 |
+
env = DummyVecEnv([SmartGridEnv])
|
| 66 |
+
if os.path.exists("vec_normalize.pkl"):
|
| 67 |
+
env = VecNormalize.load("vec_normalize.pkl", env)
|
| 68 |
+
env.training = False
|
| 69 |
+
env.norm_reward = False
|
| 70 |
+
model = PPO.load("ppo_smart_grid", env=env)
|
| 71 |
+
return model, env
|
| 72 |
+
except Exception as e:
|
| 73 |
+
st.error(f"Could not load model: {e}")
|
| 74 |
+
return None, None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def run_episode(agent: str, model=None, vec_env=None, speed: float = 0.3):
|
| 78 |
+
"""
|
| 79 |
+
Run a single 24-step episode and return a DataFrame of step-level data.
|
| 80 |
+
agent: 'ppo' | 'rule'
|
| 81 |
+
"""
|
| 82 |
+
raw_env = SmartGridEnv()
|
| 83 |
+
obs_raw, _ = raw_env.reset()
|
| 84 |
+
|
| 85 |
+
# PPO uses the normalised vec_env; rule-based uses raw env directly
|
| 86 |
+
if agent == "ppo" and vec_env is not None:
|
| 87 |
+
obs_vec = vec_env.reset()
|
| 88 |
+
|
| 89 |
+
records = []
|
| 90 |
+
total_cost = 0.0
|
| 91 |
+
|
| 92 |
+
live_battery = st.empty()
|
| 93 |
+
live_price = st.empty()
|
| 94 |
+
live_cost = st.empty()
|
| 95 |
+
live_chart = st.empty()
|
| 96 |
+
|
| 97 |
+
for step in range(24):
|
| 98 |
+
# ---- Pick action ----
|
| 99 |
+
if agent == "ppo" and model is not None:
|
| 100 |
+
action_arr, _ = model.predict(obs_vec, deterministic=True)
|
| 101 |
+
action = int(action_arr.item()) # numpy → plain int
|
| 102 |
+
obs_vec, _, _, _ = vec_env.step(action_arr)
|
| 103 |
+
# Step the raw env with the same action to get proper info dict
|
| 104 |
+
obs_raw, reward, terminated, _, info = raw_env.step(action)
|
| 105 |
+
else:
|
| 106 |
+
action = rule_based_action(obs_raw)
|
| 107 |
+
obs_raw, reward, terminated, _, info = raw_env.step(action)
|
| 108 |
+
|
| 109 |
+
cost = info["cost"]
|
| 110 |
+
total_cost += cost
|
| 111 |
+
|
| 112 |
+
battery = info["battery_soc"]
|
| 113 |
+
solar = info["solar_kw"]
|
| 114 |
+
demand = info["demand_kw"]
|
| 115 |
+
price = info["price"]
|
| 116 |
+
|
| 117 |
+
# ---- Live metrics ----
|
| 118 |
+
col1, col2, col3 = live_battery, live_price, live_cost
|
| 119 |
+
live_battery.metric("🔋 Battery SoC", f"{battery:.1f} %", action_to_label(action))
|
| 120 |
+
live_price.metric( "💲 Grid Price", f"${price:.3f}/kWh")
|
| 121 |
+
live_cost.metric( "💰 Running Cost", f"${total_cost:.2f}", delta_color="inverse")
|
| 122 |
+
|
| 123 |
+
records.append({
|
| 124 |
+
"Hour": step + 1,
|
| 125 |
+
"Battery (%)": round(battery, 2),
|
| 126 |
+
"Solar (kW)": round(solar, 2),
|
| 127 |
+
"Demand (kW)": round(demand, 2),
|
| 128 |
+
"Price ($/kWh)": round(price, 3),
|
| 129 |
+
"Step Cost ($)": round(cost, 3),
|
| 130 |
+
"Action": action_to_label(action),
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
# ---- Live chart (updates every step) ----
|
| 134 |
+
df_so_far = pd.DataFrame(records).set_index("Hour")
|
| 135 |
+
live_chart.line_chart(
|
| 136 |
+
df_so_far[["Battery (%)", "Solar (kW)", "Demand (kW)"]],
|
| 137 |
+
height=250,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
time.sleep(speed)
|
| 141 |
+
|
| 142 |
+
raw_env.close()
|
| 143 |
+
return pd.DataFrame(records), total_cost
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def show_results(df: pd.DataFrame, total_cost: float, label: str):
|
| 147 |
+
st.success(f"**{label}** — Total 24-hour cost: **${total_cost:.2f}**")
|
| 148 |
+
|
| 149 |
+
col_a, col_b = st.columns(2)
|
| 150 |
+
with col_a:
|
| 151 |
+
st.subheader("Hourly step cost ($)")
|
| 152 |
+
st.bar_chart(df.set_index("Hour")[["Step Cost ($)"]], height=220)
|
| 153 |
+
with col_b:
|
| 154 |
+
st.subheader("Battery, solar and demand")
|
| 155 |
+
st.line_chart(
|
| 156 |
+
df.set_index("Hour")[["Battery (%)", "Solar (kW)", "Demand (kW)"]],
|
| 157 |
+
height=220,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
with st.expander("📋 Full step-by-step log"):
|
| 161 |
+
st.dataframe(df, use_container_width=True)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ── Main simulation ───────────────────────────────────────────────────────────
|
| 165 |
+
if run_btn:
|
| 166 |
+
model, vec_env = load_ppo_model()
|
| 167 |
+
ppo_available = model is not None
|
| 168 |
+
|
| 169 |
+
if not ppo_available and agent_choice in ("PPO Agent", "Compare Both"):
|
| 170 |
+
st.warning(
|
| 171 |
+
"ppo_smart_grid.zip not found — run `python train.py` first. "
|
| 172 |
+
"Falling back to rule-based agent."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# ---- PPO only ----
|
| 176 |
+
if agent_choice == "PPO Agent":
|
| 177 |
+
st.markdown("### PPO Agent")
|
| 178 |
+
agent = "ppo" if ppo_available else "rule"
|
| 179 |
+
df, cost = run_episode(agent, model, vec_env, sim_speed)
|
| 180 |
+
show_results(df, cost, "PPO Agent" if ppo_available else "Rule-Based (fallback)")
|
| 181 |
+
|
| 182 |
+
# ---- Rule-based only ----
|
| 183 |
+
elif agent_choice == "Rule-Based Baseline":
|
| 184 |
+
st.markdown("### Rule-Based Baseline")
|
| 185 |
+
df, cost = run_episode("rule", speed=sim_speed)
|
| 186 |
+
show_results(df, cost, "Rule-Based Baseline")
|
| 187 |
+
|
| 188 |
+
# ---- Compare both ----
|
| 189 |
+
else:
|
| 190 |
+
tab_ppo, tab_rule = st.tabs([" PPO Agent", " Rule-Based Baseline"])
|
| 191 |
+
|
| 192 |
+
with tab_ppo:
|
| 193 |
+
st.markdown("#### PPO Agent — running...")
|
| 194 |
+
agent = "ppo" if ppo_available else "rule"
|
| 195 |
+
df_ppo, cost_ppo = run_episode(agent, model, vec_env, sim_speed)
|
| 196 |
+
show_results(df_ppo, cost_ppo, "PPO Agent" if ppo_available else "Rule-Based (fallback)")
|
| 197 |
+
|
| 198 |
+
with tab_rule:
|
| 199 |
+
st.markdown("#### Rule-Based Baseline — running...")
|
| 200 |
+
df_rule, cost_rule = run_episode("rule", speed=sim_speed)
|
| 201 |
+
show_results(df_rule, cost_rule, "Rule-Based Baseline")
|
| 202 |
+
|
| 203 |
+
# ---- Side-by-side cost summary ----
|
| 204 |
+
st.markdown("---")
|
| 205 |
+
st.subheader("Cost comparison")
|
| 206 |
+
c1, c2, c3 = st.columns(3)
|
| 207 |
+
c1.metric("PPO total cost", f"${cost_ppo:.2f}")
|
| 208 |
+
c2.metric("Rule-based total cost", f"${cost_rule:.2f}")
|
| 209 |
+
saving = cost_rule - cost_ppo
|
| 210 |
+
c3.metric("PPO saving", f"${saving:.2f}",
|
| 211 |
+
delta=f"{'better' if saving > 0 else 'worse'} than rule-based",
|
| 212 |
+
delta_color="normal" if saving > 0 else "inverse")
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
st.info("Configure settings in the sidebar and click **▶ Start 24-Hour Simulation**.")
|
| 216 |
+
st.markdown("""
|
| 217 |
+
#### How it works
|
| 218 |
+
| Component | Detail |
|
| 219 |
+
|---|---|
|
| 220 |
+
| Environment | 24-step episode (1 step = 1 hour) |
|
| 221 |
+
| Observation | Battery SoC, solar generation, house demand, grid price |
|
| 222 |
+
| Actions | Hold / Charge from grid / Discharge battery |
|
| 223 |
+
| Reward | Negative net grid cost (includes solar sell-back revenue) |
|
| 224 |
+
| Agent | PPO with MLP policy, trained via `stable-baselines3` |
|
| 225 |
+
| Baseline | Simple price-threshold rule for comparison |
|
| 226 |
+
""")
|
smart_grid_env.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
from gymnasium import spaces
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SmartGridEnv(gym.Env):
|
| 7 |
+
"""
|
| 8 |
+
Smart Grid Battery Management Environment (Gymnasium-compatible)
|
| 9 |
+
|
| 10 |
+
Goal: Minimize daily electricity cost by intelligently charging/discharging
|
| 11 |
+
a home battery, using solar generation, and interacting with the grid.
|
| 12 |
+
|
| 13 |
+
Observation (4 values, all normalized to [0, 1]):
|
| 14 |
+
[battery_soc, solar_gen_kw, house_demand_kw, grid_price]
|
| 15 |
+
|
| 16 |
+
Action space (Discrete 3):
|
| 17 |
+
0 = Hold — do nothing beyond covering net load from grid/solar
|
| 18 |
+
1 = Charge — buy from grid to fill battery (10 kW rate)
|
| 19 |
+
2 = Discharge — draw from battery to cover load (10 kW rate)
|
| 20 |
+
|
| 21 |
+
Reward:
|
| 22 |
+
Negative net cost per step (agent learns to minimize cost).
|
| 23 |
+
Includes sell-back revenue when solar surplus is fed to grid.
|
| 24 |
+
Includes a small battery health penalty for extreme SoC operation.
|
| 25 |
+
|
| 26 |
+
Fixes vs. original:
|
| 27 |
+
- Observation space bounds match actual value ranges
|
| 28 |
+
- Solar/demand energy balance applied in all action branches
|
| 29 |
+
- Grid sell-back (feed-in tariff) modeled in HOLD branch
|
| 30 |
+
- Battery SoC clamped to [0, battery_capacity]
|
| 31 |
+
- Charging efficiency loss (90%) modeled
|
| 32 |
+
- Correlated time-series for price/solar/demand (no more i.i.d. jumps)
|
| 33 |
+
- Battery health penalty for operating near 0% or 100% SoC
|
| 34 |
+
- render() method added
|
| 35 |
+
- Fully compatible with check_env() and VecNormalize
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
metadata = {"render_modes": ["human"]}
|
| 39 |
+
|
| 40 |
+
# Physical constants
|
| 41 |
+
BATTERY_CAPACITY = 100.0 # kWh
|
| 42 |
+
CHARGE_RATE = 10.0 # kW (max charge/discharge per step)
|
| 43 |
+
CHARGE_EFFICIENCY = 0.90 # 90% round-trip efficiency
|
| 44 |
+
FEED_IN_TARIFF = 0.50 # sell surplus solar at 50% of grid price
|
| 45 |
+
SOC_PENALTY_COEF = 0.005 # small penalty for extreme battery levels
|
| 46 |
+
|
| 47 |
+
# Observation high limits (battery 0-100 %, solar 0-10 kW,
|
| 48 |
+
# demand 0-10 kW, price 0-1.5 $/kWh to cover double-peak)
|
| 49 |
+
OBS_HIGH = np.array([100.0, 10.0, 10.0, 1.5], dtype=np.float32)
|
| 50 |
+
OBS_LOW = np.zeros(4, dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
def __init__(self, render_mode=None):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.render_mode = render_mode
|
| 55 |
+
|
| 56 |
+
# --- Action / Observation Spaces ---
|
| 57 |
+
self.action_space = spaces.Discrete(3)
|
| 58 |
+
self.observation_space = spaces.Box(
|
| 59 |
+
low=self.OBS_LOW,
|
| 60 |
+
high=self.OBS_HIGH,
|
| 61 |
+
shape=(4,),
|
| 62 |
+
dtype=np.float32,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Internal state
|
| 66 |
+
self.current_step = 0
|
| 67 |
+
self.current_battery = self.BATTERY_CAPACITY * 0.5
|
| 68 |
+
self._state = self._make_initial_state()
|
| 69 |
+
|
| 70 |
+
# For correlated time-series generation
|
| 71 |
+
self._price_base = 0.2 # $/kWh (drifts each step)
|
| 72 |
+
self._demand_base = 3.0 # kW
|
| 73 |
+
self._solar_base = 0.0 # kW
|
| 74 |
+
|
| 75 |
+
# ------------------------------------------------------------------
|
| 76 |
+
# Gymnasium API
|
| 77 |
+
# ------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def reset(self, seed=None, options=None):
|
| 80 |
+
super().reset(seed=seed)
|
| 81 |
+
self.current_step = 0
|
| 82 |
+
self.current_battery = self.BATTERY_CAPACITY * 0.5
|
| 83 |
+
self._price_base = 0.2
|
| 84 |
+
self._demand_base = 3.0
|
| 85 |
+
self._solar_base = 0.0
|
| 86 |
+
self._state = self._make_initial_state()
|
| 87 |
+
return self._state.copy(), {}
|
| 88 |
+
|
| 89 |
+
def step(self, action):
|
| 90 |
+
assert self.action_space.contains(action), f"Invalid action {action}"
|
| 91 |
+
|
| 92 |
+
battery, solar, demand, price = self._state
|
| 93 |
+
|
| 94 |
+
# ---- 1. Compute net load BEFORE battery action ----
|
| 95 |
+
# Positive → house needs more than solar provides (must buy or discharge)
|
| 96 |
+
# Negative → solar surplus (can sell back or charge)
|
| 97 |
+
net_load = demand - solar
|
| 98 |
+
|
| 99 |
+
grid_cost = 0.0 # positive = paying, negative = earning
|
| 100 |
+
|
| 101 |
+
# ---- 2. Execute battery action ----
|
| 102 |
+
if action == 1: # CHARGE from grid
|
| 103 |
+
# How much can we actually charge?
|
| 104 |
+
headroom = self.BATTERY_CAPACITY - battery
|
| 105 |
+
charge_requested = min(self.CHARGE_RATE, headroom)
|
| 106 |
+
# Efficiency: we buy more from grid than actually stored
|
| 107 |
+
grid_draw = charge_requested / self.CHARGE_EFFICIENCY
|
| 108 |
+
self.current_battery = np.clip(
|
| 109 |
+
battery + charge_requested,
|
| 110 |
+
0.0, self.BATTERY_CAPACITY
|
| 111 |
+
)
|
| 112 |
+
# Also cover net_load from grid
|
| 113 |
+
grid_cost = (max(0.0, net_load) + grid_draw) * price
|
| 114 |
+
# If solar surplus even after load, get sell-back credit
|
| 115 |
+
grid_cost -= max(0.0, -net_load) * price * self.FEED_IN_TARIFF
|
| 116 |
+
|
| 117 |
+
elif action == 2: # DISCHARGE battery to cover load
|
| 118 |
+
# How much battery can supply?
|
| 119 |
+
discharge_requested = min(self.CHARGE_RATE, battery)
|
| 120 |
+
self.current_battery = np.clip(
|
| 121 |
+
battery - discharge_requested,
|
| 122 |
+
0.0, self.BATTERY_CAPACITY
|
| 123 |
+
)
|
| 124 |
+
# Remaining load after battery contribution
|
| 125 |
+
residual_load = net_load - discharge_requested
|
| 126 |
+
if residual_load > 0:
|
| 127 |
+
grid_cost = residual_load * price # still need some grid
|
| 128 |
+
else:
|
| 129 |
+
grid_cost = residual_load * price * self.FEED_IN_TARIFF # surplus → sell
|
| 130 |
+
|
| 131 |
+
else: # HOLD — let solar + grid balance the load
|
| 132 |
+
if net_load > 0:
|
| 133 |
+
grid_cost = net_load * price # buy deficit from grid
|
| 134 |
+
else:
|
| 135 |
+
grid_cost = net_load * price * self.FEED_IN_TARIFF # sell surplus
|
| 136 |
+
|
| 137 |
+
# ---- 3. Battery health penalty (discourages extreme SoC) ----
|
| 138 |
+
soc_frac = self.current_battery / self.BATTERY_CAPACITY
|
| 139 |
+
health_penalty = self.SOC_PENALTY_COEF * (
|
| 140 |
+
max(0.0, soc_frac - 0.9) + max(0.0, 0.1 - soc_frac)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# ---- 4. Reward ----
|
| 144 |
+
reward = -(grid_cost + health_penalty)
|
| 145 |
+
|
| 146 |
+
# ---- 5. Advance time, generate next state ----
|
| 147 |
+
self.current_step += 1
|
| 148 |
+
terminated = self.current_step >= 24
|
| 149 |
+
truncated = False
|
| 150 |
+
|
| 151 |
+
if not terminated:
|
| 152 |
+
self._state = self._generate_next_state()
|
| 153 |
+
else:
|
| 154 |
+
self._state = np.zeros(4, dtype=np.float32)
|
| 155 |
+
|
| 156 |
+
info = {
|
| 157 |
+
"cost": float(grid_cost),
|
| 158 |
+
"battery_soc": float(self.current_battery),
|
| 159 |
+
"solar_kw": float(solar),
|
| 160 |
+
"demand_kw": float(demand),
|
| 161 |
+
"price": float(price),
|
| 162 |
+
"action": int(action),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
if self.render_mode == "human":
|
| 166 |
+
self.render()
|
| 167 |
+
|
| 168 |
+
return self._state.copy(), float(reward), terminated, truncated, info
|
| 169 |
+
|
| 170 |
+
def render(self):
|
| 171 |
+
b, s, d, p = self._state
|
| 172 |
+
print(
|
| 173 |
+
f"[Hour {self.current_step:02d}] "
|
| 174 |
+
f"Battery={self.current_battery:.1f}% | "
|
| 175 |
+
f"Solar={s:.2f}kW | Demand={d:.2f}kW | Price=${p:.3f}/kWh"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def close(self):
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# ------------------------------------------------------------------
|
| 182 |
+
# Helpers
|
| 183 |
+
# ------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
def _make_initial_state(self) -> np.ndarray:
|
| 186 |
+
return np.array(
|
| 187 |
+
[self.current_battery, 0.0, 2.5, 0.15],
|
| 188 |
+
dtype=np.float32
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def _generate_next_state(self) -> np.ndarray:
|
| 192 |
+
"""
|
| 193 |
+
Correlated time-series generation so successive hours are smooth.
|
| 194 |
+
Each variable drifts toward a time-of-day mean with Gaussian noise.
|
| 195 |
+
"""
|
| 196 |
+
hour = self.current_step # 0–23
|
| 197 |
+
|
| 198 |
+
# --- Solar: bell-curve peaking at noon, zero at night ---
|
| 199 |
+
if 6 <= hour <= 18:
|
| 200 |
+
solar_mean = 5.0 * np.exp(-0.5 * ((hour - 12) / 3.5) ** 2)
|
| 201 |
+
else:
|
| 202 |
+
solar_mean = 0.0
|
| 203 |
+
self._solar_base += 0.3 * (solar_mean - self._solar_base)
|
| 204 |
+
next_solar = float(np.clip(
|
| 205 |
+
self._solar_base + self.np_random.normal(0, 0.4),
|
| 206 |
+
0.0, 10.0
|
| 207 |
+
))
|
| 208 |
+
|
| 209 |
+
# --- Demand: morning and evening peaks ---
|
| 210 |
+
demand_mean = (
|
| 211 |
+
2.5
|
| 212 |
+
+ 2.0 * np.exp(-0.5 * ((hour - 8) / 1.5) ** 2) # morning
|
| 213 |
+
+ 3.0 * np.exp(-0.5 * ((hour - 19) / 2.0) ** 2) # evening
|
| 214 |
+
)
|
| 215 |
+
self._demand_base += 0.3 * (demand_mean - self._demand_base)
|
| 216 |
+
next_demand = float(np.clip(
|
| 217 |
+
self._demand_base + self.np_random.normal(0, 0.5),
|
| 218 |
+
0.0, 10.0
|
| 219 |
+
))
|
| 220 |
+
|
| 221 |
+
# --- Price: cheap at night, expensive at peak (17–21) ---
|
| 222 |
+
if 17 <= hour <= 21:
|
| 223 |
+
price_mean = 0.65
|
| 224 |
+
elif 6 <= hour <= 9:
|
| 225 |
+
price_mean = 0.30
|
| 226 |
+
elif 23 <= hour or hour <= 5:
|
| 227 |
+
price_mean = 0.12
|
| 228 |
+
else:
|
| 229 |
+
price_mean = 0.22
|
| 230 |
+
self._price_base += 0.4 * (price_mean - self._price_base)
|
| 231 |
+
next_price = float(np.clip(
|
| 232 |
+
self._price_base + self.np_random.normal(0, 0.03),
|
| 233 |
+
0.05, 1.5
|
| 234 |
+
))
|
| 235 |
+
|
| 236 |
+
return np.array(
|
| 237 |
+
[self.current_battery, next_solar, next_demand, next_price],
|
| 238 |
+
dtype=np.float32
|
| 239 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py — PPO training script for SmartGridEnv
|
| 3 |
+
|
| 4 |
+
Fixes vs. original:
|
| 5 |
+
- check_env() validates the environment before training starts
|
| 6 |
+
- VecNormalize auto-normalizes observations and rewards for stable gradients
|
| 7 |
+
- 500,000 timesteps (was 10,000 — far too few for PPO to learn anything)
|
| 8 |
+
- EvalCallback saves the best model checkpoint automatically
|
| 9 |
+
- Hyperparameters tuned for this problem (n_steps, batch_size, ent_coef)
|
| 10 |
+
- vec_normalize stats saved alongside model (required for correct inference)
|
| 11 |
+
- TensorBoard logging enabled (optional — run: tensorboard --logdir ./tb_logs)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
from stable_baselines3 import PPO
|
| 16 |
+
from stable_baselines3.common.env_util import make_vec_env
|
| 17 |
+
from stable_baselines3.common.vec_env import VecNormalize
|
| 18 |
+
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
|
| 19 |
+
from stable_baselines3.common.env_checker import check_env
|
| 20 |
+
|
| 21 |
+
from smart_grid_env import SmartGridEnv
|
| 22 |
+
|
| 23 |
+
# ── 1. Validate environment ───────────────────────────────────────────────────
|
| 24 |
+
print("Checking environment...")
|
| 25 |
+
check_env(SmartGridEnv(), warn=True)
|
| 26 |
+
print("Environment check passed.\n")
|
| 27 |
+
|
| 28 |
+
# ── 2. Vectorised training environment (4 parallel workers) ───────────────────
|
| 29 |
+
N_ENVS = 4
|
| 30 |
+
train_env = make_vec_env(SmartGridEnv, n_envs=N_ENVS)
|
| 31 |
+
train_env = VecNormalize(
|
| 32 |
+
train_env,
|
| 33 |
+
norm_obs=True, # normalizes each obs dimension to ~N(0,1)
|
| 34 |
+
norm_reward=True, # normalizes reward scale — critical for PPO stability
|
| 35 |
+
clip_obs=10.0,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# ── 3. Separate evaluation environment (no reward normalisation) ───────────────
|
| 39 |
+
eval_env = make_vec_env(SmartGridEnv, n_envs=1)
|
| 40 |
+
eval_env = VecNormalize(
|
| 41 |
+
eval_env,
|
| 42 |
+
norm_obs=True,
|
| 43 |
+
norm_reward=False, # raw rewards for interpretable eval metrics
|
| 44 |
+
training=False, # stats are copied from train_env, not updated
|
| 45 |
+
clip_obs=10.0,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# ── 4. Define the PPO model ────────────────────────────────────────────────────
|
| 49 |
+
model = PPO(
|
| 50 |
+
policy = "MlpPolicy",
|
| 51 |
+
env = train_env,
|
| 52 |
+
verbose = 1,
|
| 53 |
+
tensorboard_log = "./tb_logs",
|
| 54 |
+
# --- Core PPO hyperparameters ---
|
| 55 |
+
n_steps = 1024, # steps collected per env per rollout
|
| 56 |
+
batch_size = 256, # minibatch size for gradient update
|
| 57 |
+
n_epochs = 10, # number of passes over each rollout buffer
|
| 58 |
+
gamma = 0.99, # discount factor (long-horizon cost matters)
|
| 59 |
+
gae_lambda = 0.95, # GAE smoothing
|
| 60 |
+
clip_range = 0.2, # PPO clip parameter
|
| 61 |
+
learning_rate = 3e-4, # Adam lr
|
| 62 |
+
ent_coef = 0.01, # entropy bonus (encourages exploration early on)
|
| 63 |
+
vf_coef = 0.5,
|
| 64 |
+
max_grad_norm = 0.5,
|
| 65 |
+
# --- Policy network architecture ---
|
| 66 |
+
policy_kwargs = dict(net_arch=[128, 128]), # 2-layer MLP, 128 units each
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# ── 5. Callbacks ───────────────────────────────────────────────────────────────
|
| 70 |
+
os.makedirs("./best_model", exist_ok=True)
|
| 71 |
+
os.makedirs("./checkpoints", exist_ok=True)
|
| 72 |
+
|
| 73 |
+
eval_callback = EvalCallback(
|
| 74 |
+
eval_env,
|
| 75 |
+
best_model_save_path = "./best_model",
|
| 76 |
+
log_path = "./eval_logs",
|
| 77 |
+
eval_freq = max(5_000 // N_ENVS, 1), # evaluate every ~5k env steps
|
| 78 |
+
n_eval_episodes = 20, # average over 20 full 24-hour episodes
|
| 79 |
+
deterministic = True,
|
| 80 |
+
render = False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
checkpoint_callback = CheckpointCallback(
|
| 84 |
+
save_freq = max(50_000 // N_ENVS, 1),
|
| 85 |
+
save_path = "./checkpoints",
|
| 86 |
+
name_prefix= "ppo_smart_grid",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# ── 6. Train ───────────────────────────────────────────────────────────────────
|
| 90 |
+
TOTAL_TIMESTEPS = 500_000
|
| 91 |
+
print(f"Training PPO for {TOTAL_TIMESTEPS:,} timesteps across {N_ENVS} parallel envs...")
|
| 92 |
+
print("Tip: run `tensorboard --logdir ./tb_logs` to monitor training live.\n")
|
| 93 |
+
|
| 94 |
+
model.learn(
|
| 95 |
+
total_timesteps = TOTAL_TIMESTEPS,
|
| 96 |
+
callback = [eval_callback, checkpoint_callback],
|
| 97 |
+
progress_bar = True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# ── 7. Save final model + normalisation statistics ────────────────────────────
|
| 101 |
+
model.save("ppo_smart_grid")
|
| 102 |
+
train_env.save("vec_normalize.pkl") # MUST be saved — needed for inference
|
| 103 |
+
|
| 104 |
+
print("\nTraining complete!")
|
| 105 |
+
print(" Saved: ppo_smart_grid.zip")
|
| 106 |
+
print(" Saved: vec_normalize.pkl (required alongside the model for inference)")
|
| 107 |
+
print(" Best checkpoint: ./best_model/best_model.zip")
|
| 108 |
+
|
| 109 |
+
# ── 8. Quick sanity-check: run one episode with the trained agent ──────────────
|
| 110 |
+
print("\n--- Sanity check: one 24-hour episode ---")
|
| 111 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 112 |
+
|
| 113 |
+
test_env = DummyVecEnv([SmartGridEnv])
|
| 114 |
+
test_env = VecNormalize.load("vec_normalize.pkl", test_env)
|
| 115 |
+
test_env.training = False
|
| 116 |
+
test_env.norm_reward = False
|
| 117 |
+
|
| 118 |
+
obs = test_env.reset()
|
| 119 |
+
total_cost = 0.0
|
| 120 |
+
for hour in range(24):
|
| 121 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 122 |
+
obs, reward, done, info = test_env.step(action)
|
| 123 |
+
total_cost += info[0]["cost"]
|
| 124 |
+
action_label = ["Hold", "Charge", "Discharge"][int(action[0])]
|
| 125 |
+
print(
|
| 126 |
+
f" Hour {hour+1:02d} | Action: {action_label:<10} | "
|
| 127 |
+
f"Battery: {info[0]['battery_soc']:5.1f}% | "
|
| 128 |
+
f"Price: ${info[0]['price']:.3f} | "
|
| 129 |
+
f"Step cost: ${info[0]['cost']:.3f}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print(f"\nTotal 24-hour cost: ${total_cost:.2f}")
|
| 133 |
+
test_env.close()
|