appledog00 commited on
Commit
de15c4b
Β·
verified Β·
1 Parent(s): 9cd2352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -134
app.py CHANGED
@@ -1,179 +1,120 @@
1
- import os
2
  import json
3
- import uvicorn
4
- import pandas as pd
5
 
 
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from pydantic import BaseModel, ConfigDict
9
  from catboost import CatBoostClassifier
10
- from typing import Dict, Any
11
 
12
- # =========================================================
13
- # 1. FASTAPI SETUP
14
- # =========================================================
 
 
 
 
15
 
 
 
 
 
 
16
  app = FastAPI(
17
  title="PPD Risk Assessment API",
18
- description="AI-powered screening tool for Postpartum Depression (CatBoost Top-20)",
19
  version="1.0.0"
20
  )
21
 
22
- app.add_middleware(
23
- CORSMiddleware,
24
- allow_origins=["*"],
25
- allow_credentials=True,
26
- allow_methods=["*"],
27
- allow_headers=["*"],
28
- )
29
-
30
- # =========================================================
31
- # 2. PATH CONFIG (HF SAFE)
32
- # =========================================================
33
-
34
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
35
- ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final")
36
-
37
- MODEL_FILE = "catboost_model_top20.cbm"
38
- TOP_FEATURES_FILE = "top20_features.csv"
39
- UI_FILE = "model_ui_schema.json"
40
-
41
- THRESHOLD = 0.3
42
-
43
- # =========================================================
44
- # 3. LOAD ARTIFACTS
45
- # =========================================================
46
-
47
  print("⏳ Loading AI Models and Config...")
48
- print("πŸ“ Expected artifacts path:", ARTIFACTS_DIR)
49
 
50
- if not os.path.isdir(ARTIFACTS_DIR):
51
- raise RuntimeError(f"Artifacts directory not found: {ARTIFACTS_DIR}")
52
 
53
- print("πŸ“„ Files found:", os.listdir(ARTIFACTS_DIR))
54
 
55
- # ---- Load Model ----
56
- model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
57
  model = CatBoostClassifier()
58
- model.load_model(model_path)
59
  print("βœ… CatBoost model loaded")
60
 
61
- # ---- Load Top 20 Features (SOURCE OF TRUTH) ----
62
- features_path = os.path.join(ARTIFACTS_DIR, TOP_FEATURES_FILE)
 
 
 
 
 
 
63
  TOP_FEATURES = (
64
- pd.read_csv(features_path, header=None)[0]
 
65
  .astype(str)
66
  .tolist()
67
  )
68
 
69
  print(f"βœ… Loaded {len(TOP_FEATURES)} top features")
70
 
71
- # ---- Load UI Schema ----
72
- ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
73
- with open(ui_path, "r") as f:
74
- ui_schema = json.load(f)
75
-
76
- print("βœ… UI schema loaded")
77
- print(f"🚦 Threshold set to {THRESHOLD}")
78
-
79
- # =========================================================
80
- # 4. REQUEST SCHEMA
81
- # =========================================================
82
-
83
- class PatientData(BaseModel):
84
- data: Dict[str, Any]
85
-
86
- model_config = ConfigDict(
87
- json_schema_extra={
88
- "example": {
89
- "data": {
90
- "Need for Support": "high",
91
- "Recieved Support": "low",
92
- "Abuse": "no",
93
- "Disease before pregnancy": "none",
94
- "Pregnancy plan": "no",
95
- "Relationship with the in-laws": "bad",
96
- "Relationship with husband": "bad",
97
- "Occupation before latest pregnancy": "housewife",
98
- "Major changes or losses during pregnancy": "no",
99
- "Relationship with the newborn": "good",
100
- "Family type": "nuclear",
101
- "Diseases during pregnancy": "none",
102
- "Relationship between father and newborn": "good",
103
- "Husband's education level": "college",
104
- "Trust and share feelings": "no",
105
- "Birth compliancy": "no",
106
- "Education Level": "college",
107
- "Occupation After Your Latest Childbirth": "housewife",
108
- "Addiction": "none",
109
- "Age": 24
110
- }
111
- }
112
- }
113
- )
114
-
115
- # =========================================================
116
- # 5. PREPROCESSING
117
- # =========================================================
118
-
119
- def preprocess_input(raw_data: Dict[str, Any]) -> pd.DataFrame:
120
- clean_data = {}
121
-
122
- for k, v in raw_data.items():
123
- clean_data[k] = v.lower() if isinstance(v, str) else v
124
-
125
- df = pd.DataFrame([clean_data])
126
-
127
- # Ensure all required features exist
128
- for col in TOP_FEATURES:
129
- if col not in df.columns:
130
- df[col] = "unknown"
131
 
132
- return df[TOP_FEATURES]
133
 
134
- # =========================================================
135
- # 6. API ENDPOINTS
136
- # =========================================================
137
 
 
 
 
138
  @app.get("/")
139
- def health():
140
  return {
141
  "status": "online",
142
  "model": "CatBoost Top-20",
143
- "threshold": THRESHOLD
 
144
  }
145
 
146
- @app.get("/config")
147
- def get_ui_config():
 
148
  return ui_schema
149
 
 
150
  @app.post("/predict")
151
- def predict(payload: PatientData):
152
  try:
153
- input_df = preprocess_input(payload.data)
154
 
