File size: 3,401 Bytes
227f031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bc94d6
227f031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
import numpy as np
import joblib
import torch
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from pytorch_tabnet.tab_model import TabNetClassifier

from pathlib import Path
import os
import sys


BASE_DIR = Path(__file__).resolve().parents[1]


if str(BASE_DIR) not in sys.path:
    sys.path.insert(0, str(BASE_DIR))





#  CONFIG 
MODEL_DIR = BASE_DIR/"models"
FUSION_PATH = BASE_DIR/"fusion"/"fusion_model_metadata.joblib"

FEATURES = [
    "rainfall_mm", "humidity_pct", "temp_c", "vegetation_index",
    "elevation_m", "proximity_to_water_km", "wind_speed_mps",
    "surface_water_presence", "daylight_hours", "season_label"
]

#  LOAD MODELS 
print(" Loading models...")

cat_model = CatBoostClassifier()
cat_model.load_model(f"{MODEL_DIR}/catboost.cbm")

xgb_model = joblib.load(f"{MODEL_DIR}/xgboost.joblib")
lgb_model = joblib.load(f"{MODEL_DIR}/lightgbm.joblib")

tabnet_model = TabNetClassifier()
tabnet_model.load_model(f"{MODEL_DIR}/tabnet.zip.zip")
tabnet_scaler = joblib.load(f"{MODEL_DIR}/tabnet_scaler.joblib")

fusion_meta = joblib.load(FUSION_PATH)
print(" All models and fusion metadata loaded successfully.")

# FASTAPI APP 
app = FastAPI(title="Malaria Risk Prediction API", version="1.0")

#  INPUT SCHEMA 
class MalariaInput(BaseModel):
    rainfall_mm: float
    humidity_pct: float
    temp_c: float
    vegetation_index: float
    elevation_m: float
    proximity_to_water_km: float
    wind_speed_mps: float
    surface_water_presence: float
    daylight_hours: float
    season_label: int  # 0 = Dry, 1 = Rainy


#  PREDICTION ENDPOINT 
@app.post("/predict/")
def predict(data: MalariaInput):
    try:
        
        x = np.array([[getattr(data, f) for f in FEATURES]], dtype=float)
        x_scaled = tabnet_scaler.transform(x)

        # Individual model probabilities
        preds = {
            "catboost": float(cat_model.predict_proba(x)[0][1]),
            "xgboost": float(xgb_model.predict_proba(x)[0][1]),
            "lightgbm": float(lgb_model.predict_proba(x)[0][1]),
            "tabnet": float(tabnet_model.predict_proba(torch.tensor(x_scaled).float().numpy())[0][1]),
        }

        # Fusion weights
        models_in_fusion = fusion_meta.get("models", [])
        weights_data = fusion_meta.get("weights", [])
        weights = {m: float(w) for m, w in zip(models_in_fusion, weights_data)}

        # Weighted ensemble score
        risk_score = float(sum(preds[m] * weights[m] for m in preds if m in weights))
        #risk_percentage = round(risk_score * 100, 2)

        # Label decision
        risk_label = "High" if risk_score >= 0.60 else "Medium"

        response = {
            #"input_features": data.dict(),
            "model_outputs": preds,
           # "fusion_weights": weights,
            "risk_score": round(risk_score, 3),
            "risk_label": risk_label
        }

        return JSONResponse(content=jsonable_encoder(response))

    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": str(e), "message": "Prediction failed."}
        )


# ROOT ENDPOINT 
@app.get("/")
def root():
    return {"message": " Malaria Risk Prediction API is running!"}