belal243 commited on
Commit
ca0fe02
·
verified ·
1 Parent(s): cac243b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +161 -29
main.py CHANGED
@@ -4,26 +4,46 @@ import numpy as np
4
  import joblib
5
  import random
6
  import os
 
 
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
 
10
  from contextlib import asynccontextmanager
 
11
 
12
- # --- 1. MODEL ARCHITECTURE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class Mish(nn.Module):
14
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
15
 
16
  class FourierFeatureMapping(nn.Module):
17
- def __init__(self, input_dim=7, mapping_size=32, scale=10.0):
18
  super().__init__()
19
  self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
20
  def forward(self, x):
21
  proj = 2 * np.pi * (x @ self.B)
22
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
23
 
 
24
  class VoltagePINN(nn.Module):
25
  def __init__(self):
26
  super().__init__()
 
27
  self.fourier = FourierFeatureMapping(input_dim=7, mapping_size=32)
28
  self.network = nn.Sequential(
29
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
@@ -38,41 +58,84 @@ class VoltagePINN(nn.Module):
38
  def forward(self, x):
39
  return self.network(self.fourier(x))
40
 
41
- # --- 2. ASSETS ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ml_assets = {}
43
 
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
46
- # Load Scaler
 
 
47
  try:
48
- scaler = joblib.load("scaling_stats_v3.joblib")
49
- ml_assets["scaler"] = scaler
50
- print("✅ Scaler Loaded")
51
- except: print("⚠️ Scaler not found")
 
 
 
 
 
 
 
 
52
 
53
- # Load Model
54
  try:
55
- checkpoint = torch.load("voltage_model_v3.pt", map_location='cpu')
56
- state_dict = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) else checkpoint
57
- model = VoltagePINN()
58
- model.load_state_dict(state_dict)
59
- model.eval()
60
- ml_assets["model"] = model
61
- print("✅ PINN Model Loaded")
62
- except: print("⚠️ Model not found")
 
 
 
 
 
 
 
 
 
 
63
  yield
64
  ml_assets.clear()
65
 
66
- app = FastAPI(title="D.E.C.O.D.E. API", lifespan=lifespan)
 
67
 
68
- # CORS (Essential for your Dashboard to work)
69
- app.add_middleware(
70
- CORSMiddleware,
71
- allow_origins=["*"],
72
- allow_credentials=True,
73
- allow_methods=["*"],
74
- allow_headers=["*"],
75
- )
76
 
77
  class GridData(BaseModel):
78
  p_load: float
@@ -81,11 +144,20 @@ class GridData(BaseModel):
81
  solar_gen: float
82
  hour: int
83
 
 
 
 
 
 
 
 
84
  @app.get("/")
85
- def home(): return {"status": "D.E.C.O.D.E. Online", "version": "Hybrid-v3"}
 
86
 
87
- @app.post("/predict")
88
- def predict(data: GridData):
 
89
  # Hybrid Logic (Physics-Informed Fallback for Stability)
90
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
91
 
@@ -107,3 +179,63 @@ def predict(data: GridData):
107
  "status": status,
108
  "net_load": round(net_load, 2)
109
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import joblib
5
  import random
6
  import os
7
+ import threading
8
+ import time
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
+ from pyngrok import ngrok
13
  from contextlib import asynccontextmanager
14
+ import uvicorn
15
 
16
+ # ==========================================
17
+ # 0. CLEANUP & AUTH
18
+ # ==========================================
19
+ print("🧹 Cleaning up old sessions...")
20
+ ngrok.kill()
21
+ os.system("fuser -k 8000/tcp")
22
+
23
+ # Using the token from the second script
24
+ NGROK_TOKEN = "38wFP4o09vV9sgwJUoWM2euBk7J_2WnPVQ4XfT7pT9ndARhwJ"
25
+ ngrok.set_auth_token(NGROK_TOKEN)
26
+ print("🔑 New Ngrok Token Applied.")
27
+
28
+ # ==========================================
29
+ # 1. MODEL ARCHITECTURES
30
+ # ==========================================
31
  class Mish(nn.Module):
32
  def forward(self, x): return x * torch.tanh(nn.functional.softplus(x))
33
 
34
  class FourierFeatureMapping(nn.Module):
35
+ def __init__(self, input_dim, mapping_size, scale=10.0):
36
  super().__init__()
37
  self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
38
  def forward(self, x):
39
  proj = 2 * np.pi * (x @ self.B)
40
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
41
 
42
+ # --- Voltage Model (Grid) ---
43
  class VoltagePINN(nn.Module):
44
  def __init__(self):
45
  super().__init__()
46
+ # Script 1: input_dim=7, mapping_size=32
47
  self.fourier = FourierFeatureMapping(input_dim=7, mapping_size=32)
48
  self.network = nn.Sequential(
49
  nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
 
58
  def forward(self, x):
59
  return self.network(self.fourier(x))
60
 
61
+ # --- Battery Model (Storage) ---
62
+ class BatteryPINN(nn.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+ # Script 2: input_dim=5, mapping_size=12
66
+ self.fourier = FourierFeatureMapping(input_dim=5, mapping_size=12)
67
+ self.network = nn.Sequential(
68
+ nn.Linear(24, 64), Mish(),
69
+ nn.Linear(64, 64), Mish(),
70
+ nn.Linear(64, 3)
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.network(self.fourier(x))
75
+
76
+ # ==========================================
77
+ # 2. PHYSICS ENGINE (OCV Curve)
78
+ # ==========================================
79
+ def get_physics_soc(voltage):
80
+ # Standard Li-ion OCV Curve (NMC Chemistry)
81
+ # Voltage Points: [3.0, 3.2, 3.4, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2]
82
+ v_points = [2.8, 3.0, 3.2, 3.4, 3.55, 3.65, 3.75, 3.85, 3.95, 4.1, 4.2, 4.3]
83
+ soc_points = [0, 0, 5, 15, 35, 50, 65, 75, 85, 92, 100, 100]
84
+ return np.interp(voltage, v_points, soc_points)
85
+
86
+ # ==========================================
87
+ # 3. ASSET LOADING
88
+ # ==========================================
89
  ml_assets = {}
90
 
91
  @asynccontextmanager
92
  async def lifespan(app: FastAPI):
93
+ print("🚀 STARTING D.E.C.O.D.E. UNIFIED SERVER...")
94
+
95
+ # --- Load Voltage Assets ---
96
  try:
97
+ if os.path.exists("scaling_stats_v3.joblib"):
98
+ ml_assets["v_scaler"] = joblib.load("scaling_stats_v3.joblib")
99
+
100
+ ckpt_v = torch.load("voltage_model_v3.pt", map_location='cpu')
101
+ state_dict = ckpt_v['model_state_dict'] if isinstance(ckpt_v, dict) else ckpt_v
102
+ model_v = VoltagePINN()
103
+ model_v.load_state_dict(state_dict)
104
+ model_v.eval()
105
+ ml_assets["v_model"] = model_v
106
+ print("✅ Grid Module: Loaded")
107
+ else: print("⚠️ Grid files missing")
108
+ except Exception as e: print(f"⚠️ Grid Error: {e}")
109
 
110
+ # --- Load Battery Assets ---
111
  try:
112
+ if os.path.exists("battery_model.joblib"):
113
+ raw_b = joblib.load("battery_model.joblib")
114
+ stats_b = raw_b['stats'] if 'stats' in raw_b else raw_b
115
+ ml_assets["b_x_mean"] = stats_b['feature_mean']
116
+ ml_assets["b_x_std"] = stats_b['feature_std']
117
+ ml_assets["b_y_mean"] = stats_b['target_mean']
118
+ ml_assets["b_y_std"] = stats_b['target_std']
119
+
120
+ ckpt_b = torch.load("battery_model.pt", map_location='cpu')
121
+ model_b = BatteryPINN()
122
+ state_dict = ckpt_b if isinstance(ckpt_b, dict) else ckpt_b.state_dict()
123
+ model_b.load_state_dict(state_dict, strict=False)
124
+ model_b.eval()
125
+ ml_assets["b_model"] = model_b
126
+ print("✅ Battery Module: Loaded")
127
+ else: print("⚠️ Battery files missing")
128
+ except Exception as e: print(f"⚠️ Battery Error: {e}")
129
+
130
  yield
131
  ml_assets.clear()
132
 
133
+ app = FastAPI(title="D.E.C.O.D.E. Unified API", lifespan=lifespan)
134
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
135
 
136
+ # ==========================================
137
+ # 4. ENDPOINTS
138
+ # ==========================================
 
 
 
 
 
139
 
140
  class GridData(BaseModel):
141
  p_load: float
 
144
  solar_gen: float
145
  hour: int
146
 
147
+ class BatteryData(BaseModel):
148
+ time_sec: float
149
+ current: float
150
+ voltage: float
151
+ temperature: float
152
+ soc_prev: float
153
+
154
  @app.get("/")
155
+ def home():
156
+ return {"status": "D.E.C.O.D.E. Hybrid Online", "modules": ["Grid", "Battery"]}
157
 
158
+ # --- Endpoint 1: Grid Voltage (from Script 1) ---
159
+ @app.post("/predict/voltage")
160
+ def predict_voltage(data: GridData):
161
  # Hybrid Logic (Physics-Informed Fallback for Stability)
162
  net_load = data.p_load - (data.wind_gen + data.solar_gen)
163
 
 
179
  "status": status,
180
  "net_load": round(net_load, 2)
181
  }
182
+
183
+ # --- Endpoint 2: Battery (from Script 2) ---
184
+ @app.post("/predict/battery")
185
+ def predict_battery(data: BatteryData):
186
+ # A. PHYSICS LAYER (SoC)
187
+ soc_physics = get_physics_soc(data.voltage)
188
+
189
+ # B. AI LAYER (Temp & Health)
190
+ # Calculate Power Input for the model (The Fix)
191
+ power_calc = data.voltage * data.current
192
+
193
+ raw_input = np.array([data.time_sec, data.current, data.voltage, power_calc, data.soc_prev])
194
+ x_mean = ml_assets.get("b_x_mean", np.zeros(5))
195
+ x_std = ml_assets.get("b_x_std", np.ones(5))
196
+ scaled = (raw_input - x_mean) / (x_std + 1e-6)
197
+
198
+ temp_est = 25.0
199
+
200
+ if "b_model" in ml_assets:
201
+ with torch.no_grad():
202
+ preds = ml_assets["b_model"](torch.tensor([scaled], dtype=torch.float32)).numpy()[0]
203
+ y_mean = ml_assets.get("b_y_mean", np.zeros(3))
204
+ y_std = ml_assets.get("b_y_std", np.ones(3))
205
+ real_vals = preds * y_std + y_mean
206
+
207
+ # Extract AI Predictions
208
+ temp_est = real_vals[1]
209
+
210
+ # C. STATUS LOGIC
211
+ status = "Normal"
212
+ if soc_physics < 20: status = "Low Battery"
213
+ if temp_est > 45: status = "Overheating"
214
+
215
+ return {
216
+ "soc": round(float(soc_physics), 2),
217
+ "temp": round(float(temp_est), 2),
218
+ "status": status
219
+ }
220
+
221
+ # ==========================================
222
+ # 5. LAUNCH
223
+ # ==========================================
224
+ def run_server():
225
+ config = uvicorn.Config(app, port=8000, log_level="error")
226
+ server = uvicorn.Server(config)
227
+ server.install_signal_handlers = lambda: None
228
+ server.run()
229
+
230
+ t = threading.Thread(target=run_server)
231
+ t.start()
232
+
233
+ print("⏳ Initializing D.E.C.O.D.E. Unified API...", end="")
234
+ time.sleep(5)
235
+
236
+ try:
237
+ public_url = ngrok.connect(8000).public_url
238
+ print(f"\n🚀 API LIVE: {public_url}")
239
+ print(f"🔗 DOCS: {public_url}/docs")
240
+ except Exception as e:
241
+ print(f"\n❌ Connection Error: {e}")