chinmay0805 commited on
Commit
06f9287
·
0 Parent(s):

inital commit

Browse files
Files changed (4) hide show
  1. .gitignore +23 -0
  2. app.py +226 -0
  3. smart_grid_env.py +239 -0
  4. train.py +133 -0
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python Cache
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual Environment
7
+ venv/
8
+
9
+ # Model Checkpoints & Artifacts
10
+ best_model/
11
+ checkpoints/
12
+ *.zip
13
+
14
+ # Logs and Evaluations
15
+ eval_logs/
16
+ tb_logs/
17
+ *.npz
18
+
19
+ # Pickled Data (e.g., normalization vectors)
20
+ *.pkl
21
+
22
+ # IDE / Editor
23
+ .vscode/
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — Streamlit dashboard for SmartGridEnv PPO agent
3
+
4
+ Fixes vs. original:
5
+ - Loads VecNormalize stats (vec_normalize.pkl) alongside the PPO model
6
+ so observations are correctly normalised at inference time
7
+ - int(action.item()) fixes numpy array comparison in action_to_text()
8
+ - Added a rule-based baseline agent for comparison
9
+ - Richer charts: cost-per-hour bar chart + solar/demand/battery area chart
10
+ - Step-level info table logged per episode
11
+ - Graceful error handling throughout
12
+ """
13
+
14
+ import os
15
+ import time
16
+ import numpy as np
17
+ import pandas as pd
18
+ import streamlit as st
19
+ from stable_baselines3 import PPO
20
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
21
+
22
+ from smart_grid_env import SmartGridEnv
23
+
24
+ # ── Page config ───────────────────────────────────────────────────────────────
25
+ st.set_page_config(page_title="Smart Grid AI Control", layout="wide")
26
+ st.title("Smart Grid Energy Management System")
27
+ st.markdown("### PPO Reinforcement Learning Agent vs. Rule-Based Baseline")
28
+
29
+ # ── Sidebar ───────────────────────────────────────────────────────────────────
30
+ st.sidebar.header("Simulation Settings")
31
+ sim_speed = st.sidebar.slider("Speed (sec / step)", 0.05, 2.0, 0.3)
32
+ agent_choice = st.sidebar.radio(
33
+ "Agent to run",
34
+ ["PPO Agent", "Rule-Based Baseline", "Compare Both"],
35
+ )
36
+ run_btn = st.sidebar.button("▶ Start 24-Hour Simulation", type="primary")
37
+
38
+ st.sidebar.markdown("---")
39
+ st.sidebar.markdown(
40
+ "**Rule-based logic:** charge when price < 0.20 and battery < 80%, "
41
+ "discharge when price > 0.40 and battery > 20%, else hold."
42
+ )
43
+
44
+
45
+ # ── Helpers ───────────────────────────────────────────────────────────────────
46
+ def action_to_label(action: int) -> str:
47
+ return {0: "Hold ⏸", 1: "Charge ⬆", 2: "Discharge ⬇"}.get(action, "?")
48
+
49
+
50
+ def rule_based_action(obs: np.ndarray) -> int:
51
+ """Simple price-threshold rule — useful as a sanity-check baseline."""
52
+ battery, solar, demand, price = obs
53
+ if price < 0.20 and battery < 80.0:
54
+ return 1 # cheap electricity → charge
55
+ if price > 0.40 and battery > 20.0:
56
+ return 2 # expensive electricity → use battery
57
+ return 0 # hold
58
+
59
+
60
+ def load_ppo_model():
61
+ """Load trained PPO model + normalisation stats. Returns (model, vec_env) or None."""
62
+ if not os.path.exists("ppo_smart_grid.zip"):
63
+ return None, None
64
+ try:
65
+ env = DummyVecEnv([SmartGridEnv])
66
+ if os.path.exists("vec_normalize.pkl"):
67
+ env = VecNormalize.load("vec_normalize.pkl", env)
68
+ env.training = False
69
+ env.norm_reward = False
70
+ model = PPO.load("ppo_smart_grid", env=env)
71
+ return model, env
72
+ except Exception as e:
73
+ st.error(f"Could not load model: {e}")
74
+ return None, None
75
+
76
+
77
+ def run_episode(agent: str, model=None, vec_env=None, speed: float = 0.3):
78
+ """
79
+ Run a single 24-step episode and return a DataFrame of step-level data.
80
+ agent: 'ppo' | 'rule'
81
+ """
82
+ raw_env = SmartGridEnv()
83
+ obs_raw, _ = raw_env.reset()
84
+
85
+ # PPO uses the normalised vec_env; rule-based uses raw env directly
86
+ if agent == "ppo" and vec_env is not None:
87
+ obs_vec = vec_env.reset()
88
+
89
+ records = []
90
+ total_cost = 0.0
91
+
92
+ live_battery = st.empty()
93
+ live_price = st.empty()
94
+ live_cost = st.empty()
95
+ live_chart = st.empty()
96
+
97
+ for step in range(24):
98
+ # ---- Pick action ----
99
+ if agent == "ppo" and model is not None:
100
+ action_arr, _ = model.predict(obs_vec, deterministic=True)
101
+ action = int(action_arr.item()) # numpy → plain int
102
+ obs_vec, _, _, _ = vec_env.step(action_arr)
103
+ # Step the raw env with the same action to get proper info dict
104
+ obs_raw, reward, terminated, _, info = raw_env.step(action)
105
+ else:
106
+ action = rule_based_action(obs_raw)
107
+ obs_raw, reward, terminated, _, info = raw_env.step(action)
108
+
109
+ cost = info["cost"]
110
+ total_cost += cost
111
+
112
+ battery = info["battery_soc"]
113
+ solar = info["solar_kw"]
114
+ demand = info["demand_kw"]
115
+ price = info["price"]
116
+
117
+ # ---- Live metrics ----
118
+ col1, col2, col3 = live_battery, live_price, live_cost
119
+ live_battery.metric("🔋 Battery SoC", f"{battery:.1f} %", action_to_label(action))
120
+ live_price.metric( "💲 Grid Price", f"${price:.3f}/kWh")
121
+ live_cost.metric( "💰 Running Cost", f"${total_cost:.2f}", delta_color="inverse")
122
+
123
+ records.append({
124
+ "Hour": step + 1,
125
+ "Battery (%)": round(battery, 2),
126
+ "Solar (kW)": round(solar, 2),
127
+ "Demand (kW)": round(demand, 2),
128
+ "Price ($/kWh)": round(price, 3),
129
+ "Step Cost ($)": round(cost, 3),
130
+ "Action": action_to_label(action),
131
+ })
132
+
133
+ # ---- Live chart (updates every step) ----
134
+ df_so_far = pd.DataFrame(records).set_index("Hour")
135
+ live_chart.line_chart(
136
+ df_so_far[["Battery (%)", "Solar (kW)", "Demand (kW)"]],
137
+ height=250,
138
+ )
139
+
140
+ time.sleep(speed)
141
+
142
+ raw_env.close()
143
+ return pd.DataFrame(records), total_cost
144
+
145
+
146
+ def show_results(df: pd.DataFrame, total_cost: float, label: str):
147
+ st.success(f"**{label}** — Total 24-hour cost: **${total_cost:.2f}**")
148
+
149
+ col_a, col_b = st.columns(2)
150
+ with col_a:
151
+ st.subheader("Hourly step cost ($)")
152
+ st.bar_chart(df.set_index("Hour")[["Step Cost ($)"]], height=220)
153
+ with col_b:
154
+ st.subheader("Battery, solar and demand")
155
+ st.line_chart(
156
+ df.set_index("Hour")[["Battery (%)", "Solar (kW)", "Demand (kW)"]],
157
+ height=220,
158
+ )
159
+
160
+ with st.expander("📋 Full step-by-step log"):
161
+ st.dataframe(df, use_container_width=True)
162
+
163
+
164
+ # ── Main simulation ───────────────────────────────────────────────────────────
165
+ if run_btn:
166
+ model, vec_env = load_ppo_model()
167
+ ppo_available = model is not None
168
+
169
+ if not ppo_available and agent_choice in ("PPO Agent", "Compare Both"):
170
+ st.warning(
171
+ "ppo_smart_grid.zip not found — run `python train.py` first. "
172
+ "Falling back to rule-based agent."
173
+ )
174
+
175
+ # ---- PPO only ----
176
+ if agent_choice == "PPO Agent":
177
+ st.markdown("### PPO Agent")
178
+ agent = "ppo" if ppo_available else "rule"
179
+ df, cost = run_episode(agent, model, vec_env, sim_speed)
180
+ show_results(df, cost, "PPO Agent" if ppo_available else "Rule-Based (fallback)")
181
+
182
+ # ---- Rule-based only ----
183
+ elif agent_choice == "Rule-Based Baseline":
184
+ st.markdown("### Rule-Based Baseline")
185
+ df, cost = run_episode("rule", speed=sim_speed)
186
+ show_results(df, cost, "Rule-Based Baseline")
187
+
188
+ # ---- Compare both ----
189
+ else:
190
+ tab_ppo, tab_rule = st.tabs([" PPO Agent", " Rule-Based Baseline"])
191
+
192
+ with tab_ppo:
193
+ st.markdown("#### PPO Agent — running...")
194
+ agent = "ppo" if ppo_available else "rule"
195
+ df_ppo, cost_ppo = run_episode(agent, model, vec_env, sim_speed)
196
+ show_results(df_ppo, cost_ppo, "PPO Agent" if ppo_available else "Rule-Based (fallback)")
197
+
198
+ with tab_rule:
199
+ st.markdown("#### Rule-Based Baseline — running...")
200
+ df_rule, cost_rule = run_episode("rule", speed=sim_speed)
201
+ show_results(df_rule, cost_rule, "Rule-Based Baseline")
202
+
203
+ # ---- Side-by-side cost summary ----
204
+ st.markdown("---")
205
+ st.subheader("Cost comparison")
206
+ c1, c2, c3 = st.columns(3)
207
+ c1.metric("PPO total cost", f"${cost_ppo:.2f}")
208
+ c2.metric("Rule-based total cost", f"${cost_rule:.2f}")
209
+ saving = cost_rule - cost_ppo
210
+ c3.metric("PPO saving", f"${saving:.2f}",
211
+ delta=f"{'better' if saving > 0 else 'worse'} than rule-based",
212
+ delta_color="normal" if saving > 0 else "inverse")
213
+
214
+ else:
215
+ st.info("Configure settings in the sidebar and click **▶ Start 24-Hour Simulation**.")
216
+ st.markdown("""
217
+ #### How it works
218
+ | Component | Detail |
219
+ |---|---|
220
+ | Environment | 24-step episode (1 step = 1 hour) |
221
+ | Observation | Battery SoC, solar generation, house demand, grid price |
222
+ | Actions | Hold / Charge from grid / Discharge battery |
223
+ | Reward | Negative net grid cost (includes solar sell-back revenue) |
224
+ | Agent | PPO with MLP policy, trained via `stable-baselines3` |
225
+ | Baseline | Simple price-threshold rule for comparison |
226
+ """)
smart_grid_env.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ from gymnasium import spaces
4
+
5
+
6
+ class SmartGridEnv(gym.Env):
7
+ """
8
+ Smart Grid Battery Management Environment (Gymnasium-compatible)
9
+
10
+ Goal: Minimize daily electricity cost by intelligently charging/discharging
11
+ a home battery, using solar generation, and interacting with the grid.
12
+
13
+ Observation (4 values, all normalized to [0, 1]):
14
+ [battery_soc, solar_gen_kw, house_demand_kw, grid_price]
15
+
16
+ Action space (Discrete 3):
17
+ 0 = Hold — do nothing beyond covering net load from grid/solar
18
+ 1 = Charge — buy from grid to fill battery (10 kW rate)
19
+ 2 = Discharge — draw from battery to cover load (10 kW rate)
20
+
21
+ Reward:
22
+ Negative net cost per step (agent learns to minimize cost).
23
+ Includes sell-back revenue when solar surplus is fed to grid.
24
+ Includes a small battery health penalty for extreme SoC operation.
25
+
26
+ Fixes vs. original:
27
+ - Observation space bounds match actual value ranges
28
+ - Solar/demand energy balance applied in all action branches
29
+ - Grid sell-back (feed-in tariff) modeled in HOLD branch
30
+ - Battery SoC clamped to [0, battery_capacity]
31
+ - Charging efficiency loss (90%) modeled
32
+ - Correlated time-series for price/solar/demand (no more i.i.d. jumps)
33
+ - Battery health penalty for operating near 0% or 100% SoC
34
+ - render() method added
35
+ - Fully compatible with check_env() and VecNormalize
36
+ """
37
+
38
+ metadata = {"render_modes": ["human"]}
39
+
40
+ # Physical constants
41
+ BATTERY_CAPACITY = 100.0 # kWh
42
+ CHARGE_RATE = 10.0 # kW (max charge/discharge per step)
43
+ CHARGE_EFFICIENCY = 0.90 # 90% round-trip efficiency
44
+ FEED_IN_TARIFF = 0.50 # sell surplus solar at 50% of grid price
45
+ SOC_PENALTY_COEF = 0.005 # small penalty for extreme battery levels
46
+
47
+ # Observation high limits (battery 0-100 %, solar 0-10 kW,
48
+ # demand 0-10 kW, price 0-1.5 $/kWh to cover double-peak)
49
+ OBS_HIGH = np.array([100.0, 10.0, 10.0, 1.5], dtype=np.float32)
50
+ OBS_LOW = np.zeros(4, dtype=np.float32)
51
+
52
+ def __init__(self, render_mode=None):
53
+ super().__init__()
54
+ self.render_mode = render_mode
55
+
56
+ # --- Action / Observation Spaces ---
57
+ self.action_space = spaces.Discrete(3)
58
+ self.observation_space = spaces.Box(
59
+ low=self.OBS_LOW,
60
+ high=self.OBS_HIGH,
61
+ shape=(4,),
62
+ dtype=np.float32,
63
+ )
64
+
65
+ # Internal state
66
+ self.current_step = 0
67
+ self.current_battery = self.BATTERY_CAPACITY * 0.5
68
+ self._state = self._make_initial_state()
69
+
70
+ # For correlated time-series generation
71
+ self._price_base = 0.2 # $/kWh (drifts each step)
72
+ self._demand_base = 3.0 # kW
73
+ self._solar_base = 0.0 # kW
74
+
75
+ # ------------------------------------------------------------------
76
+ # Gymnasium API
77
+ # ------------------------------------------------------------------
78
+
79
+ def reset(self, seed=None, options=None):
80
+ super().reset(seed=seed)
81
+ self.current_step = 0
82
+ self.current_battery = self.BATTERY_CAPACITY * 0.5
83
+ self._price_base = 0.2
84
+ self._demand_base = 3.0
85
+ self._solar_base = 0.0
86
+ self._state = self._make_initial_state()
87
+ return self._state.copy(), {}
88
+
89
+ def step(self, action):
90
+ assert self.action_space.contains(action), f"Invalid action {action}"
91
+
92
+ battery, solar, demand, price = self._state
93
+
94
+ # ---- 1. Compute net load BEFORE battery action ----
95
+ # Positive → house needs more than solar provides (must buy or discharge)
96
+ # Negative → solar surplus (can sell back or charge)
97
+ net_load = demand - solar
98
+
99
+ grid_cost = 0.0 # positive = paying, negative = earning
100
+
101
+ # ---- 2. Execute battery action ----
102
+ if action == 1: # CHARGE from grid
103
+ # How much can we actually charge?
104
+ headroom = self.BATTERY_CAPACITY - battery
105
+ charge_requested = min(self.CHARGE_RATE, headroom)
106
+ # Efficiency: we buy more from grid than actually stored
107
+ grid_draw = charge_requested / self.CHARGE_EFFICIENCY
108
+ self.current_battery = np.clip(
109
+ battery + charge_requested,
110
+ 0.0, self.BATTERY_CAPACITY
111
+ )
112
+ # Also cover net_load from grid
113
+ grid_cost = (max(0.0, net_load) + grid_draw) * price
114
+ # If solar surplus even after load, get sell-back credit
115
+ grid_cost -= max(0.0, -net_load) * price * self.FEED_IN_TARIFF
116
+
117
+ elif action == 2: # DISCHARGE battery to cover load
118
+ # How much battery can supply?
119
+ discharge_requested = min(self.CHARGE_RATE, battery)
120
+ self.current_battery = np.clip(
121
+ battery - discharge_requested,
122
+ 0.0, self.BATTERY_CAPACITY
123
+ )
124
+ # Remaining load after battery contribution
125
+ residual_load = net_load - discharge_requested
126
+ if residual_load > 0:
127
+ grid_cost = residual_load * price # still need some grid
128
+ else:
129
+ grid_cost = residual_load * price * self.FEED_IN_TARIFF # surplus → sell
130
+
131
+ else: # HOLD — let solar + grid balance the load
132
+ if net_load > 0:
133
+ grid_cost = net_load * price # buy deficit from grid
134
+ else:
135
+ grid_cost = net_load * price * self.FEED_IN_TARIFF # sell surplus
136
+
137
+ # ---- 3. Battery health penalty (discourages extreme SoC) ----
138
+ soc_frac = self.current_battery / self.BATTERY_CAPACITY
139
+ health_penalty = self.SOC_PENALTY_COEF * (
140
+ max(0.0, soc_frac - 0.9) + max(0.0, 0.1 - soc_frac)
141
+ )
142
+
143
+ # ---- 4. Reward ----
144
+ reward = -(grid_cost + health_penalty)
145
+
146
+ # ---- 5. Advance time, generate next state ----
147
+ self.current_step += 1
148
+ terminated = self.current_step >= 24
149
+ truncated = False
150
+
151
+ if not terminated:
152
+ self._state = self._generate_next_state()
153
+ else:
154
+ self._state = np.zeros(4, dtype=np.float32)
155
+
156
+ info = {
157
+ "cost": float(grid_cost),
158
+ "battery_soc": float(self.current_battery),
159
+ "solar_kw": float(solar),
160
+ "demand_kw": float(demand),
161
+ "price": float(price),
162
+ "action": int(action),
163
+ }
164
+
165
+ if self.render_mode == "human":
166
+ self.render()
167
+
168
+ return self._state.copy(), float(reward), terminated, truncated, info
169
+
170
+ def render(self):
171
+ b, s, d, p = self._state
172
+ print(
173
+ f"[Hour {self.current_step:02d}] "
174
+ f"Battery={self.current_battery:.1f}% | "
175
+ f"Solar={s:.2f}kW | Demand={d:.2f}kW | Price=${p:.3f}/kWh"
176
+ )
177
+
178
+ def close(self):
179
+ pass
180
+
181
+ # ------------------------------------------------------------------
182
+ # Helpers
183
+ # ------------------------------------------------------------------
184
+
185
+ def _make_initial_state(self) -> np.ndarray:
186
+ return np.array(
187
+ [self.current_battery, 0.0, 2.5, 0.15],
188
+ dtype=np.float32
189
+ )
190
+
191
+ def _generate_next_state(self) -> np.ndarray:
192
+ """
193
+ Correlated time-series generation so successive hours are smooth.
194
+ Each variable drifts toward a time-of-day mean with Gaussian noise.
195
+ """
196
+ hour = self.current_step # 0–23
197
+
198
+ # --- Solar: bell-curve peaking at noon, zero at night ---
199
+ if 6 <= hour <= 18:
200
+ solar_mean = 5.0 * np.exp(-0.5 * ((hour - 12) / 3.5) ** 2)
201
+ else:
202
+ solar_mean = 0.0
203
+ self._solar_base += 0.3 * (solar_mean - self._solar_base)
204
+ next_solar = float(np.clip(
205
+ self._solar_base + self.np_random.normal(0, 0.4),
206
+ 0.0, 10.0
207
+ ))
208
+
209
+ # --- Demand: morning and evening peaks ---
210
+ demand_mean = (
211
+ 2.5
212
+ + 2.0 * np.exp(-0.5 * ((hour - 8) / 1.5) ** 2) # morning
213
+ + 3.0 * np.exp(-0.5 * ((hour - 19) / 2.0) ** 2) # evening
214
+ )
215
+ self._demand_base += 0.3 * (demand_mean - self._demand_base)
216
+ next_demand = float(np.clip(
217
+ self._demand_base + self.np_random.normal(0, 0.5),
218
+ 0.0, 10.0
219
+ ))
220
+
221
+ # --- Price: cheap at night, expensive at peak (17–21) ---
222
+ if 17 <= hour <= 21:
223
+ price_mean = 0.65
224
+ elif 6 <= hour <= 9:
225
+ price_mean = 0.30
226
+ elif 23 <= hour or hour <= 5:
227
+ price_mean = 0.12
228
+ else:
229
+ price_mean = 0.22
230
+ self._price_base += 0.4 * (price_mean - self._price_base)
231
+ next_price = float(np.clip(
232
+ self._price_base + self.np_random.normal(0, 0.03),
233
+ 0.05, 1.5
234
+ ))
235
+
236
+ return np.array(
237
+ [self.current_battery, next_solar, next_demand, next_price],
238
+ dtype=np.float32
239
+ )
train.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py — PPO training script for SmartGridEnv
3
+
4
+ Fixes vs. original:
5
+ - check_env() validates the environment before training starts
6
+ - VecNormalize auto-normalizes observations and rewards for stable gradients
7
+ - 500,000 timesteps (was 10,000 — far too few for PPO to learn anything)
8
+ - EvalCallback saves the best model checkpoint automatically
9
+ - Hyperparameters tuned for this problem (n_steps, batch_size, ent_coef)
10
+ - vec_normalize stats saved alongside model (required for correct inference)
11
+ - TensorBoard logging enabled (optional — run: tensorboard --logdir ./tb_logs)
12
+ """
13
+
14
+ import os
15
+ from stable_baselines3 import PPO
16
+ from stable_baselines3.common.env_util import make_vec_env
17
+ from stable_baselines3.common.vec_env import VecNormalize
18
+ from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
19
+ from stable_baselines3.common.env_checker import check_env
20
+
21
+ from smart_grid_env import SmartGridEnv
22
+
23
+ # ── 1. Validate environment ───────────────────────────────────────────────────
24
+ print("Checking environment...")
25
+ check_env(SmartGridEnv(), warn=True)
26
+ print("Environment check passed.\n")
27
+
28
+ # ── 2. Vectorised training environment (4 parallel workers) ───────────────────
29
+ N_ENVS = 4
30
+ train_env = make_vec_env(SmartGridEnv, n_envs=N_ENVS)
31
+ train_env = VecNormalize(
32
+ train_env,
33
+ norm_obs=True, # normalizes each obs dimension to ~N(0,1)
34
+ norm_reward=True, # normalizes reward scale — critical for PPO stability
35
+ clip_obs=10.0,
36
+ )
37
+
38
+ # ── 3. Separate evaluation environment (no reward normalisation) ───────────────
39
+ eval_env = make_vec_env(SmartGridEnv, n_envs=1)
40
+ eval_env = VecNormalize(
41
+ eval_env,
42
+ norm_obs=True,
43
+ norm_reward=False, # raw rewards for interpretable eval metrics
44
+ training=False, # stats are copied from train_env, not updated
45
+ clip_obs=10.0,
46
+ )
47
+
48
+ # ── 4. Define the PPO model ────────────────────────────────────────────────────
49
+ model = PPO(
50
+ policy = "MlpPolicy",
51
+ env = train_env,
52
+ verbose = 1,
53
+ tensorboard_log = "./tb_logs",
54
+ # --- Core PPO hyperparameters ---
55
+ n_steps = 1024, # steps collected per env per rollout
56
+ batch_size = 256, # minibatch size for gradient update
57
+ n_epochs = 10, # number of passes over each rollout buffer
58
+ gamma = 0.99, # discount factor (long-horizon cost matters)
59
+ gae_lambda = 0.95, # GAE smoothing
60
+ clip_range = 0.2, # PPO clip parameter
61
+ learning_rate = 3e-4, # Adam lr
62
+ ent_coef = 0.01, # entropy bonus (encourages exploration early on)
63
+ vf_coef = 0.5,
64
+ max_grad_norm = 0.5,
65
+ # --- Policy network architecture ---
66
+ policy_kwargs = dict(net_arch=[128, 128]), # 2-layer MLP, 128 units each
67
+ )
68
+
69
+ # ── 5. Callbacks ───────────────────────────────────────────────────────────────
70
+ os.makedirs("./best_model", exist_ok=True)
71
+ os.makedirs("./checkpoints", exist_ok=True)
72
+
73
+ eval_callback = EvalCallback(
74
+ eval_env,
75
+ best_model_save_path = "./best_model",
76
+ log_path = "./eval_logs",
77
+ eval_freq = max(5_000 // N_ENVS, 1), # evaluate every ~5k env steps
78
+ n_eval_episodes = 20, # average over 20 full 24-hour episodes
79
+ deterministic = True,
80
+ render = False,
81
+ )
82
+
83
+ checkpoint_callback = CheckpointCallback(
84
+ save_freq = max(50_000 // N_ENVS, 1),
85
+ save_path = "./checkpoints",
86
+ name_prefix= "ppo_smart_grid",
87
+ )
88
+
89
+ # ── 6. Train ───────────────────────────────────────────────────────────────────
90
+ TOTAL_TIMESTEPS = 500_000
91
+ print(f"Training PPO for {TOTAL_TIMESTEPS:,} timesteps across {N_ENVS} parallel envs...")
92
+ print("Tip: run `tensorboard --logdir ./tb_logs` to monitor training live.\n")
93
+
94
+ model.learn(
95
+ total_timesteps = TOTAL_TIMESTEPS,
96
+ callback = [eval_callback, checkpoint_callback],
97
+ progress_bar = True,
98
+ )
99
+
100
+ # ── 7. Save final model + normalisation statistics ────────────────────────────
101
+ model.save("ppo_smart_grid")
102
+ train_env.save("vec_normalize.pkl") # MUST be saved — needed for inference
103
+
104
+ print("\nTraining complete!")
105
+ print(" Saved: ppo_smart_grid.zip")
106
+ print(" Saved: vec_normalize.pkl (required alongside the model for inference)")
107
+ print(" Best checkpoint: ./best_model/best_model.zip")
108
+
109
+ # ── 8. Quick sanity-check: run one episode with the trained agent ──────────────
110
+ print("\n--- Sanity check: one 24-hour episode ---")
111
+ from stable_baselines3.common.vec_env import DummyVecEnv
112
+
113
+ test_env = DummyVecEnv([SmartGridEnv])
114
+ test_env = VecNormalize.load("vec_normalize.pkl", test_env)
115
+ test_env.training = False
116
+ test_env.norm_reward = False
117
+
118
+ obs = test_env.reset()
119
+ total_cost = 0.0
120
+ for hour in range(24):
121
+ action, _ = model.predict(obs, deterministic=True)
122
+ obs, reward, done, info = test_env.step(action)
123
+ total_cost += info[0]["cost"]
124
+ action_label = ["Hold", "Charge", "Discharge"][int(action[0])]
125
+ print(
126
+ f" Hour {hour+1:02d} | Action: {action_label:<10} | "
127
+ f"Battery: {info[0]['battery_soc']:5.1f}% | "
128
+ f"Price: ${info[0]['price']:.3f} | "
129
+ f"Step cost: ${info[0]['cost']:.3f}"
130
+ )
131
+
132
+ print(f"\nTotal 24-hour cost: ${total_cost:.2f}")
133
+ test_env.close()