File size: 9,153 Bytes
548cba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
SmartContainer Risk Engine β€” FastAPI Backend
=============================================
Offline-Train, Online-Serve architecture.
Models are loaded once at startup and kept in memory for fast inference.

Workflow:
    1. Run train_offline.py to train and save models to saved_models/
    2. Start this server:  uvicorn main:app --reload
    3. POST a CSV to /api/predict-batch  β†’  receive final_predictions.csv

Expected upload schema (no Clearance_Status column):
    Container_ID, Declaration_Date (YYYY-MM-DD), Declaration_Time,
    Trade_Regime (Import / Export / Transit), Origin_Country,
    Destination_Port, Destination_Country, HS_Code, Importer_ID,
    Exporter_ID, Declared_Value, Declared_Weight, Measured_Weight,
    Shipping_Line, Dwell_Time_Hours
"""
import asyncio
import io
import os
import joblib
import httpx
import pandas as pd
from contextlib import asynccontextmanager
from dotenv import load_dotenv

load_dotenv()

from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

# ── News API key (server-side only, never exposed to frontend) ────────────
GNEWS_API_KEY = os.environ.get("GNEWS_API_KEY", "")

from src.config import TRAIN_PATH
from src.features import preprocess_and_engineer
from src.model import prepare_features, inference_predict, explain_and_save

# ── Global model / data store (populated at startup) ──────────────────────
_store: dict = {}
SAVED_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "saved_models")


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load all artifacts into memory once at startup; release on shutdown."""
    
    def load_heavy_artifacts():
        print("[Startup] Loading models from saved_models/ ...")
        _store["xgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "xgb_model.pkl"))
        _store["lgb"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "lgb_model.pkl"))
        _store["cat"] = joblib.load(os.path.join(SAVED_MODELS_DIR, "cat_model.pkl"))
        
        detector = joblib.load(os.path.join(SAVED_MODELS_DIR, "anomaly_detector.pkl"))
        _store["iso"]      = detector["iso"]
        _store["iso_rmin"] = detector["rmin"]
        _store["iso_rmax"] = detector["rmax"]

        # Check if the training CSV actually exists on the server!
        print(f"[Startup] Looking for training data at: {TRAIN_PATH}")
        if not os.path.exists(TRAIN_PATH):
            print(f"🚨 FATAL ERROR: The file {TRAIN_PATH} does not exist on Hugging Face! Did you upload the CSV?")
        else:
            _store["train_df_raw"] = pd.read_csv(TRAIN_PATH)
            print(f"[Startup] Cached train data: {_store['train_df_raw'].shape}")
            
        print("βœ… [Startup] All models ready!")

    # Run the heavy loading in a separate thread so Uvicorn doesn't freeze
    await asyncio.to_thread(load_heavy_artifacts)
    
    yield
    _store.clear()

