appledog00 commited on
Commit
9cd2352
·
verified ·
1 Parent(s): 9d6af53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -35,13 +35,13 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
35
  ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final")
36
 
37
  MODEL_FILE = "catboost_model_top20.cbm"
38
- META_FILE = "catboost_metadata.json"
39
  UI_FILE = "model_ui_schema.json"
40
 
41
- THRESHOLD = 0.3 # calibrated decision threshold
42
 
43
  # =========================================================
44
- # 3. LOAD ARTIFACTS (FAIL FAST)
45
  # =========================================================
46
 
47
  print("⏳ Loading AI Models and Config...")
@@ -54,20 +54,19 @@ print("📄 Files found:", os.listdir(ARTIFACTS_DIR))
54
 
55
  # ---- Load Model ----
56
  model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
57
- if not os.path.exists(model_path):
58
- raise FileNotFoundError(f"Model not found at {model_path}")
59
-
60
  model = CatBoostClassifier()
61
  model.load_model(model_path)
62
  print("✅ CatBoost model loaded")
63
 
64
- # ---- Load Metadata ----
65
- meta_path = os.path.join(ARTIFACTS_DIR, META_FILE)
66
- with open(meta_path, "r") as f:
67
- metadata = json.load(f)
 
 
 
68
 
69
- TOP_FEATURES = metadata["top_features"]
70
- print(f"✅ Metadata loaded ({len(TOP_FEATURES)} features)")
71
 
72
  # ---- Load UI Schema ----
73
  ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
@@ -118,28 +117,19 @@ class PatientData(BaseModel):
118
  # =========================================================
119
 
120
  def preprocess_input(raw_data: Dict[str, Any]) -> pd.DataFrame:
121
- """
122
- - Lowercases categorical strings
123
- - Ensures all Top-20 features exist
124
- - Orders columns exactly as training
125
- """
126
  clean_data = {}
127
 
128
- for key, value in raw_data.items():
129
- if isinstance(value, str):
130
- clean_data[key] = value.lower()
131
- else:
132
- clean_data[key] = value
133
 
134
  df = pd.DataFrame([clean_data])
135
 
136
- # Fill missing features safely
137
  for col in TOP_FEATURES:
138
  if col not in df.columns:
139
  df[col] = "unknown"
140
 
141
- df = df[TOP_FEATURES]
142
- return df
143
 
144
  # =========================================================
145
  # 6. API ENDPOINTS
@@ -155,10 +145,6 @@ def health():
155
 
156
  @app.get("/config")
157
  def get_ui_config():
158
- """
159
- Returns your provided UI JSON
160
- Used by frontend to auto-render form
161
- """
162
  return ui_schema
163
 
164
  @app.post("/predict")
@@ -166,7 +152,6 @@ def predict(payload: PatientData):
166
  try:
167
  input_df = preprocess_input(payload.data)
168
 
169
- # Probability of positive (PPD risk)
170
  risk_prob = model.predict_proba(input_df)[0][1]
171
  is_high_risk = risk_prob >= THRESHOLD
172
 
@@ -186,7 +171,7 @@ def predict(payload: PatientData):
186
  raise HTTPException(status_code=500, detail=str(e))
187
 
188
  # =========================================================
189
- # 7. RUNNER (HF / DOCKER)
190
  # =========================================================
191
 
192
  if __name__ == "__main__":
 
35
  ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final")
36
 
37
  MODEL_FILE = "catboost_model_top20.cbm"
38
+ TOP_FEATURES_FILE = "top20_features.csv"
39
  UI_FILE = "model_ui_schema.json"
40
 
41
+ THRESHOLD = 0.3
42
 
43
  # =========================================================
44
+ # 3. LOAD ARTIFACTS
45
  # =========================================================
46
 
47
  print("⏳ Loading AI Models and Config...")
 
54
 
55
  # ---- Load Model ----
56
  model_path = os.path.join(ARTIFACTS_DIR, MODEL_FILE)
 
 
 
57
  model = CatBoostClassifier()
58
  model.load_model(model_path)
59
  print("✅ CatBoost model loaded")
60
 
61
+ # ---- Load Top 20 Features (SOURCE OF TRUTH) ----
62
+ features_path = os.path.join(ARTIFACTS_DIR, TOP_FEATURES_FILE)
63
+ TOP_FEATURES = (
64
+ pd.read_csv(features_path, header=None)[0]
65
+ .astype(str)
66
+ .tolist()
67
+ )
68
 
69
+ print(f"✅ Loaded {len(TOP_FEATURES)} top features")
 
70
 
71
  # ---- Load UI Schema ----
72
  ui_path = os.path.join(ARTIFACTS_DIR, UI_FILE)
 
117
  # =========================================================
118
 
119
  def preprocess_input(raw_data: Dict[str, Any]) -> pd.DataFrame:
 
 
 
 
 
120
  clean_data = {}
121
 
122
+ for k, v in raw_data.items():
123
+ clean_data[k] = v.lower() if isinstance(v, str) else v
 
 
 
124
 
125
  df = pd.DataFrame([clean_data])
126
 
127
+ # Ensure all required features exist
128
  for col in TOP_FEATURES:
129
  if col not in df.columns:
130
  df[col] = "unknown"
131
 
132
+ return df[TOP_FEATURES]
 
133
 
134
  # =========================================================
135
  # 6. API ENDPOINTS
 
145
 
146
  @app.get("/config")
147
  def get_ui_config():
 
 
 
 
148
  return ui_schema
149
 
150
  @app.post("/predict")
 
152
  try:
153
  input_df = preprocess_input(payload.data)
154
 
 
155
  risk_prob = model.predict_proba(input_df)[0][1]
156
  is_high_risk = risk_prob >= THRESHOLD
157
 
 
171
  raise HTTPException(status_code=500, detail=str(e))
172
 
173
  # =========================================================
174
+ # 7. RUNNER (HF)
175
  # =========================================================
176
 
177
  if __name__ == "__main__":