155
- risk_prob = model.predict_proba(input_df)[0][1]
 
156
  is_high_risk = risk_prob >= THRESHOLD
157
 
158
  return {
159
- "prediction": "HIGH RISK" if is_high_risk else "LOW RISK",
160
- "risk_probability": round(float(risk_prob), 4),
161
- "threshold_used": THRESHOLD,
162
- "flag": int(is_high_risk),
163
- "clinical_note": (
164
- "Refer to specialist"
165
- if is_high_risk
166
- else "Standard monitoring"
167
- )
168
  }
169
 
 
 
170
  except Exception as e:
171
- raise HTTPException(status_code=500, detail=str(e))
172
-
173
- # =========================================================
174
- # 7. RUNNER (HF)
175
- # =========================================================
176
-
177
- if __name__ == "__main__":
178
- print("πŸš€ Starting server...")
179
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import json
2
+ from pathlib import Path
3
+ from typing import Dict, Any
4
 
5
+ import pandas as pd
6
  from fastapi import FastAPI, HTTPException
 
 
7
  from catboost import CatBoostClassifier
 
8
 
9
+ # =========================
10
+ # CONFIG
11
+ # =========================
12
+ ARTIFACTS_DIR = Path("artifacts_final")
13
+ MODEL_FILE = ARTIFACTS_DIR / "catboost_model_top20.cbm"
14
+ UI_SCHEMA_FILE = ARTIFACTS_DIR / "model_ui_schema.json"
15
+ TOP_FEATURES_FILE = ARTIFACTS_DIR / "top20_features.csv"
16
 
17
+ THRESHOLD = 0.41 # βœ… FINAL OPERATING THRESHOLD
18
+
19
+ # =========================
20
+ # APP INIT
21
+ # =========================
22
  app = FastAPI(
23
  title="PPD Risk Assessment API",
24
+ description="Hybrid ML-based screening API for Postpartum Depression risk",
25
  version="1.0.0"
26
  )
27
 
28
+ # =========================
29
+ # LOAD ARTIFACTS
30
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  print("⏳ Loading AI Models and Config...")
32
+ print(f"πŸ“ Expected artifacts path: {ARTIFACTS_DIR.resolve()}")
33
 
34
+ if not ARTIFACTS_DIR.exists():
35
+ raise RuntimeError("❌ artifacts_final folder not found")
36
 
37
+ print(f"πŸ“„ Files found: {[f.name for f in ARTIFACTS_DIR.iterdir()]}")
38
 
39
+ # --- Load model ---
 
40
  model = CatBoostClassifier()
41
+ model.load_model(str(MODEL_FILE))
42
  print("βœ… CatBoost model loaded")
43
 
44
+ # --- Load UI schema ---
45
+ with open(UI_SCHEMA_FILE, "r") as f:
46
+ ui_schema = json.load(f)
47
+
48
+ # --- Load top features safely ---
49
+ if not TOP_FEATURES_FILE.exists():
50
+ raise RuntimeError("❌ top20_features.csv not found")
51
+
52
  TOP_FEATURES = (
53
+ pd.read_csv(TOP_FEATURES_FILE, header=None)
54
+ .iloc[:, 0]
55
  .astype(str)
56
  .tolist()
57
  )
58
 
59
  print(f"βœ… Loaded {len(TOP_FEATURES)} top features")
60
 
61
+ # =========================
62
+ # HELPERS
63
+ # =========================
64
+ def build_input_dataframe(payload: Dict[str, Any]) -> pd.DataFrame:
65
+ """
66
+ Build a single-row dataframe aligned with TOP_FEATURES
67
+ """
68
+ row = {}
69
+ for feature in TOP_FEATURES:
70
+ if feature not in payload:
71
+ raise HTTPException(
72
+ status_code=400,
73
+ detail=f"Missing required feature: {feature}"
74
+ )
75
+ row[feature] = payload[feature]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ return pd.DataFrame([row])
78
 
 
 
 
79
 
80
+ # =========================
81
+ # ROUTES
82
+ # =========================
83
  @app.get("/")
84
+ def health_check():
85
  return {
86
  "status": "online",
87
  "model": "CatBoost Top-20",
88
+ "threshold": THRESHOLD,
89
+ "features_used": len(TOP_FEATURES)
90
  }
91
 
92
+
93
+ @app.get("/ui-schema")
94
+ def get_ui_schema():
95
  return ui_schema
96
 
97
+
98
  @app.post("/predict")
99
+ def predict_risk(payload: Dict[str, Any]):
100
  try:
101
+ input_df = build_input_dataframe(payload)
102
 
103
+ # CatBoost handles categoricals internally
104
+ risk_prob = float(model.predict_proba(input_df)[0][1])
105
  is_high_risk = risk_prob >= THRESHOLD
106
 
107
  return {
108
+ "ppd_risk_probability": round(risk_prob, 4),
109
+ "threshold": THRESHOLD,
110
+ "risk_label": "HIGH RISK" if is_high_risk else "LOW RISK",
111
+ "screening_positive": bool(is_high_risk)
 
 
 
 
 
112
  }
113
 
114
+ except HTTPException:
115
+ raise
116
  except Exception as e:
117
+ raise HTTPException(
118
+ status_code=500,
119
+ detail=f"Prediction failed: {str(e)}"
120
+ )