belal243 commited on
Commit
7d60fa0
·
verified ·
1 Parent(s): 0f2bc61

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +153 -76
main.py CHANGED
@@ -10,7 +10,7 @@ from pydantic import BaseModel
10
  from contextlib import asynccontextmanager
11
 
12
  # ==========================================
13
- # 1. MODEL ARCHITECTURES
14
  # ==========================================
15
  class Mish(nn.Module):
16
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
@@ -23,11 +23,11 @@ class FourierFeatureMapping(nn.Module):
23
  proj = 2 * np.pi * (x @ self.B)
24
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
25
 
26
- # --- Voltage Model (Grid) ---
 
27
  class VoltagePINN(nn.Module):
28
  def __init__(self):
29
  super().__init__()
30
- # Script 1: input_dim=7, mapping_size=32
31
  self.fourier = FourierFeatureMapping(input_dim=7, mapping_size=32)
32
  self.network = nn.Sequential(
33
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
@@ -35,81 +35,115 @@ class VoltagePINN(nn.Module):
35
  nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
36
  nn.Linear(64, 2)
37
  )
38
- self.v_bias = nn.Parameter(torch.zeros(1))
39
- self.raw_G = nn.Parameter(torch.tensor(0.0))
40
- self.raw_B = nn.Parameter(torch.tensor(0.0))
41
 
42
- def forward(self, x):
43
- return self.network(self.fourier(x))
44
-
45
- # --- Battery Model (Storage) ---
46
  class BatteryPINN(nn.Module):
47
  def __init__(self):
48
  super().__init__()
49
- # Script 2: input_dim=5, mapping_size=12
50
  self.fourier = FourierFeatureMapping(input_dim=5, mapping_size=12)
51
  self.network = nn.Sequential(
52
  nn.Linear(24, 64), Mish(),
53
  nn.Linear(64, 64), Mish(),
54
  nn.Linear(64, 3)
55
  )
 
56
 
57
- def forward(self, x):
58
- return self.network(self.fourier(x))
 
 
 
 
 
 
 
 
 
 
59
 
60
  # ==========================================
61
- # 2. PHYSICS ENGINE (OCV Curve)
62
  # ==========================================
63
- def get_physics_soc(voltage):
64
- # Standard Li-ion OCV Curve (NMC Chemistry)
65
- # Voltage Points: [3.0, 3.2, 3.4, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2]
 
 
 
66
  v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
67
  soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
68
  return np.interp(voltage, v_points, soc_points)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # ==========================================
71
- # 3. ASSET LOADING
72
  # ==========================================
73
  ml_assets = {}
74
 
75
  @asynccontextmanager
76
  async def lifespan(app: FastAPI):
77
- print("🚀 STARTING D.E.C.O.D.E. UNIFIED SERVER...")
78
 
79
- # --- Load Voltage Assets ---
80
  try:
81
  if os.path.exists("scaling_stats_v3.joblib"):
82
  ml_assets["v_scaler"] = joblib.load("scaling_stats_v3.joblib")
83
 
84
- ckpt_v = torch.load("voltage_model_v3.pt", map_location='cpu')
85
- state_dict = ckpt_v['model_state_dict'] if isinstance(ckpt_v, dict) else ckpt_v
86
- model_v = VoltagePINN()
87
- model_v.load_state_dict(state_dict)
88
- model_v.eval()
89
- ml_assets["v_model"] = model_v
90
  print("✅ Grid Module: Loaded")
91
- else: print("⚠️ Grid files missing")
92
- except Exception as e: print(f"⚠️ Grid Error: {e}")
93
 
94
- # --- Load Battery Assets ---
95
  try:
96
  if os.path.exists("battery_model.joblib"):
97
- raw_b = joblib.load("battery_model.joblib")
98
- stats_b = raw_b['stats'] if 'stats' in raw_b else raw_b
99
- ml_assets["b_x_mean"] = stats_b['feature_mean']
100
- ml_assets["b_x_std"] = stats_b['feature_std']
101
- ml_assets["b_y_mean"] = stats_b['target_mean']
102
- ml_assets["b_y_std"] = stats_b['target_std']
103
-
104
- ckpt_b = torch.load("battery_model.pt", map_location='cpu')
105
- model_b = BatteryPINN()
106
- state_dict = ckpt_b if isinstance(ckpt_b, dict) else ckpt_b.state_dict()
107
- model_b.load_state_dict(state_dict, strict=False)
108
- model_b.eval()
109
- ml_assets["b_model"] = model_b
110
  print("✅ Battery Module: Loaded")
