Spaces:
Sleeping
Sleeping
| # pulmoprobe_backend/app.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import joblib | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import logging | |
| # ------------------------------------------------------------ | |
| # Setup Logging | |
| # ------------------------------------------------------------ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------------------ | |
| # FastAPI Initialization | |
| # ------------------------------------------------------------ | |
| app = FastAPI(title="PulmoProbe AI API") | |
| # Allow CORS for frontend communication | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------------------------------------------ | |
| # Hugging Face Model Setup | |
| # ------------------------------------------------------------ | |
| os.environ['HF_HOME'] = '/tmp/huggingface' | |
| os.makedirs(os.environ['HF_HOME'], exist_ok=True) | |
| MODEL_REPO_ID = "costaspinto/PulmoProbe" | |
| MODEL_FILENAME = "best_model.joblib" | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILENAME, | |
| cache_dir=os.environ['HF_HOME'] | |
| ) | |
| model = joblib.load(model_path) | |
| logger.info("✅ Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load model: {str(e)}") | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| # ------------------------------------------------------------ | |
| # Define Input Schema (One-Hot Encoded) | |
| # ------------------------------------------------------------ | |
| class OneHotPatientData(BaseModel): | |
| age: float | |
| bmi: float | |
| cholesterol_level: float | |
| hypertension: int | |
| asthma: int | |
| cirrhosis: int | |
| other_cancer: int | |
| gender_Male: int | |
| family_history_Yes: int | |
| country_Belgium: int | |
| country_Bulgaria: int | |
| country_Croatia: int | |
| country_Cyprus: int | |
| country_Czech_Republic: int | |
| country_Denmark: int | |
| country_Estonia: int | |
| country_Finland: int | |
| country_France: int | |
| country_Germany: int | |
| country_Greece: int | |
| country_Hungary: int | |
| country_Ireland: int | |
| country_Italy: int | |
| country_Latvia: int | |
| country_Lithuania: int | |
| country_Luxembourg: int | |
| country_Malta: int | |
| country_Netherlands: int | |
| country_Poland: int | |
| country_Portugal: int | |
| country_Romania: int | |
| country_Slovakia: int | |
| country_Slovenia: int | |
| country_Spain: int | |
| country_Sweden: int | |
| cancer_stage_Stage_II: int | |
| cancer_stage_Stage_III: int | |
| cancer_stage_Stage_IV: int | |
| smoking_status_Former_Smoker: int | |
| smoking_status_Never_Smoked: int | |
| smoking_status_Passive_Smoker: int | |
| treatment_type_Combined: int | |
| treatment_type_Radiation: int | |
| treatment_type_Surgery: int | |
| # ------------------------------------------------------------ | |
| # Root Endpoint | |
| # ------------------------------------------------------------ | |
| def read_root(): | |
| return {"message": "Welcome to PulmoProbe AI API"} | |
| # ------------------------------------------------------------ | |
| # Prediction Endpoint | |
| # ------------------------------------------------------------ | |
| def predict(data: OneHotPatientData): | |
| try: | |
| input_dict = data.dict() | |
| logger.info(f"Incoming data: {input_dict}") | |
| feature_order = [ | |
| 'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma', | |
| 'cirrhosis', 'other_cancer', 'gender_Male', | |
| 'country_Belgium', 'country_Bulgaria', 'country_Croatia', 'country_Cyprus', | |
| 'country_Czech Republic', 'country_Denmark', 'country_Estonia', 'country_Finland', | |
| 'country_France', 'country_Germany', 'country_Greece', 'country_Hungary', | |
| 'country_Ireland', 'country_Italy', 'country_Latvia', 'country_Lithuania', | |
| 'country_Luxembourg', 'country_Malta', 'country_Netherlands', 'country_Poland', | |
| 'country_Portugal', 'country_Romania', 'country_Slovakia', 'country_Slovenia', | |
| 'country_Spain', 'country_Sweden', | |
| 'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii', 'cancer_stage_Stage Iv', | |
| 'family_history_Yes', | |
| 'smoking_status_Former Smoker', 'smoking_status_Never Smoked', 'smoking_status_Passive Smoker', | |
| 'treatment_type_Combined', 'treatment_type_Radiation', 'treatment_type_Surgery' | |
| ] | |
| input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order} | |
| input_df = pd.DataFrame([input_dict_complete], columns=feature_order) | |
| probabilities = model.predict_proba(input_df)[0] | |
| confidence_high_risk = probabilities[1] | |
| risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival" | |
| result = { | |
| "risk": risk_level, | |
| "confidence": f"{confidence_high_risk*100:.1f}%" | |
| } | |
| return result | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) |