File size: 5,184 Bytes
7128abc
c84f188
7128abc
c84f188
 
 
 
 
 
3b383e1
c84f188
c64fad1
 
 
 
 
c84f188
3b383e1
c64fad1
 
 
 
 
 
 
2e00ed2
c64fad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7573b1
c64fad1
a7573b1
c64fad1
 
 
a7573b1
3b383e1
061d037
c84f188
 
 
 
 
 
 
061d037
2e00ed2
a7573b1
061d037
a92d9eb
061d037
a92d9eb
 
061d037
a92d9eb
 
061d037
a92d9eb
 
 
 
 
 
 
 
 
 
 
 
 
061d037
a92d9eb
061d037
a92d9eb
2e00ed2
 
 
 
 
a92d9eb
 
 
a7573b1
061d037
a92d9eb
 
c84f188
c64fad1
 
 
 
 
a7573b1
c84f188
3b383e1
c64fad1
3b383e1
c84f188
061d037
c84f188
a92d9eb
 
a7573b1
70b86fe
a528ee4
 
17c692f
 
 
 
 
 
 
 
a7573b1
17c692f
 
70b86fe
 
a7573b1
 
3b383e1
c84f188
7128abc
c84f188
 
3b383e1
c84f188
a7573b1
c84f188
3b383e1
 
c84f188
3b383e1
7128abc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# pulmoprobe_backend/app.py

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import pandas as pd
from huggingface_hub import hf_hub_download
import os
import logging

# ------------------------------------------------------------
# Setup Logging
# ------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ------------------------------------------------------------
# FastAPI Initialization
# ------------------------------------------------------------
app = FastAPI(title="PulmoProbe AI API")

# Allow CORS for frontend communication
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ------------------------------------------------------------
# Hugging Face Model Setup
# ------------------------------------------------------------
os.environ['HF_HOME'] = '/tmp/huggingface'
os.makedirs(os.environ['HF_HOME'], exist_ok=True)

MODEL_REPO_ID = "costaspinto/PulmoProbe"
MODEL_FILENAME = "best_model.joblib"

try:
    model_path = hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=MODEL_FILENAME,
        cache_dir=os.environ['HF_HOME']
    )
    model = joblib.load(model_path)
    logger.info("✅ Model loaded successfully")
except Exception as e:
    logger.error(f"❌ Failed to load model: {str(e)}")
    raise RuntimeError(f"Model loading failed: {str(e)}")

# ------------------------------------------------------------
# Define Input Schema (One-Hot Encoded)
# ------------------------------------------------------------
class OneHotPatientData(BaseModel):
    age: float
    bmi: float
    cholesterol_level: float
    hypertension: int
    asthma: int
    cirrhosis: int
    other_cancer: int
    gender_Male: int
    family_history_Yes: int

    country_Belgium: int
    country_Bulgaria: int
    country_Croatia: int
    country_Cyprus: int
    country_Czech_Republic: int
    country_Denmark: int
    country_Estonia: int
    country_Finland: int
    country_France: int
    country_Germany: int
    country_Greece: int
    country_Hungary: int
    country_Ireland: int
    country_Italy: int
    country_Latvia: int
    country_Lithuania: int
    country_Luxembourg: int
    country_Malta: int
    country_Netherlands: int
    country_Poland: int
    country_Portugal: int
    country_Romania: int
    country_Slovakia: int
    country_Slovenia: int
    country_Spain: int
    country_Sweden: int

    cancer_stage_Stage_II: int
    cancer_stage_Stage_III: int
    cancer_stage_Stage_IV: int

    smoking_status_Former_Smoker: int
    smoking_status_Never_Smoked: int
    smoking_status_Passive_Smoker: int

    treatment_type_Combined: int
    treatment_type_Radiation: int
    treatment_type_Surgery: int

# ------------------------------------------------------------
# Root Endpoint
# ------------------------------------------------------------
@app.get("/")
def read_root():
    return {"message": "Welcome to PulmoProbe AI API"}

# ------------------------------------------------------------
# Prediction Endpoint
# ------------------------------------------------------------
@app.post("/predict")
def predict(data: OneHotPatientData):
    try:
        input_dict = data.dict()
        logger.info(f"Incoming data: {input_dict}")

        feature_order = [
            'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
            'cirrhosis', 'other_cancer', 'gender_Male',
            'country_Belgium', 'country_Bulgaria', 'country_Croatia', 'country_Cyprus',
            'country_Czech Republic', 'country_Denmark', 'country_Estonia', 'country_Finland',
            'country_France', 'country_Germany', 'country_Greece', 'country_Hungary',
            'country_Ireland', 'country_Italy', 'country_Latvia', 'country_Lithuania',
            'country_Luxembourg', 'country_Malta', 'country_Netherlands', 'country_Poland',
            'country_Portugal', 'country_Romania', 'country_Slovakia', 'country_Slovenia',
            'country_Spain', 'country_Sweden',
            'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii', 'cancer_stage_Stage Iv',
            'family_history_Yes',
            'smoking_status_Former Smoker', 'smoking_status_Never Smoked', 'smoking_status_Passive Smoker',
            'treatment_type_Combined', 'treatment_type_Radiation', 'treatment_type_Surgery'
        ]

        input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order}
        input_df = pd.DataFrame([input_dict_complete], columns=feature_order)

        probabilities = model.predict_proba(input_df)[0]
        confidence_high_risk = probabilities[1]
        risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"

        result = {
            "risk": risk_level,
            "confidence": f"{confidence_high_risk*100:.1f}%"
        }
        return result

    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))