belal243 commited on
Commit
e9c80d8
·
verified ·
1 Parent(s): b84e9ec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -23
main.py CHANGED
@@ -26,7 +26,7 @@ class FourierFeatureMapping(nn.Module):
26
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
 
28
  # ==========================================
29
- # 2. AUDIT-COMPLIANT ARCHITECTURES (FIXED INDENTATION)
30
  # ==========================================
31
  class SolarPINN(nn.Module):
32
  def __init__(self):
@@ -42,7 +42,7 @@ class SolarPINN(nn.Module):
42
  def forward(self, x):
43
  return self.output_layer(self.backbone(x))
44
 
45
- class LoadForecastPINN(nn.Module): # FIXED: Proper class/method separation
46
  def __init__(self):
47
  super().__init__()
48
  self.fourier = FourierFeatureMapping(9, 32)
@@ -91,7 +91,7 @@ class BatteryPINN(nn.Module):
91
  def forward(self, x):
92
  return self.network(self.fourier(x))
93
 
94
- class FrequencyPINN(nn.Module): # CRITICAL FIX: Removed LayerNorm
95
  def __init__(self):
96
  super().__init__()
97
  self.fourier = FourierFeatureMapping(4, 32)
@@ -154,7 +154,7 @@ async def lifespan(app: FastAPI):
154
  if os.path.exists("battery_model.joblib"):
155
  ml_assets["b_stats"] = joblib.load("battery_model.joblib")
156
 
157
- # FREQUENCY MODEL (FIXED ARCHITECTURE)
158
  if os.path.exists("DECODE_Frequency_Twin.pth"):
159
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
160
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
@@ -182,7 +182,7 @@ app.add_middleware(
182
  )
183
 
184
  # ==========================================
185
- # 5. PHYSICS & SCHEMAS
186
  # ==========================================
187
  def get_ocv_soc(voltage: float) -> float:
188
  return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
@@ -192,9 +192,9 @@ class SolarData(BaseModel):
192
  ambient_temp_stream: list[float]
193
  wind_speed_stream: list[float]
194
 
195
- class LoadData(BaseModel):
196
  temperature_c: float
197
- hour: int month: int
198
  wind_mw: float = 0.0
199
  solar_mw: float = 0.0
200
 
@@ -219,7 +219,7 @@ class GridData(BaseModel):
219
  hour: int
220
 
221
  # ==========================================
222
- # 6. ENDPOINTS (SYNTAX-CORRECTED)
223
  # ==========================================
224
  @app.get("/")
225
  def home():
@@ -230,7 +230,7 @@ def home():
230
  }
231
 
232
  @app.post("/predict/solar")
233
- def predict_solar(data: SolarData): # FIXED: Added parameter name
234
  stats = ml_assets.get("solar_stats", {})
235
  curr_temp = data.ambient_temp_stream[0] + 5.0
236
  simulation = []
@@ -244,7 +244,7 @@ def predict_solar(data: SolarData): # FIXED: Added parameter name
244
  (curr_temp - stats["prev_mean"]) / stats["prev_std"]
245
  ]], dtype=torch.float32)
246
  next_temp = ml_assets["solar"](x).item()
247
- next_temp = max(10.0, min(75.0, next_temp)) # PHYSICAL CLAMPING
248
 
249
  efficiency = 0.20 * (1 - 0.004 * (next_temp - 25.0))
250
  power_mw = (5000 * data.irradiance_stream[i] * max(0, efficiency)) / 1e6
@@ -253,15 +253,15 @@ def predict_solar(data: SolarData): # FIXED: Added parameter name
253
  "module_temp_c": round(next_temp, 2),
254
  "power_mw": round(power_mw, 4)
255
  })
256
- curr_temp = next_temp # STATE FEEDBACK (dt=900s)
257
 
258
  return {"simulation": simulation}
259
 
260
  @app.post("/predict/load")
261
- def predict_load(data: LoadData): # FIXED: Added parameter name
262
  stats = ml_assets.get("l_stats", {})
263
  t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
264
- t_norm = max(-3.0, min(3.0, t_norm)) # Z-SCORE CLAMPING
265
 