app = FastAPI(
    title="SmartContainer Risk Engine",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/api")
def server_status():
    return {"status": "ok", "message": "Server is running"}

@app.get("/health")
async def health():
    return {"status": "ok", "artifacts": list(_store.keys())}


@app.post("/api/predict-batch")
async def predict_batch(file: UploadFile = File(...)):
    """
    Accept a container manifest CSV (no Clearance_Status column).
    Returns final_predictions.csv as a streaming download.

    Output columns: Container_ID, Risk_Score, Risk_Level, Explanation_Summary
    """
    if not file.filename.lower().endswith(".csv"):
        raise HTTPException(status_code=400, detail="Only .csv files are accepted.")

    # ── Read uploaded test data ───────────────────────────────────────────
    contents = await file.read()
    try:
        test_df = pd.read_csv(io.BytesIO(contents))
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Could not parse CSV: {exc}")

    # ── Fresh copy of cached train data prevents in-place mutation leaking
    #    across concurrent requests.                                        ─
    train_df = _store["train_df_raw"].copy()

    # ── Feature engineering: stats fitted on train_df, mapped to test_df ──
    X_train, X_test, y_train, train_ids, test_ids = preprocess_and_engineer(
        train_df, test_df
    )

    # ── Drop zero-variance Trade_ columns (same step as offline training) ─
    X_train, X_test = prepare_features(X_train, X_test)

    # ── Safe index alignment before all downstream ops ─────────────────────
    X_test    = X_test.reset_index(drop=True)
    test_ids  = test_ids.reset_index(drop=True)

    # ── Inference: inject anomaly score + weighted ensemble predict ────────
    X_test_enriched, proba, predictions, risk_scores = inference_predict(
        _store["xgb"],
        _store["lgb"],
        _store["cat"],
        _store["iso"],
        _store["iso_rmin"],
        _store["iso_rmax"],
        X_test,
    )

    # ── SHAP explanations via XGBoost + build output DataFrame ────────────
    # X_test_enriched already has Anomaly_Score; test_ids is 0-indexed.
    output = explain_and_save(
        _store["xgb"], X_test_enriched, test_ids, predictions, risk_scores
    )

    # Integrity guard: lengths must match before streaming
    if len(output) != len(test_ids):
        raise HTTPException(
            status_code=500,
            detail=f"Row count mismatch: output={len(output)}, ids={len(test_ids)}",
        )

    # ── Stream result as CSV (index=False β†’ no 'Unnamed: 0' column) ────────
    stream = io.StringIO()
    output.to_csv(stream, index=False)
    stream.seek(0)

    return StreamingResponse(
        iter([stream.getvalue()]),
        media_type="text/csv",
        headers={"Content-Disposition": "attachment; filename=final_predictions.csv"},
    )


# ═══════════════════════════════════════════════════════════════════════════
#  TRADE INTELLIGENCE β€” News endpoint (GNews upstream)
# ═══════════════════════════════════════════════════════════════════════════
_CATEGORY_TERMS = {
    "congestion":  "congestion",
    "shipping":    "shipping",
    "container":   "container",
    "trade":       "trade",
    "terminal":    "terminal",
}


@app.get("/api/trade/trade-intelligence/news")
async def trade_intelligence_news(
    keyword:  str = Query(..., min_length=1),
    category: str = Query("all"),
    limit:    int = Query(10, ge=1, le=50),
):
    """
    Proxy to GNews API.  Maps upstream response to the article schema
    expected by the React frontend.
    """
    if not GNEWS_API_KEY:
        raise HTTPException(
            status_code=401,
            detail="News API key is not configured on the server.",
        )

    # Build search query β€” use OR to broaden instead of AND-narrowing
    if category != "all" and category in _CATEGORY_TERMS:
        search_q = f"{keyword} OR {_CATEGORY_TERMS[category]}"
    else:
        search_q = keyword

    params = {
        "q":        search_q,
        "language": "en",           
        "pageSize": str(limit),    
        "apiKey":   GNEWS_API_KEY,   
    }

    try:
        async with httpx.AsyncClient(timeout=15.0) as client:
            resp = await client.get("https://newsapi.org/v2/everything", params=params)
    except httpx.TimeoutException:
        raise HTTPException(status_code=504)
    except httpx.RequestError:
        raise HTTPException(status_code=502)

    # Map upstream status codes to what the frontend expects
    if resp.status_code == 401 or resp.status_code == 403:
        raise HTTPException(status_code=401)
    if resp.status_code == 429:
        raise HTTPException(status_code=429)
    if resp.status_code >= 500:
        raise HTTPException(status_code=502)
    if resp.status_code != 200:
        raise HTTPException(status_code=500)

    data = resp.json()
    raw_articles = data.get("articles", [])

    articles = [
        {
            "title":        a.get("title", ""),
            "description":  a.get("description"),
            "url":          a.get("url", ""),
            "image_url":    a.get("image"),
            "source_name":  (a.get("source") or {}).get("name", "Unknown"),
            "published_at": a.get("publishedAt", ""),
        }
        for a in raw_articles
    ]

    return {"articles": articles}