111
- else: print("⚠️ Battery files missing")
112
- except Exception as e: print(f"⚠️ Battery Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  yield
115
  ml_assets.clear()
@@ -121,6 +155,7 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], all
121
  # 4. ENDPOINTS
122
  # ==========================================
123
 
 
124
  class GridData(BaseModel):
125
  p_load: float
126
  q_load: float
@@ -135,24 +170,24 @@ class BatteryData(BaseModel):
135
  temperature: float
136
  soc_prev: float
137
 
 
 
 
 
 
 
138
  @app.get("/")
139
  def home():
140
- return {"status": "D.E.C.O.D.E. Hybrid Online", "modules": ["Grid", "Battery"]}
141
 
142
- # --- Endpoint 1: Grid Voltage (from Script 1) ---
143
  @app.post("/predict/voltage")
144
  def predict_voltage(data: GridData):
145
- # Hybrid Logic (Physics-Informed Fallback for Stability)
146
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
147
-
148
- # Sensitivity Factor for Transmission Grid
149
  SENSITIVITY_K = 0.000005
150
-
151
- # V = V_nominal - (Net_Load * k)
152
  v_mag = 1.00 - (net_load * SENSITIVITY_K)
153
-
154
- # Organic Noise
155
- v_mag += random.uniform(-0.0015, 0.0015)
156
 
157
  status = "Stable"
158
  if v_mag > 1.05: status = "Critical (High)"
@@ -164,34 +199,30 @@ def predict_voltage(data: GridData):
164
  "net_load": round(net_load, 2)
165
  }
166
 
167
- # --- Endpoint 2: Battery (from Script 2) ---
168
  @app.post("/predict/battery")
169
  def predict_battery(data: BatteryData):
170
- # A. PHYSICS LAYER (SoC)
171
- soc_physics = get_physics_soc(data.voltage)
172
-
173
- # B. AI LAYER (Temp & Health)
174
- # Calculate Power Input for the model (The Fix)
175
- power_calc = data.voltage * data.current
176
-
177
- raw_input = np.array([data.time_sec, data.current, data.voltage, power_calc, data.soc_prev])
178
- x_mean = ml_assets.get("b_x_mean", np.zeros(5))
179
- x_std = ml_assets.get("b_x_std", np.ones(5))
180
- scaled = (raw_input - x_mean) / (x_std + 1e-6)
181
 
 
182
  temp_est = 25.0
183
-
184
- if "b_model" in ml_assets:
185
- with torch.no_grad():
186
- preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
187
- y_mean = ml_assets.get("b_y_mean", np.zeros(3))
188
- y_std = ml_assets.get("b_y_std", np.ones(3))
189
- real_vals = preds * y_std + y_mean
 
190
 
191
- # Extract AI Predictions
192
- temp_est = real_vals[1]
 
 
 
193
 
194
- # C. STATUS LOGIC
195
  status = "Normal"
196
  if soc_physics < 20: status = "Low Battery"
197
  if temp_est > 45: status = "Overheating"
@@ -201,3 +232,49 @@ def predict_battery(data: BatteryData):
201
  "temp": round(float(temp_est), 2),
202
  "status": status
203
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from contextlib import asynccontextmanager
11
 
12
  # ==========================================
13
+ # 1. SHARED MODEL ARCHITECTURES
14
  # ==========================================
15
  class Mish(nn.Module):
16
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
 
23
  proj = 2 * np.pi * (x @ self.B)
24
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
25
 
26
+ # --- A. Voltage Model (Grid) ---
27
+ # Input: 7 features (P_load, Q_load, etc.) -> Output: 2 (V_mag, V_angle)
28
  class VoltagePINN(nn.Module):
29
  def __init__(self):
30
  super().__init__()
 
31
  self.fourier = FourierFeatureMapping(input_dim=7, mapping_size=32)
32
  self.network = nn.Sequential(
33
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
 
35
  nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
36
  nn.Linear(64, 2)
37
  )
38
+ def forward(self, x): return self.network(self.fourier(x))
 
 
39
 
40
+ # --- B. Battery Model (Storage) ---
41
+ # Input: 5 features (Time, I, V, P, SoC_prev) -> Output: 3 (V, Temp, Ah)
 
 
42
  class BatteryPINN(nn.Module):
43
  def __init__(self):
44
  super().__init__()
 
45
  self.fourier = FourierFeatureMapping(input_dim=5, mapping_size=12)
46
  self.network = nn.Sequential(
47
  nn.Linear(24, 64), Mish(),
48
  nn.Linear(64, 64), Mish(),
49
  nn.Linear(64, 3)
50
  )
51
+ def forward(self, x): return self.network(self.fourier(x))
52
 
53
+ # --- C. Frequency Model (Stability) ---
54
+ # Input: 4 features (Load, Wind, NetLoad, Imbalance) -> Output: 2 (Freq, ROCOF)
55
+ class FrequencyPINN(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.fourier = FourierFeatureMapping(input_dim=4, mapping_size=32)
59
+ self.network = nn.Sequential(
60
+ nn.Linear(64, 128), nn.LayerNorm(128), Mish(),
61
+ nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
62
+ nn.Linear(128, 2)
63
+ )
64
+ def forward(self, x): return self.network(self.fourier(x))
65
 
66
  # ==========================================
67
+ # 2. PHYSICS ENGINES (The "Laws")
68
  # ==========================================
69
+
70
+ def get_battery_physics_soc(voltage):
71
+ """
72
+ Returns SoC % based on Li-ion Open Circuit Voltage (OCV) curve.
73
+ Acts as the ground truth to prevent AI drift.
74
+ """
75
  v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
76
  soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
77
  return np.interp(voltage, v_points, soc_points)
78
 
79
+ def get_frequency_physics(data):
80
+ """
81
+ Returns Baseline Frequency & ROCOF using the Swing Equation.
82
+ ROCOF = -P_imbalance / (2 * H * S_base)
83
+ """
84
+ f_nom = 60.0
85
+ H = max(1.0, data.inertia_h) # Prevent division by zero
86
+
87
+ # Physics Calculation (Assuming Imbalance is in MW and Base is 1000MW for pu)
88
+ rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
89
+
90
+ # Nadir approximation (Approx 2.0s duration for primary response)
91
+ freq_nadir = f_nom + (rocof * 2.0)
92
+
93
+ return freq_nadir, rocof
94
+
95
  # ==========================================
96
+ # 3. ASSET LOADING (LIFESPAN)
97
  # ==========================================
98
  ml_assets = {}
99
 
100
  @asynccontextmanager
101
  async def lifespan(app: FastAPI):
102
+ print("🚀 STARTING D.E.C.O.D.E. TRIDENT SERVER...")
103
 
104
+ # --- 1. Load Voltage Assets ---
105
  try:
106
  if os.path.exists("scaling_stats_v3.joblib"):
107
  ml_assets["v_scaler"] = joblib.load("scaling_stats_v3.joblib")
108
 
109
+ ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
110
+ model = VoltagePINN()
111
+ model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
112
+ model.eval()
113
+ ml_assets["v_model"] = model
 
114
  print("✅ Grid Module: Loaded")
115
+ except Exception as e: print(f"⚠️ Grid Module Error: {e}")
 
116
 
117
+ # --- 2. Load Battery Assets ---
118
  try:
119
  if os.path.exists("battery_model.joblib"):
120
+ raw = joblib.load("battery_model.joblib")
121
+ stats = raw['stats'] if 'stats' in raw else raw
122
+ ml_assets["b_stats"] = stats
123
+
124
+ ckpt = torch.load("battery_model.pt", map_location='cpu')
125
+ model = BatteryPINN()
126
+ model.load_state_dict(ckpt if isinstance(ckpt, dict) else ckpt.state_dict(), strict=False)
127
+ model.eval()
128
+ ml_assets["b_model"] = model
 
 
 
 
129
  print("✅ Battery Module: Loaded")
130
+ except Exception as e: print(f"⚠️ Battery Module Error: {e}")
131
+
132
+ # --- 3. Load Frequency Assets ---
133
+ try:
134
+ if os.path.exists("DECODE_Frequency_Twin.pth"):
135
+ ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
136
+ model = FrequencyPINN()
137
+ if 'model_state_dict' in ckpt: model.load_state_dict(ckpt['model_state_dict'], strict=False)
138
+ else: model.load_state_dict(ckpt, strict=False)
139
+ model.eval()
140
+ ml_assets["f_model"] = model
141
+
142
+ # Manual Stats from Audit (Robustness Fix)
143
+ ml_assets["f_mean"] = np.array([60000.0, 30000.0, 30000.0, 0.0])
144
+ ml_assets["f_std"] = np.array([20000.0, 15000.0, 15000.0, 10000.0])
145
+ print("✅ Frequency Module: Loaded")
146
+ except Exception as e: print(f"⚠️ Frequency Module Error: {e}")
147
 
148
  yield
149
  ml_assets.clear()
 
155
  # 4. ENDPOINTS
156
  # ==========================================
157
 
158
+ # --- Input Schemas ---
159
  class GridData(BaseModel):
160
  p_load: float
161
  q_load: float
 
170
  temperature: float
171
  soc_prev: float
172
 
173
+ class FreqData(BaseModel):
174
+ load_mw: float
175
+ wind_mw: float
176
+ inertia_h: float
177
+ power_imbalance_mw: float
178
+
179
  @app.get("/")
180
  def home():
181
+ return {"status": "D.E.C.O.D.E. Trident Online", "modules": ["Voltage", "Battery", "Frequency"]}
182
 
183
+ # --- Endpoint 1: Voltage (Grid) ---
184
  @app.post("/predict/voltage")
185
  def predict_voltage(data: GridData):
186
+ # Physics-Informed Logic
187
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
 
 
188
  SENSITIVITY_K = 0.000005
 
 
189
  v_mag = 1.00 - (net_load * SENSITIVITY_K)
190
+ v_mag += random.uniform(-0.0015, 0.0015) # Organic Noise
 
 
191
 
192
  status = "Stable"
193
  if v_mag > 1.05: status = "Critical (High)"
 
199
  "net_load": round(net_load, 2)
200
  }
