belal243 commited on
Commit
13224c7
·
verified ·
1 Parent(s): a534474

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +80 -132
main.py CHANGED
@@ -24,7 +24,6 @@ class FourierFeatureMapping(nn.Module):
24
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
25
 
26
  # --- A. Voltage Model (Grid) ---
27
- # Input: 7 features (P_load, Q_load, etc.) -> Output: 2 (V_mag, V_angle)
28
  class VoltagePINN(nn.Module):
29
  def __init__(self):
30
  super().__init__()
@@ -38,7 +37,6 @@ class VoltagePINN(nn.Module):
38
  def forward(self, x): return self.network(self.fourier(x))
39
 
40
  # --- B. Battery Model (Storage) ---
41
- # Input: 5 features (Time, I, V, P, SoC_prev) -> Output: 3 (V, Temp, Ah)
42
  class BatteryPINN(nn.Module):
43
  def __init__(self):
44
  super().__init__()
@@ -46,12 +44,11 @@ class BatteryPINN(nn.Module):
46
  self.network = nn.Sequential(
47
  nn.Linear(24, 64), Mish(),
48
  nn.Linear(64, 64), Mish(),
49
- nn.Linear(64, 3)
50
  )
51
  def forward(self, x): return self.network(self.fourier(x))
52
 
53
  # --- C. Frequency Model (Stability) ---
54
- # Input: 4 features (Load, Wind, NetLoad, Imbalance) -> Output: 2 (Freq, ROCOF)
55
  class FrequencyPINN(nn.Module):
56
  def __init__(self):
57
  super().__init__()
@@ -63,33 +60,33 @@ class FrequencyPINN(nn.Module):
63
  )
64
  def forward(self, x): return self.network(self.fourier(x))
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # ==========================================
67
- # 2. PHYSICS ENGINES (The "Laws")
68
  # ==========================================
69
 
70
  def get_battery_physics_soc(voltage):
71
- """
72
- Returns SoC % based on Li-ion Open Circuit Voltage (OCV) curve.
73
- Acts as the ground truth to prevent AI drift.
74
- """
75
  v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
76
  soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
77
  return np.interp(voltage, v_points, soc_points)
78
 
79
  def get_frequency_physics(data):
80
- """
81
- Returns Baseline Frequency & ROCOF using the Swing Equation.
82
- ROCOF = -P_imbalance / (2 * H * S_base)
83
- """
84
  f_nom = 60.0
85
- H = max(1.0, data.inertia_h) # Prevent division by zero
86
-
87
- # Physics Calculation (Assuming Imbalance is in MW and Base is 1000MW for pu)
88
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
89
-
90
- # Nadir approximation (Approx 2.0s duration for primary response)
91
  freq_nadir = f_nom + (rocof * 2.0)
92
-
93
  return freq_nadir, rocof
94
 
95
  # ==========================================
@@ -99,51 +96,51 @@ ml_assets = {}
99
 
100
  @asynccontextmanager
101
  async def lifespan(app: FastAPI):
102
- print("🚀 STARTING D.E.C.O.D.E. TRIDENT SERVER...")
103
 
104
- # --- 1. Load Voltage Assets ---
105
  try:
106
- if os.path.exists("scaling_stats_v3.joblib"):
107
- ml_assets["v_scaler"] = joblib.load("scaling_stats_v3.joblib")
108
-
109
  ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
110
  model = VoltagePINN()
111
  model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
112
  model.eval()
113
  ml_assets["v_model"] = model
114
- print(" Grid Module: Loaded")
115
- except Exception as e: print(f"⚠️ Grid Module Error: {e}")
116
 
117
- # --- 2. Load Battery Assets ---
118
  try:
119
  if os.path.exists("battery_model.joblib"):
120
  raw = joblib.load("battery_model.joblib")
121
- stats = raw['stats'] if 'stats' in raw else raw
122
- ml_assets["b_stats"] = stats
123
-
124
- ckpt = torch.load("battery_model.pt", map_location='cpu')
125
  model = BatteryPINN()
126
- model.load_state_dict(ckpt if isinstance(ckpt, dict) else ckpt.state_dict(), strict=False)
127
  model.eval()
128
  ml_assets["b_model"] = model
129
- print(" Battery Module: Loaded")
130
- except Exception as e: print(f"⚠️ Battery Module Error: {e}")
131
 
132
- # --- 3. Load Frequency Assets ---
133
  try:
134
  if os.path.exists("DECODE_Frequency_Twin.pth"):
135
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
136
  model = FrequencyPINN()
137
- if 'model_state_dict' in ckpt: model.load_state_dict(ckpt['model_state_dict'], strict=False)
138
- else: model.load_state_dict(ckpt, strict=False)
139
  model.eval()
140
  ml_assets["f_model"] = model
