appledog00 commited on
Commit
fd94cbc
Β·
verified Β·
1 Parent(s): c5bb78d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -116
app.py CHANGED
@@ -1,137 +1,174 @@
 
1
  import json
2
- from pathlib import Path
3
- from typing import Dict, Any
4
-
5
  import pandas as pd
 
6
  from fastapi import FastAPI, HTTPException
 
 
7
  from catboost import CatBoostClassifier
 
8
 
9
- # =========================
10
- # CONFIG
11
- # =========================
12
- ARTIFACTS_DIR = Path("artifacts_final")
13
- MODEL_FILE = ARTIFACTS_DIR / "catboost_model_top20.cbm"
14
- UI_SCHEMA_FILE = ARTIFACTS_DIR / "model_ui_schema.json"
15
- TOP_FEATURES_FILE = ARTIFACTS_DIR / "top20_features.csv"
16
-
17
- THRESHOLD = 0.41 # βœ… FINAL OPERATING THRESHOLD
18
 
19
- # =========================
20
- # APP INIT (with docs enabled)
21
- # =========================
22
  app = FastAPI(
23
  title="PPD Risk Assessment API",
24
- description="Hybrid ML-based screening API for Postpartum Depression risk",
25
- version="1.0.0",
26
- docs_url="/docs", # Swagger UI
27
- redoc_url="/redoc" # ReDoc UI
28
  )
29
 
30
- # =========================
31
- # LOAD ARTIFACTS
32
- # =========================
33
- print("⏳ Loading AI Models and Config...")
34
- print(f"πŸ“ Expected artifacts path: {ARTIFACTS_DIR.resolve()}")
35
-
36
- if not ARTIFACTS_DIR.exists():
37
- raise RuntimeError("❌ artifacts_final folder not found")
38
-
39
- print(f"πŸ“„ Files found: {[f.name for f in ARTIFACTS_DIR.iterdir()]}")
40
-
41
- # --- Load model ---
42
- model = CatBoostClassifier()
43
- model.load_model(str(MODEL_FILE))
44
- print("βœ… CatBoost model loaded")
45
-
46
- # --- Load UI schema ---
47
- with open(UI_SCHEMA_FILE, "r") as f:
48
- ui_schema = json.load(f)
49
-
50
- # --- Load top features safely ---
51
- if not TOP_FEATURES_FILE.exists():
52
- raise RuntimeError("❌ top20_features.csv not found")
53
-
54
- TOP_FEATURES = (
55
- pd.read_csv(TOP_FEATURES_FILE, header=None)
56
- .iloc[:, 0]
57
- .astype(str)
58
- .tolist()
59
  )
60
 
61
- print(f"βœ… Loaded {len(TOP_FEATURES)} top features")
62
-
63
- # =========================
64
- # HELPERS
65
- # =========================
66
- def build_input_dataframe(payload: Dict[str, Any]) -> pd.DataFrame:
67
- """
68
- Build a single-row dataframe aligned with TOP_FEATURES
69
- """
70
- row = {}
71
- for feature in TOP_FEATURES:
72
- if feature not in payload:
73
- raise HTTPException(
74
- status_code=400,
75
- detail=f"Missing required feature: {feature}"
76
- )
77
- row[feature] = payload[feature]
78
-
79
- return pd.DataFrame([row])
80
-
81
- # =========================
82
- # ROUTES
83
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.get("/")
85
- def health_check():
86
- """Simple health check"""
87
- return {
88
- "status": "online",
89
- "model": "CatBoost Top-20",
90
- "threshold": THRESHOLD,
91
- "features_used": len(TOP_FEATURES)
92
- }
93
-
94
- @app.get("/ui-schema")
95
- def get_ui_schema():
96
- """Return the UI schema for frontend forms"""
97
  return ui_schema
98
 
99
  @app.post("/predict")
100
- def predict_risk(payload: Dict[str, Any]):
101
- """
102
- Predict PPD risk probability given a dictionary of feature values
103
- Example payload:
104
- {
105
- "feature1": 3.2,
106
- "feature2": 1.0,
107
- ...
108
- }
109
- """
110
  try:
111
- input_df = build_input_dataframe(payload)
112
-
113
- # CatBoost handles categoricals internally
114
- risk_prob = float(model.predict_proba(input_df)[0][1])
115
- is_high_risk = risk_prob >= THRESHOLD
116
-
117
  return {
118
- "ppd_risk_probability": round(risk_prob, 4),
119
- "threshold": THRESHOLD,
120
- "risk_label": "HIGH RISK" if is_high_risk else "LOW RISK",
121
- "screening_positive": bool(is_high_risk)
 
122
  }
123
-
124
- except HTTPException:
125
- raise
126
  except Exception as e:
127
- raise HTTPException(
128
- status_code=500,
129
- detail=f"Prediction failed: {str(e)}"
130
- )
131
-
132
- # =========================
133
- # RUNNING LOCALLY (optional)
134
- # =========================
135
  if __name__ == "__main__":
136
- import uvicorn
137
- uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)
 
1
+ import os
2
  import json
3
+ import uvicorn
 
 
4
  import pandas as pd
5
+ import numpy as np
6
  from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, ConfigDict
9
  from catboost import CatBoostClassifier
10
+ from typing import Dict, Any
11
 
12
+ # ==========================================
13
+ # 1. SETUP & CONFIGURATION
14
+ # ==========================================
 
 
 
 
 
 
15
 
 
 
 
16
  app = FastAPI(
17
  title="PPD Risk Assessment API",
18
+ description="AI-powered screening tool for Postpartum Depression Risk (Top 20 Features)",
19
+ version="1.0.0"
 
 
20
  )
