PulmoProbe / app.py
costaspinto's picture
Update app.py
17c692f verified
# pulmoprobe_backend/app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import pandas as pd
from huggingface_hub import hf_hub_download
import os
import logging
# ------------------------------------------------------------
# Setup Logging
# ------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------------------------------------
# FastAPI Initialization
# ------------------------------------------------------------
app = FastAPI(title="PulmoProbe AI API")
# Allow CORS for frontend communication
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------------------------------------------------
# Hugging Face Model Setup
# ------------------------------------------------------------
os.environ['HF_HOME'] = '/tmp/huggingface'
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
MODEL_REPO_ID = "costaspinto/PulmoProbe"
MODEL_FILENAME = "best_model.joblib"
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
cache_dir=os.environ['HF_HOME']
)
model = joblib.load(model_path)
logger.info("✅ Model loaded successfully")
except Exception as e:
logger.error(f"❌ Failed to load model: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}")
# ------------------------------------------------------------
# Define Input Schema (One-Hot Encoded)
# ------------------------------------------------------------
class OneHotPatientData(BaseModel):
age: float
bmi: float
cholesterol_level: float
hypertension: int
asthma: int
cirrhosis: int
other_cancer: int
gender_Male: int
family_history_Yes: int
country_Belgium: int
country_Bulgaria: int
country_Croatia: int
country_Cyprus: int
country_Czech_Republic: int
country_Denmark: int
country_Estonia: int
country_Finland: int
country_France: int
country_Germany: int
country_Greece: int
country_Hungary: int
country_Ireland: int
country_Italy: int
country_Latvia: int
country_Lithuania: int
country_Luxembourg: int
country_Malta: int
country_Netherlands: int
country_Poland: int
country_Portugal: int
country_Romania: int
country_Slovakia: int
country_Slovenia: int
country_Spain: int
country_Sweden: int
cancer_stage_Stage_II: int
cancer_stage_Stage_III: int
cancer_stage_Stage_IV: int
smoking_status_Former_Smoker: int
smoking_status_Never_Smoked: int
smoking_status_Passive_Smoker: int
treatment_type_Combined: int
treatment_type_Radiation: int
treatment_type_Surgery: int
# ------------------------------------------------------------
# Root Endpoint
# ------------------------------------------------------------
@app.get("/")
def read_root():
return {"message": "Welcome to PulmoProbe AI API"}
# ------------------------------------------------------------
# Prediction Endpoint
# ------------------------------------------------------------
@app.post("/predict")
def predict(data: OneHotPatientData):
try:
input_dict = data.dict()
logger.info(f"Incoming data: {input_dict}")
feature_order = [
'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
'cirrhosis', 'other_cancer', 'gender_Male',
'country_Belgium', 'country_Bulgaria', 'country_Croatia', 'country_Cyprus',
'country_Czech Republic', 'country_Denmark', 'country_Estonia', 'country_Finland',
'country_France', 'country_Germany', 'country_Greece', 'country_Hungary',
'country_Ireland', 'country_Italy', 'country_Latvia', 'country_Lithuania',
'country_Luxembourg', 'country_Malta', 'country_Netherlands', 'country_Poland',
'country_Portugal', 'country_Romania', 'country_Slovakia', 'country_Slovenia',
'country_Spain', 'country_Sweden',
'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii', 'cancer_stage_Stage Iv',
'family_history_Yes',
'smoking_status_Former Smoker', 'smoking_status_Never Smoked', 'smoking_status_Passive Smoker',
'treatment_type_Combined', 'treatment_type_Radiation', 'treatment_type_Surgery'
]
input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order}
input_df = pd.DataFrame([input_dict_complete], columns=feature_order)
probabilities = model.predict_proba(input_df)[0]
confidence_high_risk = probabilities[1]
risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
result = {
"risk": risk_level,
"confidence": f"{confidence_high_risk*100:.1f}%"
}
return result
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))