NickPatel17 commited on
Commit
548cba6
Β·
verified Β·
1 Parent(s): cc1279f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +245 -234
main.py CHANGED
@@ -1,234 +1,245 @@
1
- """
2
- SmartContainer Risk Engine β€” FastAPI Backend
3
- =============================================
4
- Offline-Train, Online-Serve architecture.
5
- Models are loaded once at startup and kept in memory for fast inference.
6
-
7
- Workflow:
8
- 1. Run train_offline.py to train and save models to saved_models/
9
- 2. Start this server: uvicorn main:app --reload
10
- 3. POST a CSV to /api/predict-batch β†’ receive final_predictions.csv
11
-
12
- Expected upload schema (no Clearance_Status column):
13
- Container_ID, Declaration_Date (YYYY-MM-DD), Declaration_Time,
14
- Trade_Regime (Import / Export / Transit), Origin_Country,
15
- Destination_Port, Destination_Country, HS_Code, Importer_ID,
16
- Exporter_ID, Declared_Value, Declared_Weight, Measured_Weight,
17
- Shipping_Line, Dwell_Time_Hours
18
- """
19
-
20
- import io
21
- import os
22
- import joblib
23
- import httpx
24
- import pandas as pd
25
- from contextlib import asynccontextmanager
26
- from dotenv import load_dotenv
27
-
28
- load_dotenv()
29
-
30
- from fastapi import FastAPI, File, HTTPException, Query, UploadFile
31
- from fastapi.middleware.cors import CORSMiddleware
32
- from fastapi.responses import StreamingResponse
33
-
34
- # ── News API key (server-side only, never exposed to frontend) ────────────
35
- GNEWS_API_KEY = os.environ.get("GNEWS_API_KEY", "")
36
-
37
- from src.config import TRAIN_PATH
38
- from src.features import preprocess_and_engineer
39
- from src.model import prepare_features, inference_predict, explain_and_save
40
-
41
- # ── Global model / data store (populated at startup) ──────────────────────
42
- _store: dict = {}
43
- SAVED_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "saved_models")
44
-
45
-
46
- @asynccontextmanager
47
- async def lifespan(app: FastAPI):
48
- """Load all artifacts into memory once at startup; release on shutdown."""
49
- print("[Startup] Loading models from saved_models/ ...")
50
- _store["xgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "xgb_model.pkl"))
51
- _store["lgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "lgb_model.pkl"))
52
- _store["cat"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "cat_model.pkl"))
53
- detector = joblib.load(os.path.join(SAVED_MODELS_DIR, "anomaly_detector.pkl"))
54
- _store["iso"] = detector["iso"]
55
- _store["iso_rmin"] = detector["rmin"]
56
- _store["iso_rmax"] = detector["rmax"]
57
-
58
- # Cache Historical Data once β€” copied per-request to avoid in-place mutation.
59
- _store["train_df_raw"] = pd.read_csv(TRAIN_PATH)
60
- print(f"[Startup] Cached train data: {_store['train_df_raw'].shape}")
61
- print("[Startup] All models ready.")
62
- yield
63
- _store.clear()
64
-
65
-
66
- app = FastAPI(
67
- title="SmartContainer Risk Engine",
68
- version="1.0.0",
69
- lifespan=lifespan,
70
- )
71
-
72
- app.add_middleware(
73
- CORSMiddleware,
74
- allow_origins=["*"],
75
- allow_credentials=True,
76
- allow_methods=["*"],
77
- allow_headers=["*"],
78
- )
79
-
80
- @app.get("/api")
81
- def server_status():
82
- return {"status": "ok", "message": "Server is running"}
83
-
84
- @app.get("/health")
85
- async def health():
86
- return {"status": "ok", "artifacts": list(_store.keys())}
87
-
88
-
89
- @app.post("/api/predict-batch")
90
- async def predict_batch(file: UploadFile = File(...)):
91
- """
92
- Accept a container manifest CSV (no Clearance_Status column).
93
- Returns final_predictions.csv as a streaming download.
94
-
95
- Output columns: Container_ID, Risk_Score, Risk_Level, Explanation_Summary
96
- """
97
- if not file.filename.lower().endswith(".csv"):
98
- raise HTTPException(status_code=400, detail="Only .csv files are accepted.")
99
-
100
- # ── Read uploaded test data ───────────────────────────────────────────
101
- contents = await file.read()
102
- try:
103
- test_df = pd.read_csv(io.BytesIO(contents))
104
- except Exception as exc:
105
- raise HTTPException(status_code=400, detail=f"Could not parse CSV: {exc}")
106
-
107
- # ── Fresh copy of cached train data prevents in-place mutation leaking
108
- # across concurrent requests. ─
109
- train_df = _store["train_df_raw"].copy()
110
-
111
- # ── Feature engineering: stats fitted on train_df, mapped to test_df ──
112
- X_train, X_test, y_train, train_ids, test_ids = preprocess_and_engineer(
113
- train_df, test_df
114
- )
115
-
116
- # ── Drop zero-variance Trade_ columns (same step as offline training) ─
117
- X_train, X_test = prepare_features(X_train, X_test)
118
-
119
- # ── Safe index alignment before all downstream ops ─────────────────────
120
- X_test = X_test.reset_index(drop=True)
121
- test_ids = test_ids.reset_index(drop=True)
122
-
123
- # ── Inference: inject anomaly score + weighted ensemble predict ────────
124
- X_test_enriched, proba, predictions, risk_scores = inference_predict(
125
- _store["xgb"],
126
- _store["lgb"],
127
- _store["cat"],
128
- _store["iso"],
129
- _store["iso_rmin"],
130
- _store["iso_rmax"],
131
- X_test,
132
- )
133
-
134
- # ── SHAP explanations via XGBoost + build output DataFrame ────────────
135
- # X_test_enriched already has Anomaly_Score; test_ids is 0-indexed.
136
- output = explain_and_save(
137
- _store["xgb"], X_test_enriched, test_ids, predictions, risk_scores
138
- )
139
-
140
- # Integrity guard: lengths must match before streaming
141
- if len(output) != len(test_ids):
142
- raise HTTPException(
143
- status_code=500,
144
- detail=f"Row count mismatch: output={len(output)}, ids={len(test_ids)}",
145
- )
146
-
147
- # ── Stream result as CSV (index=False β†’ no 'Unnamed: 0' column) ────────
148
- stream = io.StringIO()
149
- output.to_csv(stream, index=False)
150
- stream.seek(0)
151
-
152
- return StreamingResponse(
153
- iter([stream.getvalue()]),
154
- media_type="text/csv",
155
- headers={"Content-Disposition": "attachment; filename=final_predictions.csv"},
156
- )
157
-
158
-
159
- # ═══════════════════════════════════════════════════════════════════════════
160
- # TRADE INTELLIGENCE β€” News endpoint (GNews upstream)
161
- # ═══════════════════════════════════════════════════════════════════════════
162
- _CATEGORY_TERMS = {
163
- "congestion": "congestion",
164
- "shipping": "shipping",
165
- "container": "container",
166
- "trade": "trade",
167
- "terminal": "terminal",
168
- }
169
-
170
-
171
- @app.get("/api/trade/trade-intelligence/news")
172
- async def trade_intelligence_news(
173
- keyword: str = Query(..., min_length=1),
174
- category: str = Query("all"),
175
- limit: int = Query(10, ge=1, le=50),
176
- ):
177
- """
178
- Proxy to GNews API. Maps upstream response to the article schema
179
- expected by the React frontend.
180
- """
181
- if not GNEWS_API_KEY:
182
- raise HTTPException(
183
- status_code=401,
184
- detail="News API key is not configured on the server.",
185
- )
186
-
187
- # Build search query β€” use OR to broaden instead of AND-narrowing
188
- if category != "all" and category in _CATEGORY_TERMS:
189
- search_q = f"{keyword} OR {_CATEGORY_TERMS[category]}"
190
- else:
191
- search_q = keyword
192
-
193
- params = {
194
- "q": search_q,
195
- "language": "en",
196
- "pageSize": str(limit),
197
- "apiKey": GNEWS_API_KEY,
198
- }
199
-
200
- try:
201
- async with httpx.AsyncClient(timeout=15.0) as client:
202
- resp = await client.get("https://newsapi.org/v2/everything", params=params)
203
- except httpx.TimeoutException:
204
- raise HTTPException(status_code=504)
205
- except httpx.RequestError:
206
- raise HTTPException(status_code=502)
207
-
208
- # Map upstream status codes to what the frontend expects
209
- if resp.status_code == 401 or resp.status_code == 403:
210
- raise HTTPException(status_code=401)
211
- if resp.status_code == 429:
212
- raise HTTPException(status_code=429)
213
- if resp.status_code >= 500:
214
- raise HTTPException(status_code=502)
215
- if resp.status_code != 200:
216
- raise HTTPException(status_code=500)
217
-
218
- data = resp.json()
219
- raw_articles = data.get("articles", [])
220
-
221
- articles = [
222
- {
223
- "title": a.get("title", ""),
224
- "description": a.get("description"),
225
- "url": a.get("url", ""),
226
- "image_url": a.get("image"),
227
- "source_name": (a.get("source") or {}).get("name", "Unknown"),
228
- "published_at": a.get("publishedAt", ""),
229
- }
230
- for a in raw_articles
231
- ]
232
-
233
- return {"articles": articles}
234
-
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmartContainer Risk Engine β€” FastAPI Backend
3
+ =============================================
4
+ Offline-Train, Online-Serve architecture.
5
+ Models are loaded once at startup and kept in memory for fast inference.
6
+
7
+ Workflow:
8
+ 1. Run train_offline.py to train and save models to saved_models/
9
+ 2. Start this server: uvicorn main:app --reload
10
+ 3. POST a CSV to /api/predict-batch β†’ receive final_predictions.csv
11
+
12
+ Expected upload schema (no Clearance_Status column):
13
+ Container_ID, Declaration_Date (YYYY-MM-DD), Declaration_Time,
14
+ Trade_Regime (Import / Export / Transit), Origin_Country,
15
+ Destination_Port, Destination_Country, HS_Code, Importer_ID,
16
+ Exporter_ID, Declared_Value, Declared_Weight, Measured_Weight,
17
+ Shipping_Line, Dwell_Time_Hours
18
+ """
19
+ import asyncio
20
+ import io
21
+ import os
22
+ import joblib
23
+ import httpx
24
+ import pandas as pd
25
+ from contextlib import asynccontextmanager
26
+ from dotenv import load_dotenv
27
+
28
+ load_dotenv()
29
+
30
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
31
+ from fastapi.middleware.cors import CORSMiddleware
32
+ from fastapi.responses import StreamingResponse
33
+
34
+ # ── News API key (server-side only, never exposed to frontend) ────────────
35
+ GNEWS_API_KEY = os.environ.get("GNEWS_API_KEY", "")
36
+
37
+ from src.config import TRAIN_PATH
38
+ from src.features import preprocess_and_engineer
39
+ from src.model import prepare_features, inference_predict, explain_and_save
40
+
41
+ # ── Global model / data store (populated at startup) ──────────────────────
42
+ _store: dict = {}
43
+ SAVED_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "saved_models")
44
+
45
+
46
+ @asynccontextmanager
47
+ async def lifespan(app: FastAPI):
48
+ """Load all artifacts into memory once at startup; release on shutdown."""
49
+
50
+ def load_heavy_artifacts():
51
+ print("[Startup] Loading models from saved_models/ ...")
52
+ _store["xgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "xgb_model.pkl"))
53
+ _store["lgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "lgb_model.pkl"))
54
+ _store["cat"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "cat_model.pkl"))
55
+
56
+ detector = joblib.load(os.path.join(SAVED_MODELS_DIR, "anomaly_detector.pkl"))
57
+ _store["iso"] = detector["iso"]
58
+ _store["iso_rmin"] = detector["rmin"]
59
+ _store["iso_rmax"] = detector["rmax"]
60
+
61
+ # Check if the training CSV actually exists on the server!
62
+ print(f"[Startup] Looking for training data at: {TRAIN_PATH}")
63
+ if not os.path.exists(TRAIN_PATH):
64
+ print(f"🚨 FATAL ERROR: The file {TRAIN_PATH} does not exist on Hugging Face! Did you upload the CSV?")
65
+ else:
66
+ _store["train_df_raw"] = pd.read_csv(TRAIN_PATH)
67
+ print(f"[Startup] Cached train data: {_store['train_df_raw'].shape}")
68
+
69
+ print("βœ… [Startup] All models ready!")
70
+
71
+ # Run the heavy loading in a separate thread so Uvicorn doesn't freeze
72
+ await asyncio.to_thread(load_heavy_artifacts)
73
+
74
+ yield
75
+ _store.clear()
76
+
77
+ app = FastAPI(
78
+ title="SmartContainer Risk Engine",
79
+ version="1.0.0",
80
+ lifespan=lifespan,
81
+ )
82
+
83
+ app.add_middleware(
84
+ CORSMiddleware,
85
+ allow_origins=["*"],
86
+ allow_credentials=True,
87
+ allow_methods=["*"],
88
+ allow_headers=["*"],
89
+ )
90
+
91
+ @app.get("/api")
92
+ def server_status():
93
+ return {"status": "ok", "message": "Server is running"}
94
+
95
+ @app.get("/health")
96
+ async def health():
97
+ return {"status": "ok", "artifacts": list(_store.keys())}
98
+
99
+
100
+ @app.post("/api/predict-batch")
101
+ async def predict_batch(file: UploadFile = File(...)):
102
+ """
103
+ Accept a container manifest CSV (no Clearance_Status column).
104
+ Returns final_predictions.csv as a streaming download.
105
+
106
+ Output columns: Container_ID, Risk_Score, Risk_Level, Explanation_Summary
107
+ """
108
+ if not file.filename.lower().endswith(".csv"):
109
+ raise HTTPException(status_code=400, detail="Only .csv files are accepted.")
110
+
111
+ # ── Read uploaded test data ───────────────────────────────────────────
112
+ contents = await file.read()
113
+ try:
114
+ test_df = pd.read_csv(io.BytesIO(contents))
115
+ except Exception as exc:
116
+ raise HTTPException(status_code=400, detail=f"Could not parse CSV: {exc}")
117
+
118
+ # ── Fresh copy of cached train data prevents in-place mutation leaking
119
+ # across concurrent requests. ─
120
+ train_df = _store["train_df_raw"].copy()
121
+
122
+ # ── Feature engineering: stats fitted on train_df, mapped to test_df ──
123
+ X_train, X_test, y_train, train_ids, test_ids = preprocess_and_engineer(
124
+ train_df, test_df
125
+ )
126
+
127
+ # ── Drop zero-variance Trade_ columns (same step as offline training) ─
128
+ X_train, X_test = prepare_features(X_train, X_test)
129
+
130
+ # ── Safe index alignment before all downstream ops ─────────────────────
131
+ X_test = X_test.reset_index(drop=True)
132
+ test_ids = test_ids.reset_index(drop=True)
133
+
134
+ # ── Inference: inject anomaly score + weighted ensemble predict ────────
135
+ X_test_enriched, proba, predictions, risk_scores = inference_predict(
136
+ _store["xgb"],
137
+ _store["lgb"],
138
+ _store["cat"],
139
+ _store["iso"],
140
+ _store["iso_rmin"],
141
+ _store["iso_rmax"],
142
+ X_test,
143
+ )
144
+
145
+ # ── SHAP explanations via XGBoost + build output DataFrame ────────────
146
+ # X_test_enriched already has Anomaly_Score; test_ids is 0-indexed.
147
+ output = explain_and_save(
148
+ _store["xgb"], X_test_enriched, test_ids, predictions, risk_scores
149
+ )
150
+
151
+ # Integrity guard: lengths must match before streaming
152
+ if len(output) != len(test_ids):
153
+ raise HTTPException(
154
+ status_code=500,
155
+ detail=f"Row count mismatch: output={len(output)}, ids={len(test_ids)}",
156
+ )
157
+
158
+ # ── Stream result as CSV (index=False β†’ no 'Unnamed: 0' column) ────────
159
+ stream = io.StringIO()
160
+ output.to_csv(stream, index=False)
161
+ stream.seek(0)
162
+
163
+ return StreamingResponse(
164
+ iter([stream.getvalue()]),
165
+ media_type="text/csv",
166
+ headers={"Content-Disposition": "attachment; filename=final_predictions.csv"},
167
+ )
168
+
169
+
170
+ # ═══════════════════════════════════════════════════════════════════════════
171
+ # TRADE INTELLIGENCE β€” News endpoint (GNews upstream)
172
+ # ═══════════════════════════════════════════════════════════════════════════
173
+ _CATEGORY_TERMS = {
174
+ "congestion": "congestion",
175
+ "shipping": "shipping",
176
+ "container": "container",
177
+ "trade": "trade",
178
+ "terminal": "terminal",
179
+ }
180
+
181
+
182
+ @app.get("/api/trade/trade-intelligence/news")
183
+ async def trade_intelligence_news(
184
+ keyword: str = Query(..., min_length=1),
185
+ category: str = Query("all"),
186
+ limit: int = Query(10, ge=1, le=50),
187
+ ):
188
+ """
189
+ Proxy to GNews API. Maps upstream response to the article schema
190
+ expected by the React frontend.
191
+ """
192
+ if not GNEWS_API_KEY:
193
+ raise HTTPException(
194
+ status_code=401,
195
+ detail="News API key is not configured on the server.",
196
+ )
197
+
198
+ # Build search query β€” use OR to broaden instead of AND-narrowing
199
+ if category != "all" and category in _CATEGORY_TERMS:
200
+ search_q = f"{keyword} OR {_CATEGORY_TERMS[category]}"
201
+ else:
202
+ search_q = keyword
203
+
204
+ params = {
205
+ "q": search_q,
206
+ "language": "en",
207
+ "pageSize": str(limit),
208
+ "apiKey": GNEWS_API_KEY,
209
+ }
210
+
211
+ try:
212
+ async with httpx.AsyncClient(timeout=15.0) as client:
213
+ resp = await client.get("https://newsapi.org/v2/everything", params=params)
214
+ except httpx.TimeoutException:
215
+ raise HTTPException(status_code=504)
216
+ except httpx.RequestError:
217
+ raise HTTPException(status_code=502)
218
+
219
+ # Map upstream status codes to what the frontend expects
220
+ if resp.status_code == 401 or resp.status_code == 403:
221
+ raise HTTPException(status_code=401)
222
+ if resp.status_code == 429:
223
+ raise HTTPException(status_code=429)
224
+ if resp.status_code >= 500:
225
+ raise HTTPException(status_code=502)
226
+ if resp.status_code != 200:
227
+ raise HTTPException(status_code=500)
228
+
229
+ data = resp.json()
230
+ raw_articles = data.get("articles", [])
231
+
232
+ articles = [
233
+ {
234
+ "title": a.get("title", ""),
235
+ "description": a.get("description"),
236
+ "url": a.get("url", ""),
237
+ "image_url": a.get("image"),
238
+ "source_name": (a.get("source") or {}).get("name", "Unknown"),
239
+ "published_at": a.get("publishedAt", ""),
240
+ }
241
+ for a in raw_articles
242
+ ]
243
+
244
+ return {"articles": articles}
245
+