21
 
22
+ # Enable CORS
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"], # Replace "*" with your frontend URL in production
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
+ # ==========================================
32
+ # 2. ARTIFACT PATH SETUP
33
+ # ==========================================
34
+
35
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ ARTIFACTS_DIR = os.path.normpath(os.path.join(BASE_DIR, "..", "artifacts_final"))
37
+
38
+ print("ARTIFACTS DIR:", ARTIFACTS_DIR)
39
+ print("EXISTS:", os.path.exists(ARTIFACTS_DIR))
40
+
41
+ # ==========================================
42
+ # 3. LOAD ARTIFACTS
43
+ # ==========================================
44
+
45
+ print(" Loading AI Models and Config...")
46
+
47
+ try:
48
+ # A. Load CatBoost Model
49
+ model_path = os.path.join(ARTIFACTS_DIR, "catboost_model_top20.cbm")
50
+ if not os.path.exists(model_path):
51
+ raise FileNotFoundError(f"Model not found at {model_path}")
52
+
53
+ model = CatBoostClassifier()
54
+ model.load_model(model_path)
55
+ print(" Model Loaded.")
56
+
57
+ # B. Load Metadata
58
+ meta_path = os.path.join(ARTIFACTS_DIR, "catboost_metadata.json")
59
+ with open(meta_path, "r") as f:
60
+ metadata = json.load(f)
61
+
62
+ # βœ… Correct key from your metadata
63
+ TOP_FEATURES = metadata["features_used"]
64
+ THRESHOLD = metadata["thresholds"]["optimal_balanced"]
65
+
66
+ print(f" Metadata Loaded. Threshold set to: {THRESHOLD}")
67
+
68
+ # C. Load UI Schema
69
+ ui_path = os.path.join(ARTIFACTS_DIR, "model_ui_schema.json")
70
+ with open(ui_path, "r") as f:
71
+ ui_schema = json.load(f)
72
+ print(" UI Schema Loaded.")
73
+
74
+ except Exception as e:
75
+ print(f" CRITICAL ERROR LOADING ARTIFACTS: {e}")
76
+ raise e
77
+
78
+ # ==========================================
79
+ # 4. DATA VALIDATION
80
+ # ==========================================
81
+
82
+ class PatientData(BaseModel):
83
+ data: Dict[str, Any]
84
+
85
+ model_config = ConfigDict(
86
+ json_schema_extra={
87
+ "example": {
88
+ "data": {
89
+ "Need for Support": "High",
90
+ "Recieved Support": "Low",
91
+ "Abuse": "Yes",
92
+ "Disease before pregnancy": "None",
93
+ "Occupation before latest pregnancy": "Housewife",
94
+ "Pregnancy plan": "Unplanned",
95
+ "Relationship with husband": "Bad",
96
+ "Major changes or losses during pregnancy": "Yes",
97
+ "Relationship with the in-laws": "Bad",
98
+ "Birth compliancy": "No",
99
+ "Relationship between father and newborn": "Bad",
100
+ "Education Level": "Secondary",
101
+ "Family type": "Nuclear",
102
+ "Diseases during pregnancy": "Yes",
103
+ "Trust and share feelings": "No",
104
+ "Relationship with the newborn": "Average",
105
+ "Occupation After Your Latest Childbirth": "Unemployed",
106
+ "Age": 24,
107
+ "Addiction": "No",
108
+ "Husband's education level": "Secondary"
109
+ }
110
+ }
111
+ }
112
+ )
113
+
114
+ # ==========================================
115
+ # 5. HELPER FUNCTION
116
+ # ==========================================
117
+
118
+ def preprocess_input(raw_data: dict) -> pd.DataFrame:
119
+ clean_data = {}
120
+
121
+ for k, v in raw_data.items():
122
+ if isinstance(v, str):
123
+ clean_data[k] = v.lower()
124
+ else:
125
+ clean_data[k] = v
126
+
127
+ df = pd.DataFrame([clean_data])
128
+
129
+ # Fill missing features
130
+ for col in TOP_FEATURES:
131
+ if col not in df.columns:
132
+ df[col] = "unknown"
133
+
134
+ df = df[TOP_FEATURES]
135
+
136
+ return df
137
+
138
+ # ==========================================
139
+ # 6. API ENDPOINTS
140
+ # ==========================================
141
+
142
  @app.get("/")
143
+ def home():
144
+ return {"status": "online", "model": "CatBoost Top20", "threshold": THRESHOLD}
145
+
146
+ @app.get("/config")
147
+ def get_ui_config():
 
 
 
 
 
 
 
148
  return ui_schema
149
 
150
  @app.post("/predict")
151
+ def predict_risk(payload: PatientData):
 
 
 
 
 
 
 
 
 
152
  try:
153
+ input_df = preprocess_input(payload.data)
154
+ risk_prob = model.predict_proba(input_df)[0][1]
155
+ is_high_risk = bool(risk_prob >= THRESHOLD)
156
+
 
 
157
  return {
158
+ "prediction": "HIGH RISK" if is_high_risk else "LOW RISK",
159
+ "risk_probability": round(float(risk_prob), 4),
160
+ "threshold_used": THRESHOLD,
161
+ "flag": 1 if is_high_risk else 0,
162
+ "clinical_note": "Refer to specialist" if is_high_risk else "Standard monitoring"
163
  }
164
+
 
 
165
  except Exception as e:
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+
168
+ # ==========================================
169
+ # 7. RUNNER
170
+ # ==========================================
171
+
 
 
172
  if __name__ == "__main__":
173
+ print(" Server starting...")
174
+ uvicorn.run(app, host="0.0.0.0", port=7860)