Spaces:
Sleeping
Sleeping
File size: 9,153 Bytes
548cba6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | """
SmartContainer Risk Engine β FastAPI Backend
=============================================
Offline-Train, Online-Serve architecture.
Models are loaded once at startup and kept in memory for fast inference.
Workflow:
1. Run train_offline.py to train and save models to saved_models/
2. Start this server: uvicorn main:app --reload
3. POST a CSV to /api/predict-batch β receive final_predictions.csv
Expected upload schema (no Clearance_Status column):
Container_ID, Declaration_Date (YYYY-MM-DD), Declaration_Time,
Trade_Regime (Import / Export / Transit), Origin_Country,
Destination_Port, Destination_Country, HS_Code, Importer_ID,
Exporter_ID, Declared_Value, Declared_Weight, Measured_Weight,
Shipping_Line, Dwell_Time_Hours
"""
import asyncio
import io
import os
import joblib
import httpx
import pandas as pd
from contextlib import asynccontextmanager
from dotenv import load_dotenv
load_dotenv()
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
# ββ News API key (server-side only, never exposed to frontend) ββββββββββββ
GNEWS_API_KEY = os.environ.get("GNEWS_API_KEY", "")
from src.config import TRAIN_PATH
from src.features import preprocess_and_engineer
from src.model import prepare_features, inference_predict, explain_and_save
# ββ Global model / data store (populated at startup) ββββββββββββββββββββββ
_store: dict = {}
SAVED_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "saved_models")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load all artifacts into memory once at startup; release on shutdown."""
def load_heavy_artifacts():
print("[Startup] Loading models from saved_models/ ...")
_store["xgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "xgb_model.pkl"))
_store["lgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "lgb_model.pkl"))
_store["cat"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "cat_model.pkl"))
detector = joblib.load(os.path.join(SAVED_MODELS_DIR, "anomaly_detector.pkl"))
_store["iso"] = detector["iso"]
_store["iso_rmin"] = detector["rmin"]
_store["iso_rmax"] = detector["rmax"]
# Check if the training CSV actually exists on the server!
print(f"[Startup] Looking for training data at: {TRAIN_PATH}")
if not os.path.exists(TRAIN_PATH):
print(f"π¨ FATAL ERROR: The file {TRAIN_PATH} does not exist on Hugging Face! Did you upload the CSV?")
else:
_store["train_df_raw"] = pd.read_csv(TRAIN_PATH)
print(f"[Startup] Cached train data: {_store['train_df_raw'].shape}")
print("β
[Startup] All models ready!")
# Run the heavy loading in a separate thread so Uvicorn doesn't freeze
await asyncio.to_thread(load_heavy_artifacts)
yield
_store.clear()
app = FastAPI(
title="SmartContainer Risk Engine",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api")
def server_status():
return {"status": "ok", "message": "Server is running"}
@app.get("/health")
async def health():
return {"status": "ok", "artifacts": list(_store.keys())}
@app.post("/api/predict-batch")
async def predict_batch(file: UploadFile = File(...)):
"""
Accept a container manifest CSV (no Clearance_Status column).
Returns final_predictions.csv as a streaming download.
Output columns: Container_ID, Risk_Score, Risk_Level, Explanation_Summary
"""
if not file.filename.lower().endswith(".csv"):
raise HTTPException(status_code=400, detail="Only .csv files are accepted.")
# ββ Read uploaded test data βββββββββββββββββββββββββββββββββββββββββββ
contents = await file.read()
try:
test_df = pd.read_csv(io.BytesIO(contents))
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Could not parse CSV: {exc}")
# ββ Fresh copy of cached train data prevents in-place mutation leaking
# across concurrent requests. β
train_df = _store["train_df_raw"].copy()
# ββ Feature engineering: stats fitted on train_df, mapped to test_df ββ
X_train, X_test, y_train, train_ids, test_ids = preprocess_and_engineer(
train_df, test_df
)
# ββ Drop zero-variance Trade_ columns (same step as offline training) β
X_train, X_test = prepare_features(X_train, X_test)
# ββ Safe index alignment before all downstream ops βββββββββββββββββββββ
X_test = X_test.reset_index(drop=True)
test_ids = test_ids.reset_index(drop=True)
# ββ Inference: inject anomaly score + weighted ensemble predict ββββββββ
X_test_enriched, proba, predictions, risk_scores = inference_predict(
_store["xgb"],
_store["lgb"],
_store["cat"],
_store["iso"],
_store["iso_rmin"],
_store["iso_rmax"],
X_test,
)
# ββ SHAP explanations via XGBoost + build output DataFrame ββββββββββββ
# X_test_enriched already has Anomaly_Score; test_ids is 0-indexed.
output = explain_and_save(
_store["xgb"], X_test_enriched, test_ids, predictions, risk_scores
)
# Integrity guard: lengths must match before streaming
if len(output) != len(test_ids):
raise HTTPException(
status_code=500,
detail=f"Row count mismatch: output={len(output)}, ids={len(test_ids)}",
)
# ββ Stream result as CSV (index=False β no 'Unnamed: 0' column) ββββββββ
stream = io.StringIO()
output.to_csv(stream, index=False)
stream.seek(0)
return StreamingResponse(
iter([stream.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=final_predictions.csv"},
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# TRADE INTELLIGENCE β News endpoint (GNews upstream)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_CATEGORY_TERMS = {
"congestion": "congestion",
"shipping": "shipping",
"container": "container",
"trade": "trade",
"terminal": "terminal",
}
@app.get("/api/trade/trade-intelligence/news")
async def trade_intelligence_news(
keyword: str = Query(..., min_length=1),
category: str = Query("all"),
limit: int = Query(10, ge=1, le=50),
):
"""
Proxy to GNews API. Maps upstream response to the article schema
expected by the React frontend.
"""
if not GNEWS_API_KEY:
raise HTTPException(
status_code=401,
detail="News API key is not configured on the server.",
)
# Build search query β use OR to broaden instead of AND-narrowing
if category != "all" and category in _CATEGORY_TERMS:
search_q = f"{keyword} OR {_CATEGORY_TERMS[category]}"
else:
search_q = keyword
params = {
"q": search_q,
"language": "en",
"pageSize": str(limit),
"apiKey": GNEWS_API_KEY,
}
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get("https://newsapi.org/v2/everything", params=params)
except httpx.TimeoutException:
raise HTTPException(status_code=504)
except httpx.RequestError:
raise HTTPException(status_code=502)
# Map upstream status codes to what the frontend expects
if resp.status_code == 401 or resp.status_code == 403:
raise HTTPException(status_code=401)
if resp.status_code == 429:
raise HTTPException(status_code=429)
if resp.status_code >= 500:
raise HTTPException(status_code=502)
if resp.status_code != 200:
raise HTTPException(status_code=500)
data = resp.json()
raw_articles = data.get("articles", [])
articles = [
{
"title": a.get("title", ""),
"description": a.get("description"),
"url": a.get("url", ""),
"image_url": a.get("image"),
"source_name": (a.get("source") or {}).get("name", "Unknown"),
"published_at": a.get("publishedAt", ""),
}
for a in raw_articles
]
return {"articles": articles}
|