Spaces:
Sleeping
Sleeping
File size: 5,184 Bytes
7128abc c84f188 7128abc c84f188 3b383e1 c84f188 c64fad1 c84f188 3b383e1 c64fad1 2e00ed2 c64fad1 a7573b1 c64fad1 a7573b1 c64fad1 a7573b1 3b383e1 061d037 c84f188 061d037 2e00ed2 a7573b1 061d037 a92d9eb 061d037 a92d9eb 061d037 a92d9eb 061d037 a92d9eb 061d037 a92d9eb 061d037 a92d9eb 2e00ed2 a92d9eb a7573b1 061d037 a92d9eb c84f188 c64fad1 a7573b1 c84f188 3b383e1 c64fad1 3b383e1 c84f188 061d037 c84f188 a92d9eb a7573b1 70b86fe a528ee4 17c692f a7573b1 17c692f 70b86fe a7573b1 3b383e1 c84f188 7128abc c84f188 3b383e1 c84f188 a7573b1 c84f188 3b383e1 c84f188 3b383e1 7128abc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# pulmoprobe_backend/app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import pandas as pd
from huggingface_hub import hf_hub_download
import os
import logging
# ------------------------------------------------------------
# Setup Logging
# ------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------------------------------------
# FastAPI Initialization
# ------------------------------------------------------------
app = FastAPI(title="PulmoProbe AI API")
# Allow CORS for frontend communication
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------------------------------------------------
# Hugging Face Model Setup
# ------------------------------------------------------------
os.environ['HF_HOME'] = '/tmp/huggingface'
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
MODEL_REPO_ID = "costaspinto/PulmoProbe"
MODEL_FILENAME = "best_model.joblib"
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
cache_dir=os.environ['HF_HOME']
)
model = joblib.load(model_path)
logger.info("✅ Model loaded successfully")
except Exception as e:
logger.error(f"❌ Failed to load model: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}")
# ------------------------------------------------------------
# Define Input Schema (One-Hot Encoded)
# ------------------------------------------------------------
class OneHotPatientData(BaseModel):
age: float
bmi: float
cholesterol_level: float
hypertension: int
asthma: int
cirrhosis: int
other_cancer: int
gender_Male: int
family_history_Yes: int
country_Belgium: int
country_Bulgaria: int
country_Croatia: int
country_Cyprus: int
country_Czech_Republic: int
country_Denmark: int
country_Estonia: int
country_Finland: int
country_France: int
country_Germany: int
country_Greece: int
country_Hungary: int
country_Ireland: int
country_Italy: int
country_Latvia: int
country_Lithuania: int
country_Luxembourg: int
country_Malta: int
country_Netherlands: int
country_Poland: int
country_Portugal: int
country_Romania: int
country_Slovakia: int
country_Slovenia: int
country_Spain: int
country_Sweden: int
cancer_stage_Stage_II: int
cancer_stage_Stage_III: int
cancer_stage_Stage_IV: int
smoking_status_Former_Smoker: int
smoking_status_Never_Smoked: int
smoking_status_Passive_Smoker: int
treatment_type_Combined: int
treatment_type_Radiation: int
treatment_type_Surgery: int
# ------------------------------------------------------------
# Root Endpoint
# ------------------------------------------------------------
@app.get("/")
def read_root():
return {"message": "Welcome to PulmoProbe AI API"}
# ------------------------------------------------------------
# Prediction Endpoint
# ------------------------------------------------------------
@app.post("/predict")
def predict(data: OneHotPatientData):
try:
input_dict = data.dict()
logger.info(f"Incoming data: {input_dict}")
feature_order = [
'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
'cirrhosis', 'other_cancer', 'gender_Male',
'country_Belgium', 'country_Bulgaria', 'country_Croatia', 'country_Cyprus',
'country_Czech Republic', 'country_Denmark', 'country_Estonia', 'country_Finland',
'country_France', 'country_Germany', 'country_Greece', 'country_Hungary',
'country_Ireland', 'country_Italy', 'country_Latvia', 'country_Lithuania',
'country_Luxembourg', 'country_Malta', 'country_Netherlands', 'country_Poland',
'country_Portugal', 'country_Romania', 'country_Slovakia', 'country_Slovenia',
'country_Spain', 'country_Sweden',
'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii', 'cancer_stage_Stage Iv',
'family_history_Yes',
'smoking_status_Former Smoker', 'smoking_status_Never Smoked', 'smoking_status_Passive Smoker',
'treatment_type_Combined', 'treatment_type_Radiation', 'treatment_type_Surgery'
]
input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order}
input_df = pd.DataFrame([input_dict_complete], columns=feature_order)
probabilities = model.predict_proba(input_df)[0]
confidence_high_risk = probabilities[1]
risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
result = {
"risk": risk_level,
"confidence": f"{confidence_high_risk*100:.1f}%"
}
return result
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) |