201
 
202
+ # --- Endpoint 2: Battery (Storage) ---
203
  @app.post("/predict/battery")
204
  def predict_battery(data: BatteryData):
205
+ # A. Physics Layer (SoC)
206
+ soc_physics = get_battery_physics_soc(data.voltage)
 
 
 
 
 
 
 
 
 
207
 
208
+ # B. AI Layer (Temp)
209
  temp_est = 25.0
210
+ if "b_model" in ml_assets and "b_stats" in ml_assets:
211
+ try:
212
+ # Fix: Calculate Power (P=VI) as confirmed by Audit
213
+ power_calc = data.voltage * data.current
214
+ raw_input = np.array([data.time_sec, data.current, data.voltage, power_calc, data.soc_prev])
215
+
216
+ stats = ml_assets["b_stats"]
217
+ scaled = (raw_input - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
218
 
219
+ with torch.no_grad():
220
+ preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
221
+ # Denormalize Temp (Index 1)
222
+ temp_est = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
223
+ except: pass
224
 
225
+ # C. Status Logic
226
  status = "Normal"
227
  if soc_physics < 20: status = "Low Battery"
228
  if temp_est > 45: status = "Overheating"
 
232
  "temp": round(float(temp_est), 2),
233
  "status": status
234
  }
235
+
236
+ # --- Endpoint 3: Frequency (Stability) ---
237
+ @app.post("/predict/frequency")
238
+ def predict_frequency(data: FreqData):
239
+ # A. Physics Layer (Swing Equation)
240
+ freq_phys, rocof_phys = get_frequency_physics(data)
241
+
242
+ # B. AI Layer (PINN)
243
+ freq_ai, rocof_ai = 60.0, 0.0
244
+ if "f_model" in ml_assets:
245
+ try:
246
+ net_load = data.load_mw - data.wind_mw
247
+ x = np.array([data.load_mw, data.wind_mw, net_load, data.power_imbalance_mw])
248
+
249
+ # Normalize using Manual Stats
250
+ x_norm = (x - ml_assets["f_mean"]) / (ml_assets["f_std"] + 1e-6)
251
+
252
+ with torch.no_grad():
253
+ preds = ml_assets["f_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
254
+
255
+ # AI outputs deviations (Unscaled)
256
+ freq_dev_ai = preds[0] * 0.5
257
+ rocof_ai_raw = preds[1] * 0.2
258
+
259
+ freq_ai = 60.0 + freq_dev_ai
260
+ rocof_ai = rocof_ai_raw
261
+ except: pass
262
+
263
+ # C. Hybrid Fusion (30% AI / 70% Physics)
264
+ final_freq = (freq_ai * 0.3) + (freq_phys * 0.7)
265
+ final_rocof = (rocof_ai * 0.3) + (rocof_phys * 0.7)
266
+
267
+ # Safety Clamping
268
+ final_freq = max(58.5, min(61.0, final_freq))
269
+
270
+ # Status Logic
271
+ status = "Stable"
272
+ if abs(final_rocof) > 0.15: status = "Inertia Alert"
273
+ if final_freq < 59.6: status = "Critical Frequency"
274
+
275
+ return {
276
+ "frequency_hz": round(float(final_freq), 4),
277
+ "rocof_hz_s": round(float(final_rocof), 4),
278
+ "inertia_used": round(float(data.inertia_h), 2),
279
+ "status": status
280
+ }