belal243 commited on
Commit
adda067
·
verified ·
1 Parent(s): 13224c7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -153
main.py CHANGED
@@ -10,8 +10,9 @@ from pydantic import BaseModel
10
  from contextlib import asynccontextmanager
11
 
12
  # ==========================================
13
- # 1. SHARED MODEL ARCHITECTURES
14
  # ==========================================
 
15
  class Mish(nn.Module):
16
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
17
 
@@ -23,11 +24,45 @@ class FourierFeatureMapping(nn.Module):
23
  proj = 2 * np.pi * (x @ self.B)
24
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
25
 
26
- # --- A. Voltage Model (Grid) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class VoltagePINN(nn.Module):
28
  def __init__(self):
29
  super().__init__()
30
- self.fourier = FourierFeatureMapping(input_dim=7, mapping_size=32)
31
  self.network = nn.Sequential(
32
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
33
  nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
@@ -36,11 +71,11 @@ class VoltagePINN(nn.Module):
36
  )
37
  def forward(self, x): return self.network(self.fourier(x))
38
 
39
- # --- B. Battery Model (Storage) ---
40
  class BatteryPINN(nn.Module):
41
  def __init__(self):
42
  super().__init__()
43
- self.fourier = FourierFeatureMapping(input_dim=5, mapping_size=12)
44
  self.network = nn.Sequential(
45
  nn.Linear(24, 64), Mish(),
46
  nn.Linear(64, 64), Mish(),
@@ -48,100 +83,45 @@ class BatteryPINN(nn.Module):
48
  )
49
  def forward(self, x): return self.network(self.fourier(x))
50
 
51
- # --- C. Frequency Model (Stability) ---
52
  class FrequencyPINN(nn.Module):
53
  def __init__(self):
54
  super().__init__()
55
- self.fourier = FourierFeatureMapping(input_dim=4, mapping_size=32)
56
- self.network = nn.Sequential(
57
  nn.Linear(64, 128), nn.LayerNorm(128), Mish(),
58
  nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
59
- nn.Linear(128, 2)
60
- )
61
- def forward(self, x): return self.network(self.fourier(x))
62
-
63
- # --- D. Load Model (Forecast) ---
64
- class LoadPINN(nn.Module):
65
- def __init__(self):
66
- super().__init__()
67
- self.fourier = FourierFeatureMapping(input_dim=9, mapping_size=32)
68
- self.network = nn.Sequential(
69
- nn.Linear(64, 128), nn.LayerNorm(128), Mish(),
70
  nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
71
- nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
72
- nn.Linear(64, 1)
73
  )
74
- def forward(self, x): return self.network(self.fourier(x))
75
-
76
- # ==========================================
77
- # 2. PHYSICS ENGINES
78
- # ==========================================
79
-
80
- def get_battery_physics_soc(voltage):
81
- v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
82
- soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
83
- return np.interp(voltage, v_points, soc_points)
84
-
85
- def get_frequency_physics(data):
86
- f_nom = 60.0
87
- H = max(1.0, data.inertia_h)
88
- rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
89
- freq_nadir = f_nom + (rocof * 2.0)
90
- return freq_nadir, rocof
91
 
92
  # ==========================================
93
- # 3. ASSET LOADING (LIFESPAN)
94
  # ==========================================
95
  ml_assets = {}
96
 
97
  @asynccontextmanager
98
  async def lifespan(app: FastAPI):
99
- print("🚀 STARTING D.E.C.O.D.E. UNIFIED SERVER...")
100
-
101
- # 1. Load Voltage
102
- try:
103
- if os.path.exists("voltage_model_v3.pt"):
104
- ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
105
- model = VoltagePINN()
106
- model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
107
- model.eval()
108
- ml_assets["v_model"] = model
109
- except Exception as e: print(f"⚠️ Voltage Error: {e}")
110
-
111
- # 2. Load Battery
112
- try:
113
- if os.path.exists("battery_model.joblib"):
114
- raw = joblib.load("battery_model.joblib")
115
- ml_assets["b_stats"] = raw['stats'] if 'stats' in raw else raw
116
- model = BatteryPINN()
117
- model.load_state_dict(torch.load("battery_model.pt", map_location='cpu'), strict=False)
118
- model.eval()
119
- ml_assets["b_model"] = model
120
- except Exception as e: print(f"⚠️ Battery Error: {e}")
121
-
122
- # 3. Load Frequency
123
- try:
124
- if os.path.exists("DECODE_Frequency_Twin.pth"):
125
- ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
126
- model = FrequencyPINN()
127
- model.load_state_dict(ckpt['model_state_dict'] if isinstance(ckpt, dict) else ckpt, strict=False)
128
- model.eval()
129
- ml_assets["f_model"] = model
130
- ml_assets["f_mean"] = np.array([60000.0, 30000.0, 30000.0, 0.0])
131
- ml_assets["f_std"] = np.array([20000.0, 15000.0, 15000.0, 10000.0])
132
- except Exception as e: print(f"⚠️ Frequency Error: {e}")
133
-
134
- # 4. Load Forecast
135
- try:
136
- if os.path.exists("load_model.pt"):
137
- model = LoadPINN()
138
- model.load_state_dict(torch.load("load_model.pt", map_location='cpu'), strict=False)
139
- model.eval()
140
- ml_assets["l_model"] = model
141
- stats = joblib.load("Load_stats.joblib")
142
- ml_assets["l_stats"] = stats
143
- except Exception as e: print(f"⚠️ Load Error: {e}")
144
-
145
  yield
