# pulmoprobe_backend/app.py from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import joblib import pandas as pd from huggingface_hub import hf_hub_download import os import logging # ------------------------------------------------------------ # Setup Logging # ------------------------------------------------------------ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ------------------------------------------------------------ # FastAPI Initialization # ------------------------------------------------------------ app = FastAPI(title="PulmoProbe AI API") # Allow CORS for frontend communication app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------------------------------------------ # Hugging Face Model Setup # ------------------------------------------------------------ os.environ['HF_HOME'] = '/tmp/huggingface' os.makedirs(os.environ['HF_HOME'], exist_ok=True) MODEL_REPO_ID = "costaspinto/PulmoProbe" MODEL_FILENAME = "best_model.joblib" try: model_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, cache_dir=os.environ['HF_HOME'] ) model = joblib.load(model_path) logger.info("✅ Model loaded successfully") except Exception as e: logger.error(f"❌ Failed to load model: {str(e)}") raise RuntimeError(f"Model loading failed: {str(e)}") # ------------------------------------------------------------ # Define Input Schema (One-Hot Encoded) # ------------------------------------------------------------ class OneHotPatientData(BaseModel): age: float bmi: float cholesterol_level: float hypertension: int asthma: int cirrhosis: int other_cancer: int gender_Male: int family_history_Yes: int country_Belgium: int country_Bulgaria: int country_Croatia: int country_Cyprus: int country_Czech_Republic: int country_Denmark: int country_Estonia: int country_Finland: int country_France: int country_Germany: int country_Greece: int country_Hungary: int country_Ireland: int country_Italy: int country_Latvia: int country_Lithuania: int country_Luxembourg: int country_Malta: int country_Netherlands: int country_Poland: int country_Portugal: int country_Romania: int country_Slovakia: int country_Slovenia: int country_Spain: int country_Sweden: int cancer_stage_Stage_II: int cancer_stage_Stage_III: int cancer_stage_Stage_IV: int smoking_status_Former_Smoker: int smoking_status_Never_Smoked: int smoking_status_Passive_Smoker: int treatment_type_Combined: int treatment_type_Radiation: int treatment_type_Surgery: int # ------------------------------------------------------------ # Root Endpoint # ------------------------------------------------------------ @app.get("/") def read_root(): return {"message": "Welcome to PulmoProbe AI API"} # ------------------------------------------------------------ # Prediction Endpoint # ------------------------------------------------------------ @app.post("/predict") def predict(data: OneHotPatientData): try: input_dict = data.dict() logger.info(f"Incoming data: {input_dict}") feature_order = [ 'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma', 'cirrhosis', 'other_cancer', 'gender_Male', 'country_Belgium', 'country_Bulgaria', 'country_Croatia', 'country_Cyprus', 'country_Czech Republic', 'country_Denmark', 'country_Estonia', 'country_Finland', 'country_France', 'country_Germany', 'country_Greece', 'country_Hungary', 'country_Ireland', 'country_Italy', 'country_Latvia', 'country_Lithuania', 'country_Luxembourg', 'country_Malta', 'country_Netherlands', 'country_Poland', 'country_Portugal', 'country_Romania', 'country_Slovakia', 'country_Slovenia', 'country_Spain', 'country_Sweden', 'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii', 'cancer_stage_Stage Iv', 'family_history_Yes', 'smoking_status_Former Smoker', 'smoking_status_Never Smoked', 'smoking_status_Passive Smoker', 'treatment_type_Combined', 'treatment_type_Radiation', 'treatment_type_Surgery' ] input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order} input_df = pd.DataFrame([input_dict_complete], columns=feature_order) probabilities = model.predict_proba(input_df)[0] confidence_high_risk = probabilities[1] risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival" result = { "risk": risk_level, "confidence": f"{confidence_high_risk*100:.1f}%" } return result except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=str(e))