141
-
142
- # Manual Stats from Audit (Robustness Fix)
143
  ml_assets["f_mean"] = np.array([60000.0, 30000.0, 30000.0, 0.0])
144
  ml_assets["f_std"] = np.array([20000.0, 15000.0, 15000.0, 10000.0])
145
- print(" Frequency Module: Loaded")
146
- except Exception as e: print(f"⚠️ Frequency Module Error: {e}")
 
 
 
 
 
 
 
 
 
 
147
 
148
  yield
149
  ml_assets.clear()
@@ -152,129 +149,80 @@ app = FastAPI(title="D.E.C.O.D.E. Unified API", lifespan=lifespan)
152
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
153
 
154
  # ==========================================
155
- # 4. ENDPOINTS
156
  # ==========================================
157
 
158
- # --- Input Schemas ---
159
  class GridData(BaseModel):
160
- p_load: float
161
- q_load: float
162
- wind_gen: float
163
- solar_gen: float
164
- hour: int
165
 
166
  class BatteryData(BaseModel):
167
- time_sec: float
168
- current: float
169
- voltage: float
170
- temperature: float
171
- soc_prev: float
172
 
173
  class FreqData(BaseModel):
174
- load_mw: float
175
- wind_mw: float
176
- inertia_h: float
177
- power_imbalance_mw: float
178
 
179
  @app.get("/")
180
- def home():
181
- return {"status": "D.E.C.O.D.E. Trident Online", "modules": ["Voltage", "Battery", "Frequency"]}
182
 
183
- # --- Endpoint 1: Voltage (Grid) ---
184
  @app.post("/predict/voltage")
185
  def predict_voltage(data: GridData):
186
- # Physics-Informed Logic
187
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
188
- SENSITIVITY_K = 0.000005
189
- v_mag = 1.00 - (net_load * SENSITIVITY_K)
190
- v_mag += random.uniform(-0.0015, 0.0015) # Organic Noise
191
-
192
- status = "Stable"
193
- if v_mag > 1.05: status = "Critical (High)"
194
- if v_mag < 0.95: status = "Critical (Low)"
195
 
196
- return {
197
- "voltage_pu": round(v_mag, 4),
198
- "status": status,
199
- "net_load": round(net_load, 2)
200
- }
201
-
202
- # --- Endpoint 2: Battery (Storage) ---
203
  @app.post("/predict/battery")
204
  def predict_battery(data: BatteryData):
205
- # A. Physics Layer (SoC)
206
  soc_physics = get_battery_physics_soc(data.voltage)
207
-
208
- # B. AI Layer (Temp)
209
  temp_est = 25.0
210
- if "b_model" in ml_assets and "b_stats" in ml_assets:
211
  try:
212
- # Fix: Calculate Power (P=VI) as confirmed by Audit
213
- power_calc = data.voltage * data.current
214
- raw_input = np.array([data.time_sec, data.current, data.voltage, power_calc, data.soc_prev])
215
-
216
  stats = ml_assets["b_stats"]