146
  ml_assets.clear()
147
 
@@ -149,80 +129,60 @@ app = FastAPI(title="D.E.C.O.D.E. Unified API", lifespan=lifespan)
149
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
150
 
151
  # ==========================================
152
- # 4. SCHEMAS & ENDPOINTS
153
  # ==========================================
 
154
 
155
- class GridData(BaseModel):
156
- p_load: float; q_load: float; wind_gen: float; solar_gen: float; hour: int
157
-
158
- class BatteryData(BaseModel):
159
- time_sec: float; current: float; voltage: float; temperature: float; soc_prev: float
160
-
161
- class FreqData(BaseModel):
162
- load_mw: float; wind_mw: float; inertia_h: float; power_imbalance_mw: float
163
 
164
- class LoadData(BaseModel):
165
- temperature_c: float; hour: int; month: int; wind_mw: float = 0.0; solar_mw: float = 0.0
166
-
167
- @app.get("/")
168
- def home():
169
- return {"status": "D.E.C.O.D.E. Unified Digital Twin Online"}
170
 
171
- @app.post("/predict/voltage")
172
- def predict_voltage(data: GridData):
173
- net_load = data.p_load - (data.wind_gen + data.solar_gen)
174
- v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
175
- return {"voltage_pu": round(v_mag, 4), "status": "Stable" if 0.95 < v_mag < 1.05 else "Critical"}
176
-
177
- @app.post("/predict/battery")
178
- def predict_battery(data: BatteryData):
179
- soc_physics = get_battery_physics_soc(data.voltage)
180
- temp_est = 25.0
181
- if "b_model" in ml_assets:
182
- try:
183
- stats = ml_assets["b_stats"]
184
- scaled = (np.array([data.time_sec, data.current, data.voltage, data.voltage*data.current, data.soc_prev]) - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
185
- with torch.no_grad():
186
- preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
187
- temp_est = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
188
- except: pass
189
- return {"soc": round(float(soc_physics), 2), "temp": round(float(temp_est), 2), "status": "Normal" if temp_est < 45 else "Overheating"}
190
-
191
- @app.post("/predict/frequency")
192
- def predict_frequency(data: FreqData):
193
- f_phys, r_phys = get_frequency_physics(data)
194
- f_ai = 60.0
195
- if "f_model" in ml_assets:
196
- try:
197
- x_norm = (np.array([data.load_mw, data.wind_mw, data.load_mw-data.wind_mw, data.power_imbalance_mw]) - ml_assets["f_mean"]) / (ml_assets["f_std"] + 1e-6)
198
- with torch.no_grad():
199
- preds = ml_assets["f_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
200
- f_ai = 60.0 + (preds[0] * 0.5)
201
- except: pass
202
- final_f = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
203
- return {"frequency_hz": round(float(final_f), 4), "status": "Stable" if final_f > 59.6 else "Critical"}
204
 
205
  @app.post("/predict/load")
206
  def predict_load(data: LoadData):
207
- stats = ml_assets.get("l_stats", {})
208
- t_mean, t_std = stats.get('temp_mean', 15.38), stats.get('temp_std', 4.12)
209
- t_norm = max(-3.0, min(3.0, (data.temperature_c - t_mean) / (t_std + 1e-6)))
210
-
211
- x_norm = np.array([t_norm, max(0, data.temperature_c-18)/10, max(0, 18-data.temperature_c)/10,
212
  np.sin(2*np.pi*data.hour/24), np.cos(2*np.pi*data.hour/24),
213
  np.sin(2*np.pi*data.month/12), np.cos(2*np.pi*data.month/12),
214
- data.wind_mw/10000, data.solar_mw/10000], dtype=np.float32)
215
-
216
- load_mw = stats.get('load_mean', 35000.0)
217
- if "l_model" in ml_assets:
218
- try:
219
- with torch.no_grad():
220
- preds = ml_assets["l_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
221
- load_mw = (preds[0] * stats.get('load_std', 9773.80)) + load_mw
222
- except: pass
223
-
224
- # Physics Overrides
225
- if data.temperature_c > 32 and load_mw < 45000: load_mw = 45000 + (data.temperature_c - 32) * 1200
226
- elif data.temperature_c < 5 and load_mw < 42000: load_mw = 42000 + (5 - data.temperature_c) * 900
227
-
228
- return {"predicted_load_mw": round(float(load_mw), 2), "status": "Normal" if load_mw < 58000 else "Peak Load Alert"}
 
 
 
 
 
 
10
  from contextlib import asynccontextmanager
11
 
12
  # ==========================================
13
+ # 1. UNIQUE ARCHITECTURES (PER AUDIT)
14
  # ==========================================
15
+
16
  class Mish(nn.Module):
17
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
18
 
 
24
  proj = 2 * np.pi * (x @ self.B)
25
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
26
 
27
+ # --- Solar: 128-neuron State-Space PINN ---
28
+ class SolarPINN(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.backbone = nn.Sequential(
32
+ nn.Linear(4, 128), Mish(),
33
+ nn.Linear(128, 128), Mish(),
34
+ nn.Linear(128, 1)
35
+ )
36
+ # Physics Tensors from Audit
37
+ self.log_thermal_mass = nn.Parameter(torch.tensor(7.1546))
38
+ self.log_h_conv = nn.Parameter(torch.tensor(1.8767))
39
+ def forward(self, x): return self.backbone(x)
40
+
41
+ # --- Load: Fourier Residual Architecture ---
42
+ class LoadForecastPINN(nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.fourier = FourierFeatureMapping(9, 32)
46
+ self.input_layer = nn.Linear(64, 128)
47
+ self.res_blocks = nn.ModuleList([
48
+ nn.Sequential(
49
+ nn.Linear(128, 128),
50
+ nn.BatchNorm1d(128),
51
+ Mish(),
52
+ nn.Linear(128, 128)
53
+ ) for _ in range(3)
54
+ ])
55
+ self.output_layer = nn.Linear(128, 1)
56
+ def forward(self, x):
57
+ x = self.input_layer(self.fourier(x))
58
+ for block in self.res_blocks: x = x + block(x)
59
+ return self.output_layer(x)
60
+
61
+ # --- Voltage: 256-dim Multi-Layer PINN ---
62
  class VoltagePINN(nn.Module):
63
  def __init__(self):
64
  super().__init__()
65
+ self.fourier = FourierFeatureMapping(7, 32)
66
  self.network = nn.Sequential(
67
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
68
  nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
 
71
  )
72
  def forward(self, x): return self.network(self.fourier(x))
73
 
74
+ # --- Battery: 24-dim Linear PINN ---
75
  class BatteryPINN(nn.Module):
76
  def __init__(self):
77
  super().__init__()
78
+ self.fourier = FourierFeatureMapping(5, 12)
79
  self.network = nn.Sequential(
80
  nn.Linear(24, 64), Mish(),
81
  nn.Linear(64, 64), Mish(),
 
83
  )
84
  def forward(self, x): return self.network(self.fourier(x))
85
 
86
+ # --- Frequency: Stability Twin ---
87
  class FrequencyPINN(nn.Module):
88
  def __init__(self):
89
  super().__init__()
90
+ self.fourier = FourierFeatureMapping(4, 32)
91
+ self.net = nn.Sequential(
92
  nn.Linear(64, 128), nn.LayerNorm(128), Mish(),
93
  nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
 
 
 
 
 
 
 
 
 
 
 
94
  nn.Linear(128, 128), nn.LayerNorm(128), Mish(),
95
+ nn.Linear(128, 2)
 
96
  )
97
+ def forward(self, x): return self.net(self.fourier(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # ==========================================
100
+ # 2. ASSET LOADING (LIFESPAN)
101
  # ==========================================
102
  ml_assets = {}
103
 
104
  @asynccontextmanager
105
  async def lifespan(app: FastAPI):
106
+ # Load Models based on Audit Shapes
107
+ loaders = {
108
+ "solar": ("solar_model.pt", SolarPINN()),
109
+ "load": ("load_model.pt", LoadForecastPINN()),
110
+ "voltage": ("voltage_model_v3.pt", VoltagePINN()),
111
+ "battery": ("battery_model.pt", BatteryPINN()),
112
+ "freq": ("DECODE_Frequency_Twin.pth", FrequencyPINN())
113
+ }
114
+ for key, (path, model) in loaders.items():
115
+ if os.path.exists(path):
116
+ ckpt = torch.load(path, map_location='cpu')
117
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
118
+ model.load_state_dict(sd, strict=False)
119
+ ml_assets[key] = model.eval()
120
+
121
+ # Load All Stats
122
+ if os.path.exists("Load_stats.joblib"): ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
123
+ if os.path.exists("battery_model.joblib"): ml_assets["b_stats"] = joblib.load("battery_model.joblib")
124
+ if os.path.exists("scaling_stats_v3.joblib"): ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  yield
126
  ml_assets.clear()
127
 
 
129
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
130
 
131
  # ==========================================
132
+ # 3. SCHEMAS & PHYSICS
133
  # ==========================================
134
+ def get_ocv_soc(v): return np.interp(v, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
135
 
136
+ class SolarData(BaseModel): irradiance_stream: list[float]; ambient_temp_stream: list[float]; wind_speed_stream: list[float]
137
+ class LoadData(BaseModel): temperature_c: float; hour: int; month: int; wind_mw: float = 0; solar_mw: float = 0
138
+ class BatteryData(BaseModel): time_sec: float; current: float; voltage: float; temperature: float; soc_prev: float
139
+ class FreqData(BaseModel): load_mw: float; wind_mw: float; inertia_h: float; power_imbalance_mw: float
140
+ class GridData(BaseModel): p_load: float; q_load: float; wind_gen: float; solar_gen: float; hour: int
 
 
 
141
 
142
+ # ==========================================
143
+ # 4. CALIBRATED ENDPOINTS
144
+ # ==========================================
 
 
 
145
 
146
+ @app.post("/predict/solar")
147
+ def predict_solar(data: SolarData):
148
+ # Constraint: Recursive Simulation @ 900s dt
149
+ curr_temp = data.ambient_temp_stream[0] + 5.0
150
+ sim = []
151
+ with torch.no_grad():
152
+ for i in range(len(data.irradiance_stream)):
153
+ x = torch.tensor([[(data.irradiance_stream[i]-450)/250, (data.ambient_temp_stream[i]-25)/10,
154
+ data.wind_speed_stream[i]/10.0, (curr_temp-35)/15]], dtype=torch.float32)
155
+ # Physical Clamping
156
+ next_t = max(10.0, min(75.0, ml_assets["solar"](x).item()))
157
+ eff = 0.20 * (1 - 0.004 * (next_t - 25.0))
158
+ sim.append({"temp": round(next_t, 2), "mw": round((5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6, 4)})
159
+ curr_temp = next_t
160
+ return {"simulation": sim}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  @app.post("/predict/load")
163
  def predict_load(data: LoadData):
164
+ # Constraint: Hard Z-Score Clamping at +/-3 to prevent Inverted Load Paradox
165
+ t_norm = max(-3.0, min(3.0, (data.temperature_c - 15.38) / 4.12))
166
+ x = torch.tensor([[t_norm, max(0, data.temperature_c-18)/10, max(0, 18-data.temperature_c)/10,
 
 
167
  np.sin(2*np.pi*data.hour/24), np.cos(2*np.pi*data.hour/24),
168
  np.sin(2*np.pi*data.month/12), np.cos(2*np.pi*data.month/12),
169
+ data.wind_mw/10000, data.solar_mw/10000]], dtype=torch.float32)
170
+ load_mw = 35000.0
171
+ if "load" in ml_assets:
172
+ with torch.no_grad(): load_mw = (ml_assets["load"](x).item() * 9773.8) + 35000.0
173
+ # Physical Safety Correction
174
+ if data.temperature_c > 32: load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
175
+ elif data.temperature_c < 5: load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900)
176
+ return {"mw": round(load_mw, 2)}
177
+
178
+ @app.post("/predict/battery")
179
+ def predict_battery(data: BatteryData):
180
+ # Constraint: Feature Engineering (Power Product V*I)
181
+ p_prod = data.voltage * data.current
182
+ stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
183
+ raw = np.array([data.time_sec, data.current, data.voltage, p_prod, data.soc_prev])
184
+ x_scaled = (raw - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
185
+ with torch.no_grad():
186
+ preds = ml_assets["battery"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
187
+ temp = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
188
+ return {"soc": round(get_ocv_soc(data.voltage), 2), "temp": round(temp, 2)}