Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,13 +35,13 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
| 35 |
ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final")
|
| 36 |
|
| 37 |
MODEL_FILE = "catboost_model_top20.cbm"
|
| 38 |
-
|
| 39 |
UI_FILE = "model_ui_schema.json"
|
| 40 |
|
| 41 |
-
THRESHOLD = 0.3
|
| 42 |
|
| 43 |
# =========================================================
|
| 44 |
-
# 3. LOAD ARTIFACTS
|
| 45 |
# =========================================================
|
| 46 |
|
| 47 |
print("⏳ Loading AI Models and Config...")
|
|
@@ -54,20 +54,19 @@ print("📄 Files found:", os.listdir(ARTIFACTS_DIR))
|
|
| 54 |
|
| 55 |
# ---- Load Model ----
|
| 56 |
model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
|
| 57 |
-
if not os.path.exists(model_path):
|
| 58 |
-
raise FileNotFoundError(f"Model not found at {model_path}")
|
| 59 |
-
|
| 60 |
model = CatBoostClassifier()
|
| 61 |
model.load_model(model_path)
|
| 62 |
print("✅ CatBoost model loaded")
|
| 63 |
|
| 64 |
-
# ---- Load
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
TOP_FEATURES
|
| 70 |
-
print(f"✅ Metadata loaded ({len(TOP_FEATURES)} features)")
|
| 71 |
|
| 72 |
# ---- Load UI Schema ----
|
| 73 |
ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
|
|
@@ -118,28 +117,19 @@ class PatientData(BaseModel):
|
|
| 118 |
# =========================================================
|
| 119 |
|
| 120 |
def preprocess_input(raw_data: Dict[str, Any]) -> pd.DataFrame:
|
| 121 |
-
"""
|
| 122 |
-
- Lowercases categorical strings
|
| 123 |
-
- Ensures all Top-20 features exist
|
| 124 |
-
- Orders columns exactly as training
|
| 125 |
-
"""
|
| 126 |
clean_data = {}
|
| 127 |
|
| 128 |
-
for
|
| 129 |
-
if isinstance(
|
| 130 |
-
clean_data[key] = value.lower()
|
| 131 |
-
else:
|
| 132 |
-
clean_data[key] = value
|
| 133 |
|
| 134 |
df = pd.DataFrame([clean_data])
|
| 135 |
|
| 136 |
-
#
|
| 137 |
for col in TOP_FEATURES:
|
| 138 |
if col not in df.columns:
|
| 139 |
df[col] = "unknown"
|
| 140 |
|
| 141 |
-
|
| 142 |
-
return df
|
| 143 |
|
| 144 |
# =========================================================
|
| 145 |
# 6. API ENDPOINTS
|
|
@@ -155,10 +145,6 @@ def health():
|
|
| 155 |
|
| 156 |
@app.get("/config")
|
| 157 |
def get_ui_config():
|
| 158 |
-
"""
|
| 159 |
-
Returns your provided UI JSON
|
| 160 |
-
Used by frontend to auto-render form
|
| 161 |
-
"""
|
| 162 |
return ui_schema
|
| 163 |
|
| 164 |
@app.post("/predict")
|
|
@@ -166,7 +152,6 @@ def predict(payload: PatientData):
|
|
| 166 |
try:
|
| 167 |
input_df = preprocess_input(payload.data)
|
| 168 |
|
| 169 |
-
# Probability of positive (PPD risk)
|
| 170 |
risk_prob = model.predict_proba(input_df)[0][1]
|
| 171 |
is_high_risk = risk_prob >= THRESHOLD
|
| 172 |
|
|
@@ -186,7 +171,7 @@ def predict(payload: PatientData):
|
|
| 186 |
raise HTTPException(status_code=500, detail=str(e))
|
| 187 |
|
| 188 |
# =========================================================
|
| 189 |
-
# 7. RUNNER (HF
|
| 190 |
# =========================================================
|
| 191 |
|
| 192 |
if __name__ == "__main__":
|
|
|
|
| 35 |
ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final")
|
| 36 |
|
| 37 |
MODEL_FILE = "catboost_model_top20.cbm"
|
| 38 |
+
TOP_FEATURES_FILE = "top20_features.csv"
|
| 39 |
UI_FILE = "model_ui_schema.json"
|
| 40 |
|
| 41 |
+
THRESHOLD = 0.3
|
| 42 |
|
| 43 |
# =========================================================
|
| 44 |
+
# 3. LOAD ARTIFACTS
|
| 45 |
# =========================================================
|
| 46 |
|
| 47 |
print("⏳ Loading AI Models and Config...")
|
|
|
|
| 54 |
|
| 55 |
# ---- Load Model ----
|
| 56 |
model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
|
|
|
|
|
|
|
|
|
|
| 57 |
model = CatBoostClassifier()
|
| 58 |
model.load_model(model_path)
|
| 59 |
print("✅ CatBoost model loaded")
|
| 60 |
|
| 61 |
+
# ---- Load Top 20 Features (SOURCE OF TRUTH) ----
|
| 62 |
+
features_path = os.path.join(ARTIFACTS_DIR, TOP_FEATURES_FILE)
|
| 63 |
+
TOP_FEATURES = (
|
| 64 |
+
pd.read_csv(features_path, header=None)[0]
|
| 65 |
+
.astype(str)
|
| 66 |
+
.tolist()
|
| 67 |
+
)
|
| 68 |
|
| 69 |
+
print(f"✅ Loaded {len(TOP_FEATURES)} top features")
|
|
|
|
| 70 |
|
| 71 |
# ---- Load UI Schema ----
|
| 72 |
ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
|
|
|
|
| 117 |
# =========================================================
|
| 118 |
|
| 119 |
def preprocess_input(raw_data: Dict[str, Any]) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
clean_data = {}
|
| 121 |
|
| 122 |
+
for k, v in raw_data.items():
|
| 123 |
+
clean_data[k] = v.lower() if isinstance(v, str) else v
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
df = pd.DataFrame([clean_data])
|
| 126 |
|
| 127 |
+
# Ensure all required features exist
|
| 128 |
for col in TOP_FEATURES:
|
| 129 |
if col not in df.columns:
|
| 130 |
df[col] = "unknown"
|
| 131 |
|
| 132 |
+
return df[TOP_FEATURES]
|
|
|
|
| 133 |
|
| 134 |
# =========================================================
|
| 135 |
# 6. API ENDPOINTS
|
|
|
|
| 145 |
|
| 146 |
@app.get("/config")
|
| 147 |
def get_ui_config():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return ui_schema
|
| 149 |
|
| 150 |
@app.post("/predict")
|
|
|
|
| 152 |
try:
|
| 153 |
input_df = preprocess_input(payload.data)
|
| 154 |
|
|
|
|
| 155 |
risk_prob = model.predict_proba(input_df)[0][1]
|
| 156 |
is_high_risk = risk_prob >= THRESHOLD
|
| 157 |
|
|
|
|
| 171 |
raise HTTPException(status_code=500, detail=str(e))
|
| 172 |
|
| 173 |
# =========================================================
|
| 174 |
+
# 7. RUNNER (HF)
|
| 175 |
# =========================================================
|
| 176 |
|
| 177 |
if __name__ == "__main__":
|