217
- scaled = (raw_input - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
218
-
219
  with torch.no_grad():
220
  preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
221
- # Denormalize Temp (Index 1)
222
  temp_est = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
223
  except: pass
 
224
 
225
- # C. Status Logic
226
- status = "Normal"
227
- if soc_physics < 20: status = "Low Battery"
228
- if temp_est > 45: status = "Overheating"
229
-
230
- return {
231
- "soc": round(float(soc_physics), 2),
232
- "temp": round(float(temp_est), 2),
233
- "status": status
234
- }
235
-
236
- # --- Endpoint 3: Frequency (Stability) ---
237
  @app.post("/predict/frequency")
238
  def predict_frequency(data: FreqData):
239
- # A. Physics Layer (Swing Equation)
240
- freq_phys, rocof_phys = get_frequency_physics(data)
241
-
242
- # B. AI Layer (PINN)
243
- freq_ai, rocof_ai = 60.0, 0.0
244
  if "f_model" in ml_assets:
245
  try:
246
- net_load = data.load_mw - data.wind_mw
247
- x = np.array([data.load_mw, data.wind_mw, net_load, data.power_imbalance_mw])
248
-
249
- # Normalize using Manual Stats
250
- x_norm = (x - ml_assets["f_mean"]) / (ml_assets["f_std"] + 1e-6)
251
-
252
  with torch.no_grad():
253
  preds = ml_assets["f_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
254
-
255
- # AI outputs deviations (Unscaled)
256
- freq_dev_ai = preds[0] * 0.5
257
- rocof_ai_raw = preds[1] * 0.2
258
-
259
- freq_ai = 60.0 + freq_dev_ai
260
- rocof_ai = rocof_ai_raw
261
  except: pass
262
-
263
- # C. Hybrid Fusion (30% AI / 70% Physics)
264
- final_freq = (freq_ai * 0.3) + (freq_phys * 0.7)
265
- final_rocof = (rocof_ai * 0.3) + (rocof_phys * 0.7)
 
 
 
 
266
 
267
- # Safety Clamping
268
- final_freq = max(58.5, min(61.0, final_freq))
 
 
269
 
270
- # Status Logic
271
- status = "Stable"
272
- if abs(final_rocof) > 0.15: status = "Inertia Alert"
273
- if final_freq < 59.6: status = "Critical Frequency"
 
 
 
274
 
275
- return {
276
- "frequency_hz": round(float(final_freq), 4),
277
- "rocof_hz_s": round(float(final_rocof), 4),
278
- "inertia_used": round(float(data.inertia_h), 2),
279
- "status": status
280
- }
 
24
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
25
 
26
  # --- A. Voltage Model (Grid) ---
 
27
  class VoltagePINN(nn.Module):
28
  def __init__(self):
29
  super().__init__()
 
37
  def forward(self, x): return self.network(self.fourier(x))
38
 
39
  # --- B. Battery Model (Storage) ---
 
40
  class BatteryPINN(nn.Module):
41
  def __init__(self):
42
  super().__init__()
 
44
  self.network = nn.Sequential(
45
  nn.Linear(24, 64), Mish(),
46
  nn.Linear(64, 64), Mish(),
47
+ nn.Linear(64, 3)
48
  )
49
  def forward(self, x): return self.network(self.fourier(x))
50
 
51
  # --- C. Frequency Model (Stability) ---
 
52
  class FrequencyPINN(nn.Module):
53
  def __init__(self):
54
  super().__init__()
 
60
  )
61
  def forward(self, x): return self.network(self.fourier(x))
62
 
63
+ # --- D. Load Model (Forecast) ---
64
+ class LoadPINN(nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.fourier = FourierFeatureMapping(input_dim=9, mapping_size=32)
68
+ self.network = nn.Sequential(
69
+ nn.Linear(64, 128), nn.LayerNorm(128), Mish(),
70
+ nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
71
+ nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
72
+ nn.Linear(64, 1)
73
+ )
74
+ def forward(self, x): return self.network(self.fourier(x))
75
+
76
  # ==========================================
77
+ # 2. PHYSICS ENGINES
78
  # ==========================================
79
 
80
  def get_battery_physics_soc(voltage):
 
 
 
 
81
  v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
82
  soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
83
  return np.interp(voltage, v_points, soc_points)
84
 
85
  def get_frequency_physics(data):
 
 
 
 
86
  f_nom = 60.0
87
+ H = max(1.0, data.inertia_h)
 
 
88
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
 
 
89
  freq_nadir = f_nom + (rocof * 2.0)
 
90
  return freq_nadir, rocof
91
 
92
  # ==========================================
 
96
 
97
  @asynccontextmanager
98
  async def lifespan(app: FastAPI):
99
+ print("🚀 STARTING D.E.C.O.D.E. UNIFIED SERVER...")
100
 
101
+ # 1. Load Voltage
102
  try:
103
+ if os.path.exists("voltage_model_v3.pt"):
 
 
104
  ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
105
  model = VoltagePINN()
106
  model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
107
  model.eval()
108
  ml_assets["v_model"] = model
109
+ except Exception as e: print(f"⚠️ Voltage Error: {e}")
 
110
 
111
+ # 2. Load Battery
112
  try:
113
  if os.path.exists("battery_model.joblib"):
114
  raw = joblib.load("battery_model.joblib")
115
+ ml_assets["b_stats"] = raw['stats'] if 'stats' in raw else raw
 
 
 
116
  model = BatteryPINN()
117
+ model.load_state_dict(torch.load("battery_model.pt", map_location='cpu'), strict=False)
118
  model.eval()
119
  ml_assets["b_model"] = model
120
+ except Exception as e: print(f"⚠️ Battery Error: {e}")
 
121
 
122
+ # 3. Load Frequency
123
  try:
124
  if os.path.exists("DECODE_Frequency_Twin.pth"):
125
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
126
  model = FrequencyPINN()
127
+ model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
 
128
  model.eval()
129
  ml_assets["f_model"] = model
 
 
130
  ml_assets["f_mean"] = np.array([60000.0, 30000.0, 30000.0, 0.0])
131
  ml_assets["f_std"] = np.array([20000.0, 15000.0, 15000.0, 10000.0])
132
+ except Exception as e: print(f"⚠️ Frequency Error: {e}")
133
+
134
+ # 4. Load Forecast
135
+ try:
136
+ if os.path.exists("load_model.pt"):
137
+ model = LoadPINN()
138
+ model.load_state_dict(torch.load("load_model.pt", map_location='cpu'), strict=False)
139
+ model.eval()
140
+ ml_assets["l_model"] = model
141
+ stats = joblib.load("Load_stats.joblib")
142
+ ml_assets["l_stats"] = stats
143
+ except Exception as e: print(f"⚠️ Load Error: {e}")
144
 
145
  yield
146
  ml_assets.clear()
 
149
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
150
 
151
  # ==========================================
152
+ # 4. SCHEMAS & ENDPOINTS
153
  # ==========================================
154
 
 
155
  class GridData(BaseModel):
156
+ p_load: float; q_load: float; wind_gen: float; solar_gen: float; hour: int
 
 
 
 
157
 
158
  class BatteryData(BaseModel):
159
+ time_sec: float; current: float; voltage: float; temperature: float; soc_prev: float
 
 
 
 
160
 
161
  class FreqData(BaseModel):
162
+ load_mw: float; wind_mw: float; inertia_h: float; power_imbalance_mw: float
163
+
164
+ class LoadData(BaseModel):
165
+ temperature_c: float; hour: int; month: int; wind_mw: float = 0.0; solar_mw: float = 0.0
166
 
167
  @app.get("/")
168
+ def home():
169
+ return {"status": "D.E.C.O.D.E. Unified Digital Twin Online"}
170
 
 
171
  @app.post("/predict/voltage")
172
  def predict_voltage(data: GridData):
 
173
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
174
+ v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
175
+ return {"voltage_pu": round(v_mag, 4), "status": "Stable" if 0.95 < v_mag < 1.05 else "Critical"}
 
 
 
 
 
176
 
 
 
 
 
 
 
 
177
  @app.post("/predict/battery")
178
  def predict_battery(data: BatteryData):
 
179
  soc_physics = get_battery_physics_soc(data.voltage)
 
 
180
  temp_est = 25.0
181
+ if "b_model" in ml_assets:
182
  try:
 
 
 
 
183
  stats = ml_assets["b_stats"]
184
+ scaled = (np.array([data.time_sec, data.current, data.voltage, data.voltage*data.current, data.soc_prev]) - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
 
185
  with torch.no_grad():
186
  preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
 
187
  temp_est = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
188
  except: pass
189
+ return {"soc": round(float(soc_physics), 2), "temp": round(float(temp_est), 2), "status": "Normal" if temp_est < 45 else "Overheating"}
190
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @app.post("/predict/frequency")
192
  def predict_frequency(data: FreqData):
193
+ f_phys, r_phys = get_frequency_physics(data)
194
+ f_ai = 60.0
 
 
 
195
  if "f_model" in ml_assets:
196
  try:
197
+ x_norm = (np.array([data.load_mw, data.wind_mw, data.load_mw-data.wind_mw, data.power_imbalance_mw]) - ml_assets["f_mean"]) / (ml_assets["f_std"] + 1e-6)
 
 
 
 
 
198
  with torch.no_grad():
199
  preds = ml_assets["f_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
200
+ f_ai = 60.0 + (preds[0] * 0.5)
 
 
 
 
 
 
201
  except: pass
202
+ final_f = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
203
+ return {"frequency_hz": round(float(final_f), 4), "status": "Stable" if final_f > 59.6 else "Critical"}
204
+
205
+ @app.post("/predict/load")
206
+ def predict_load(data: LoadData):
207
+ stats = ml_assets.get("l_stats", {})
208
+ t_mean, t_std = stats.get('temp_mean', 15.38), stats.get('temp_std', 4.12)
209
+ t_norm = max(-3.0, min(3.0, (data.temperature_c - t_mean) / (t_std + 1e-6)))
210
 
211
+ x_norm = np.array([t_norm, max(0, data.temperature_c-18)/10, max(0, 18-data.temperature_c)/10,
212
+ np.sin(2*np.pi*data.hour/24), np.cos(2*np.pi*data.hour/24),
213
+ np.sin(2*np.pi*data.month/12), np.cos(2*np.pi*data.month/12),
214
+ data.wind_mw/10000, data.solar_mw/10000], dtype=np.float32)
215
 
216
+ load_mw = stats.get('load_mean', 35000.0)
217
+ if "l_model" in ml_assets:
218
+ try:
219
+ with torch.no_grad():
220
+ preds = ml_assets["l_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
221
+ load_mw = (preds[0] * stats.get('load_std', 9773.80)) + load_mw
222
+ except: pass
223
 
224
+ # Physics Overrides
225
+ if data.temperature_c > 32 and load_mw < 45000: load_mw = 45000 + (data.temperature_c - 32) * 1200
226
+ elif data.temperature_c < 5 and load_mw < 42000: load_mw = 42000 + (5 - data.temperature_c) * 900
227
+
228
+ return {"predicted_load_mw": round(float(load_mw), 2), "status": "Normal" if load_mw < 58000 else "Peak Load Alert"}