Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import uvicorn | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, ConfigDict | |
| from catboost import CatBoostClassifier | |
| from typing import Dict, Any | |
| # ========================================== | |
| # 1. SETUP & CONFIGURATION | |
| # ========================================== | |
| app = FastAPI( | |
| title="PPD Risk Assessment API", | |
| description="AI-powered screening tool for Postpartum Depression Risk (Top 20 Features)", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Replace "*" with your frontend URL in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ========================================== | |
| # 2. ARTIFACT PATH SETUP | |
| # ========================================== | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| ARTIFACTS_DIR = os.path.join(BASE_DIR, "artifacts_final") # Hugging Face compatible path | |
| print("ARTIFACTS DIR:", ARTIFACTS_DIR) | |
| print("EXISTS:", os.path.exists(ARTIFACTS_DIR)) | |
| # ========================================== | |
| # 3. LOAD ARTIFACTS | |
| # ========================================== | |
| print(" Loading AI Models and Config...") | |
| try: | |
| # A. Load CatBoost Model | |
| model_path = os.path.join(ARTIFACTS_DIR, "catboost_model_top20.cbm") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| model = CatBoostClassifier() | |
| model.load_model(model_path) | |
| print(" Model Loaded.") | |
| # B. Load Metadata | |
| meta_path = os.path.join(ARTIFACTS_DIR, "catboost_metadata.json") | |
| if not os.path.exists(meta_path): | |
| raise FileNotFoundError(f"Metadata not found at {meta_path}") | |
| with open(meta_path, "r") as f: | |
| metadata = json.load(f) | |
| TOP_FEATURES = metadata.get("features_used", []) | |
| THRESHOLD = metadata.get("thresholds", {}).get("optimal_balanced", 0.5) | |
| print(f" Metadata Loaded. Threshold set to: {THRESHOLD}") | |
| # C. Load UI Schema | |
| ui_path = os.path.join(ARTIFACTS_DIR, "model_ui_schema.json") | |
| if not os.path.exists(ui_path): | |
| raise FileNotFoundError(f"UI schema not found at {ui_path}") | |
| with open(ui_path, "r") as f: | |
| ui_schema = json.load(f) | |
| print(" UI Schema Loaded.") | |
| except Exception as e: | |
| print(f" CRITICAL ERROR LOADING ARTIFACTS: {e}") | |
| raise e | |
| # ========================================== | |
| # 4. DATA VALIDATION | |
| # ========================================== | |
| class PatientData(BaseModel): | |
| data: Dict[str, Any] | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "data": { | |
| "Need for Support": "High", | |
| "Recieved Support": "Low", | |
| "Abuse": "Yes", | |
| "Disease before pregnancy": "None", | |
| "Occupation before latest pregnancy": "Housewife", | |
| "Pregnancy plan": "Unplanned", | |
| "Relationship with husband": "Bad", | |
| "Major changes or losses during pregnancy": "Yes", | |
| "Relationship with the in-laws": "Bad", | |
| "Birth compliancy": "No", | |
| "Relationship between father and newborn": "Bad", | |
| "Education Level": "Secondary", | |
| "Family type": "Nuclear", | |
| "Diseases during pregnancy": "Yes", | |
| "Trust and share feelings": "No", | |
| "Relationship with the newborn": "Average", | |
| "Occupation After Your Latest Childbirth": "Unemployed", | |
| "Age": 24, | |
| "Addiction": "No", | |
| "Husband's education level": "Secondary" | |
| } | |
| } | |
| } | |
| ) | |
| # ========================================== | |
| # 5. HELPER FUNCTION | |
| # ========================================== | |
| def preprocess_input(raw_data: dict) -> pd.DataFrame: | |
| clean_data = {} | |
| for k, v in raw_data.items(): | |
| if isinstance(v, str): | |
| clean_data[k] = v.lower() | |
| else: | |
| clean_data[k] = v | |
| df = pd.DataFrame([clean_data]) | |
| # Fill missing features | |
| for col in TOP_FEATURES: | |
| if col not in df.columns: | |
| df[col] = "unknown" | |
| df = df[TOP_FEATURES] | |
| return df | |
| # ========================================== | |
| # 6. API ENDPOINTS | |
| # ========================================== | |
| def home(): | |
| return {"status": "online", "model": "CatBoost Top20", "threshold": THRESHOLD} | |
| def get_ui_config(): | |
| return ui_schema | |
| def predict_risk(payload: PatientData): | |
| try: | |
| input_df = preprocess_input(payload.data) | |
| risk_prob = model.predict_proba(input_df)[0][1] | |
| is_high_risk = bool(risk_prob >= THRESHOLD) | |
| return { | |
| "prediction": "HIGH RISK" if is_high_risk else "LOW RISK", | |
| "risk_probability": round(float(risk_prob), 4), | |
| "threshold_used": THRESHOLD, | |
| "flag": 1 if is_high_risk else 0, | |
| "clinical_note": "Refer to specialist" if is_high_risk else "Standard monitoring" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ========================================== | |
| # 7. RUNNER | |
| # ========================================== | |
| if __name__ == "__main__": | |
| print(" Server starting...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |