appledog00 commited on
Commit
b2f845a
Β·
verified Β·
1 Parent(s): 97afca9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -193
app.py CHANGED
@@ -1,193 +1,186 @@
1
- import os
2
- import json
3
- import uvicorn
4
- import pandas as pd
5
- import numpy as np
6
- from fastapi import FastAPI, HTTPException
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from pydantic import BaseModel, ConfigDict
9
- from catboost import CatBoostClassifier
10
- from typing import Dict, Any
11
-
12
- # ==========================================
13
- # 1. SETUP & CONFIGURATION
14
- # ==========================================
15
-
16
- app = FastAPI(
17
- title="PPD Risk Assessment API",
18
- description="AI-powered screening tool for Postpartum Depression Risk (Top 20 Features)",
19
- version="1.0.0"
20
- )
21
-
22
- # Enable CORS (Allows your frontend/website to talk to this API)
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"], # In production, replace "*" with your frontend URL
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- ARTIFACTS_DIR = "artifacts"
32
-
33
- # ==========================================
34
- # 2. LOAD ARTIFACTS (The Brain)
35
- # ==========================================
36
-
37
- print("⏳ Loading AI Models and Config...")
38
-
39
- try:
40
- # A. Load Model
41
- model_path = os.path.join(ARTIFACTS_DIR, "catboost_model_top20.cbm")
42
- if not os.path.exists(model_path):
43
- raise FileNotFoundError(f"Model not found at {model_path}")
44
-
45
- model = CatBoostClassifier()
46
- model.load_model(model_path)
47
- print("βœ… Model Loaded.")
48
-
49
- # B. Load Metadata (Thresholds & Feature List)
50
- meta_path = os.path.join(ARTIFACTS_DIR, "catboost_metadata.json")
51
- with open(meta_path, "r") as f:
52
- metadata = json.load(f)
53
-
54
- TOP_FEATURES = metadata["top_features"]
55
-
56
- # --- THRESHOLD CONFIGURATION ---
57
- # Originally 0.3, but we updated to 0.5 to reduce False Positives
58
- # based on your testing (Patient with 43% risk should be Low Risk).
59
- THRESHOLD = 0.3
60
-
61
- print(f"βœ… Metadata Loaded. Threshold set to: {THRESHOLD}")
62
-
63
- # C. Load UI Schema (For Frontend Dynamic Forms)
64
- ui_path = os.path.join(ARTIFACTS_DIR, "model_ui_schema.json")
65
- with open(ui_path, "r") as f:
66
- ui_schema = json.load(f)
67
- print("βœ… UI Schema Loaded.")
68
-
69
- except Exception as e:
70
- print(f"❌ CRITICAL ERROR LOADING ARTIFACTS: {e}")
71
- raise e
72
-
73
-
74
- # ==========================================
75
- # 3. DATA VALIDATION (Pydantic)
76
- # ==========================================
77
-
78
- class PatientData(BaseModel):
79
- data: Dict[str, Any]
80
-
81
- # Updated for Pydantic V2 (No warnings)
82
- model_config = ConfigDict(
83
- json_schema_extra={
84
- "example": {
85
- "data": {
86
- "Need for Support": "High",
87
- "Recieved Support": "Low",
88
- "Abuse": "Yes",
89
- "Disease before pregnancy": "None",
90
- "Occupation before latest pregnancy": "Housewife",
91
- "Pregnancy plan": "Unplanned",
92
- "Relationship with husband": "Bad",
93
- "Major changes or losses during pregnancy": "Yes",
94
- "Relationship with the in-laws": "Bad",
95
- "Birth compliancy": "No",
96
- "Relationship between father and newborn": "Bad",
97
- "Education Level": "Secondary",
98
- "Family type": "Nuclear",
99
- "Diseases during pregnancy": "Yes",
100
- "Trust and share feelings": "No",
101
- "Relationship with the newborn": "Average",
102
- "Occupation After Your Latest Childbirth": "Unemployed",
103
- "Age": 24,
104
- "Addiction": "No",
105
- "Husband's education level": "Secondary"
106
- }
107
- }
108
- }
109
- )
110
-
111
- # ==========================================
112
- # 4. HELPER FUNCTIONS
113
- # ==========================================
114
-
115
- def preprocess_input(raw_data: dict) -> pd.DataFrame:
116
- """
117
- Cleans input dictionary: lowercases strings, handles missing cols, sorts cols.
118
- """
119
- clean_data = {}
120
-
121
- # 1. Lowercase string inputs to match model training
122
- for k, v in raw_data.items():
123
- if isinstance(v, str):
124
- clean_data[k] = v.lower()
125
- else:
126
- clean_data[k] = v
127
-
128
- # 2. Create DataFrame
129
- df = pd.DataFrame([clean_data])
130
-
131
- # 3. Ensure all Top 20 columns exist (fill missing with 'unknown')
132
- # This prevents crashing if the user leaves a field blank
133
- for col in TOP_FEATURES:
134
- if col not in df.columns:
135
- df[col] = "unknown"
136
-
137
- # 4. Reorder columns to match exactly what the model expects
138
- df = df[TOP_FEATURES]
139
-
140
- return df
141
-
142
- # ==========================================
143
- # 5. API ENDPOINTS
144
- # ==========================================
145
-
146
- @app.get("/")
147
- def home():
148
- """Health check endpoint."""
149
- return {"status": "online", "model": "CatBoost Top20", "threshold": THRESHOLD}
150
-
151
- @app.get("/config")
152
- def get_ui_config():
153
- """
154
- Returns the UI Schema (Dropdown options, Labels).
155
- Your Frontend (React/Streamlit) should call this to build the form automatically.
156
- """
157
- return ui_schema
158
-
159
- @app.post("/predict")
160
- def predict_risk(payload: PatientData):
161
- """
162
- Main prediction endpoint.
163
- Accepts patient data -> Returns Risk Status & Probability.
164
- """
165
- try:
166
- # 1. Preprocess Data
167
- input_df = preprocess_input(payload.data)
168
-
169
- # 2. Get Probability of Risk (Class 1)
170
- # [0][1] index gets the probability of "Positive" (Risk)
171
- risk_prob = model.predict_proba(input_df)[0][1]
172
-
173
- # 3. Apply Threshold Logic
174
- is_high_risk = bool(risk_prob >= THRESHOLD)
175
-
176
- return {
177
- "prediction": "HIGH RISK" if is_high_risk else "LOW RISK",
178
- "risk_probability": round(float(risk_prob), 4),
179
- "threshold_used": THRESHOLD,
180
- "flag": 1 if is_high_risk else 0,
181
- "clinical_note": "Refer to specialist" if is_high_risk else "Standard monitoring"
182
- }
183
-
184
- except Exception as e:
185
- raise HTTPException(status_code=500, detail=str(e))
186
-
187
- # ==========================================
188
- # 6. RUNNER
189
- # ==========================================
190
- if __name__ == "__main__":
191
- # Updated to 0.0.0.0 and 7860 for Docker/Hugging Face compatibility
192
- print("πŸš€ Server starting...")
193
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
+ import json
3
+ import uvicorn
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, ConfigDict
10
+ from catboost import CatBoostClassifier
11
+ from typing import Dict, Any
12
+
13
+ # ==========================================
14
+ # 1. APP SETUP
15
+ # ==========================================
16
+
17
+ app = FastAPI(
18
+ title="PPD Risk Assessment API",
19
+ description="AI-powered screening tool for Postpartum Depression Risk (Top 20 Features)",
20
+ version="1.0.0"
21
+ )
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # ==========================================
32
+ # 2. PATH CONFIG (HF SAFE)
33
+ # ==========================================
34
+
35
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifact_final")
37
+
38
+ MODEL_FILE = "catboost_model_top20.cbm"
39
+ META_FILE = "catboost_metadata.json"
40
+ UI_FILE = "model_ui_schema.json"
41
+
42
+ THRESHOLD = 0.3 # calibrated threshold
43
+
44
+ # ==========================================
45
+ # 3. LOAD ARTIFACTS (FAIL FAST)
46
+ # ==========================================
47
+
48
+ print("⏳ Loading AI Models and Config...")
49
+
50
+ try:
51
+ print("πŸ“ Artifacts directory:", ARTIFACTS_DIR)
52
+ print("πŸ“„ Files found:", os.listdir(ARTIFACTS_DIR))
53
+
54
+ # --- Load Model ---
55
+ model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
56
+ if not os.path.exists(model_path):
57
+ raise FileNotFoundError(f"Model not found at {model_path}")
58
+
59
+ model = CatBoostClassifier()
60
+ model.load_model(model_path)
61
+ print("βœ… CatBoost model loaded")
62
+
63
+ # --- Load Metadata ---
64
+ meta_path = os.path.join(ARTIFACTS_DIR, META_FILE)
65
+ with open(meta_path, "r") as f:
66
+ metadata = json.load(f)
67
+
68
+ TOP_FEATURES = metadata["top_features"]
69
+ print(f"βœ… Metadata loaded ({len(TOP_FEATURES)} features)")
70
+
71
+ # --- Load UI Schema ---
72
+ ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
73
+ with open(ui_path, "r") as f:
74
+ ui_schema = json.load(f)
75
+
76
+ print("βœ… UI schema loaded")
77
+ print(f"🚦 Threshold set to {THRESHOLD}")
78
+
79
+ except Exception as e:
80
+ print("❌ CRITICAL ERROR LOADING ARTIFACTS")
81
+ raise e
82
+
83
+ # ==========================================
84
+ # 4. REQUEST SCHEMA
85
+ # ==========================================
86
+
87
+ class PatientData(BaseModel):
88
+ data: Dict[str, Any]
89
+
90
+ model_config = ConfigDict(
91
+ json_schema_extra={
92
+ "example": {
93
+ "data": {
94
+ "Need for Support": "High",
95
+ "Recieved Support": "Low",
96
+ "Abuse": "Yes",
97
+ "Disease before pregnancy": "None",
98
+ "Occupation before latest pregnancy": "Housewife",
99
+ "Pregnancy plan": "Unplanned",
100
+ "Relationship with husband": "Bad",
101
+ "Major changes or losses during pregnancy": "Yes",
102
+ "Relationship with the in-laws": "Bad",
103
+ "Birth compliancy": "No",
104
+ "Relationship between father and newborn": "Bad",
105
+ "Education Level": "Secondary",
106
+ "Family type": "Nuclear",
107
+ "Diseases during pregnancy": "Yes",
108
+ "Trust and share feelings": "No",
109
+ "Relationship with the newborn": "Average",
110
+ "Occupation After Your Latest Childbirth": "Unemployed",
111
+ "Age": 24,
112
+ "Addiction": "No",
113
+ "Husband's education level": "Secondary"
114
+ }
115
+ }
116
+ }
117
+ )
118
+
119
+ # ==========================================
120
+ # 5. PREPROCESSING
121
+ # ==========================================
122
+
123
+ def preprocess_input(raw_data: dict) -> pd.DataFrame:
124
+ clean_data = {}
125
+
126
+ for k, v in raw_data.items():
127
+ if isinstance(v, str):
128
+ clean_data[k] = v.lower()
129
+ else:
130
+ clean_data[k] = v
131
+
132
+ df = pd.DataFrame([clean_data])
133
+
134
+ # ensure all features exist
135
+ for col in TOP_FEATURES:
136
+ if col not in df.columns:
137
+ df[col] = "unknown"
138
+
139
+ df = df[TOP_FEATURES]
140
+ return df
141
+
142
+ # ==========================================
143
+ # 6. API ENDPOINTS
144
+ # ==========================================
145
+
146
+ @app.get("/")
147
+ def health():
148
+ return {
149
+ "status": "online",
150
+ "model": "CatBoost Top-20",
151
+ "threshold": THRESHOLD
152
+ }
153
+
154
+ @app.get("/config")
155
+ def get_ui_config():
156
+ return ui_schema
157
+
158
+ @app.post("/predict")
159
+ def predict(payload: PatientData):
160
+ try:
161
+ input_df = preprocess_input(payload.data)
162
+ risk_prob = model.predict_proba(input_df)[0][1]
163
+ is_high_risk = risk_prob >= THRESHOLD
164
+
165
+ return {
166
+ "prediction": "HIGH RISK" if is_high_risk else "LOW RISK",
167
+ "risk_probability": round(float(risk_prob), 4),
168
+ "threshold_used": THRESHOLD,
169
+ "flag": int(is_high_risk),
170
+ "clinical_note": (
171
+ "Refer to specialist"
172
+ if is_high_risk
173
+ else "Standard monitoring"
174
+ )
175
+ }
176
+
177
+ except Exception as e:
178
+ raise HTTPException(status_code=500, detail=str(e))
179
+
180
+ # ==========================================
181
+ # 7. RUNNER (HF / DOCKER)
182
+ # ==========================================
183
+
184
+ if __name__ == "__main__":
185
+ print("πŸš€ Starting server...")
186
+ uvicorn.run(app, host="0.0.0.0", port=7860)