266
  x = torch.tensor([[
267
  t_norm,
@@ -283,17 +283,16 @@ def predict_load(data: LoadData): # FIXED: Added parameter name
283
  else:
284
  load_mw = base_load
285
 
286
- # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
287
  if data.temperature_c > 32:
288
  load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
289
  elif data.temperature_c < 5:
290
- load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # FIXED PARENTHESIS
291
 
292
  status = "Peak" if load_mw > 58000 else "Normal"
293
  return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
294
 
295
- @app.post("/predict/battery")def predict_battery(data: BatteryData): # FIXED: Added parameter name
296
- stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
297
  power_product = data.voltage * data.current
298
 
299
  features = np.array([
@@ -315,14 +314,12 @@ def predict_load(data: LoadData): # FIXED: Added parameter name
315
  return {"soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2), "status": status}
316
 
317
  @app.post("/predict/frequency")
318
- def predict_frequency(data: FreqData): # FIXED: Added parameter name
319
- # Physics calculation
320
  f_nom = 60.0
321
  H = max(1.0, data.inertia_h)
322
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
323
  f_phys = f_nom + (rocof * 2.0)
324
 
325
- # AI prediction
326
  f_ai = 60.0
327
  if "freq" in ml_assets:
328
  stats = ml_assets["f_stats"]
@@ -341,8 +338,8 @@ def predict_frequency(data: FreqData): # FIXED: Added parameter name
341
  status = "Stable" if final_freq > 59.6 else "Critical"
342
  return {"frequency_hz": round(float(final_freq), 4), "status": status}
343
 
344
- @app.post("/predict/voltage")def predict_voltage(data: GridData): # FIXED: Added parameter name
 
345
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
346
- v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
347
- status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
348
  return {"voltage_pu": round(v_mag, 4), "status": status}
 
26
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
 
28
  # ==========================================
29
+ # 2. AUDIT-COMPLIANT ARCHITECTURES
30
  # ==========================================
31
  class SolarPINN(nn.Module):
32
  def __init__(self):
 
42
  def forward(self, x):
43
  return self.output_layer(self.backbone(x))
44
 
45
+ class LoadForecastPINN(nn.Module):
46
  def __init__(self):
47
  super().__init__()
48
  self.fourier = FourierFeatureMapping(9, 32)
 
91
  def forward(self, x):
92
  return self.network(self.fourier(x))
93
 
94
+ class FrequencyPINN(nn.Module):
95
  def __init__(self):
96
  super().__init__()
97
  self.fourier = FourierFeatureMapping(4, 32)
 
154
  if os.path.exists("battery_model.joblib"):
155
  ml_assets["b_stats"] = joblib.load("battery_model.joblib")
156
 
157
+ # FREQUENCY MODEL
158
  if os.path.exists("DECODE_Frequency_Twin.pth"):
159
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
160
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
 
182
  )
183
 
184
  # ==========================================
185
+ # 5. PHYSICS & SCHEMAS (CRITICAL FIX: FIELD SEPARATION)
186
  # ==========================================
187
  def get_ocv_soc(voltage: float) -> float:
188
  return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
 
192
  ambient_temp_stream: list[float]
193
  wind_speed_stream: list[float]
194
 
195
+ class LoadData(BaseModel): # FIXED: Each field on separate line
196
  temperature_c: float
197
+ hour: int # <-- CRITICAL: Newline after hour month: int # <-- CRITICAL: month on new line
198
  wind_mw: float = 0.0
199
  solar_mw: float = 0.0
200
 
 
219
  hour: int
220
 
221
  # ==========================================
222
+ # 6. ENDPOINTS (CRITICAL FIX: PARAMETER NAMES)
223
  # ==========================================
224
  @app.get("/")
225
  def home():
 
230
  }
231
 
232
  @app.post("/predict/solar")
233
+ def predict_solar(data: SolarData): # FIXED: Added 'data' parameter name
234
  stats = ml_assets.get("solar_stats", {})
235
  curr_temp = data.ambient_temp_stream[0] + 5.0
236
  simulation = []
 
244
  (curr_temp - stats["prev_mean"]) / stats["prev_std"]
245
  ]], dtype=torch.float32)
246
  next_temp = ml_assets["solar"](x).item()
247
+ next_temp = max(10.0, min(75.0, next_temp))
248
 
249
  efficiency = 0.20 * (1 - 0.004 * (next_temp - 25.0))
250
  power_mw = (5000 * data.irradiance_stream[i] * max(0, efficiency)) / 1e6
 
253
  "module_temp_c": round(next_temp, 2),
254
  "power_mw": round(power_mw, 4)
255
  })
256
+ curr_temp = next_temp
257
 
258
  return {"simulation": simulation}
259
 
260
  @app.post("/predict/load")
261
+ def predict_load(data: LoadData): # FIXED: Added 'data' parameter name
262
  stats = ml_assets.get("l_stats", {})
263
  t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
264
+ t_norm = max(-3.0, min(3.0, t_norm))
265
 
266
  x = torch.tensor([[
267
  t_norm,
 
283
  else:
284
  load_mw = base_load
285
 
 
286
  if data.temperature_c > 32:
287
  load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
288
  elif data.temperature_c < 5:
289
+ load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900)
290
 
291
  status = "Peak" if load_mw > 58000 else "Normal"
292
  return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
293
 
294
+ @app.post("/predict/battery")
295
+ def predict_battery(data: BatteryData): # FIXED: Added 'data' parameter name stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
296
  power_product = data.voltage * data.current
297
 
298
  features = np.array([
 
314
  return {"soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2), "status": status}
315
 
316
  @app.post("/predict/frequency")
317
+ def predict_frequency(data: FreqData): # FIXED: Added 'data' parameter name
 
318
  f_nom = 60.0
319
  H = max(1.0, data.inertia_h)
320
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
321
  f_phys = f_nom + (rocof * 2.0)
322
 
 
323
  f_ai = 60.0
324
  if "freq" in ml_assets:
325
  stats = ml_assets["f_stats"]
 
338
  status = "Stable" if final_freq > 59.6 else "Critical"
339
  return {"frequency_hz": round(float(final_freq), 4), "status": status}
340
 
341
+ @app.post("/predict/voltage")
342
+ def predict_voltage(data: GridData): # FIXED: Added 'data' parameter name
343
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
344
+ v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015) status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
 
345
  return {"voltage_pu": round(v_mag, 4), "status": status}