Spaces:
Sleeping
Sleeping
File size: 5,582 Bytes
fd94cbc b2f845a fd94cbc de15c4b b2f845a fd94cbc b2f845a fd94cbc b2f845a fd94cbc de15c4b b2f845a fd94cbc b2f845a fd94cbc 9cd2352 b2f845a fd94cbc 505462c fd94cbc 505462c fd94cbc 505462c fd94cbc 505462c fd94cbc b2f845a fd94cbc b2f845a fd94cbc b2f845a fd94cbc b2f845a fd94cbc b2f845a fd94cbc b2f845a fd94cbc c5bb78d fd94cbc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import os
import json
import uvicorn
import pandas as pd
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, ConfigDict
from catboost import CatBoostClassifier
from typing import Dict, Any
# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
app = FastAPI(
title="PPD Risk Assessment API",
description="AI-powered screening tool for Postpartum Depression Risk (Top 20 Features)",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace "*" with your frontend URL in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# 2. ARTIFACT PATH SETUP
# ==========================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final") # Hugging Face compatible path
print("ARTIFACTS DIR:", ARTIFACTS_DIR)
print("EXISTS:", os.path.exists(ARTIFACTS_DIR))
# ==========================================
# 3. LOAD ARTIFACTS
# ==========================================
print(" Loading AI Models and Config...")
try:
# A. Load CatBoost Model
model_path = os.path.join(ARTIFACTS_DIR, "catboost_model_top20.cbm")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
model = CatBoostClassifier()
model.load_model(model_path)
print(" Model Loaded.")
# B. Load Metadata
meta_path = os.path.join(ARTIFACTS_DIR, "catboost_metadata.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Metadata not found at {meta_path}")
with open(meta_path, "r") as f:
metadata = json.load(f)
TOP_FEATURES = metadata.get("features_used", [])
THRESHOLD = metadata.get("thresholds", {}).get("optimal_balanced", 0.5)
print(f" Metadata Loaded. Threshold set to: {THRESHOLD}")
# C. Load UI Schema
ui_path = os.path.join(ARTIFACTS_DIR, "model_ui_schema.json")
if not os.path.exists(ui_path):
raise FileNotFoundError(f"UI schema not found at {ui_path}")
with open(ui_path, "r") as f:
ui_schema = json.load(f)
print(" UI Schema Loaded.")
except Exception as e:
print(f" CRITICAL ERROR LOADING ARTIFACTS: {e}")
raise e
# ==========================================
# 4. DATA VALIDATION
# ==========================================
class PatientData(BaseModel):
data: Dict[str, Any]
model_config = ConfigDict(
json_schema_extra={
"example": {
"data": {
"Need for Support": "High",
"Recieved Support": "Low",
"Abuse": "Yes",
"Disease before pregnancy": "None",
"Occupation before latest pregnancy": "Housewife",
"Pregnancy plan": "Unplanned",
"Relationship with husband": "Bad",
"Major changes or losses during pregnancy": "Yes",
"Relationship with the in-laws": "Bad",
"Birth compliancy": "No",
"Relationship between father and newborn": "Bad",
"Education Level": "Secondary",
"Family type": "Nuclear",
"Diseases during pregnancy": "Yes",
"Trust and share feelings": "No",
"Relationship with the newborn": "Average",
"Occupation After Your Latest Childbirth": "Unemployed",
"Age": 24,
"Addiction": "No",
"Husband's education level": "Secondary"
}
}
}
)
# ==========================================
# 5. HELPER FUNCTION
# ==========================================
def preprocess_input(raw_data: dict) -> pd.DataFrame:
clean_data = {}
for k, v in raw_data.items():
if isinstance(v, str):
clean_data[k] = v.lower()
else:
clean_data[k] = v
df = pd.DataFrame([clean_data])
# Fill missing features
for col in TOP_FEATURES:
if col not in df.columns:
df[col] = "unknown"
df = df[TOP_FEATURES]
return df
# ==========================================
# 6. API ENDPOINTS
# ==========================================
@app.get("/")
def home():
return {"status": "online", "model": "CatBoost Top20", "threshold": THRESHOLD}
@app.get("/config")
def get_ui_config():
return ui_schema
@app.post("/predict")
def predict_risk(payload: PatientData):
try:
input_df = preprocess_input(payload.data)
risk_prob = model.predict_proba(input_df)[0][1]
is_high_risk = bool(risk_prob >= THRESHOLD)
return {
"prediction": "HIGH RISK" if is_high_risk else "LOW RISK",
"risk_probability": round(float(risk_prob), 4),
"threshold_used": THRESHOLD,
"flag": 1 if is_high_risk else 0,
"clinical_note": "Refer to specialist" if is_high_risk else "Standard monitoring"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ==========================================
# 7. RUNNER
# ==========================================
if __name__ == "__main__":
print(" Server starting...")
uvicorn.run(app, host="0.0.0.0", port=7860)
|