Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,12 +34,10 @@ app.add_middleware(
|
|
| 34 |
# ------------------------------------------------------------
|
| 35 |
os.environ['HF_HOME'] = '/tmp/huggingface'
|
| 36 |
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
|
| 37 |
-
logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
|
| 38 |
|
| 39 |
MODEL_REPO_ID = "costaspinto/PulmoProbe"
|
| 40 |
MODEL_FILENAME = "best_model.joblib"
|
| 41 |
|
| 42 |
-
logger.info("Downloading model from Hugging Face Hub...")
|
| 43 |
try:
|
| 44 |
model_path = hf_hub_download(
|
| 45 |
repo_id=MODEL_REPO_ID,
|
|
@@ -47,13 +45,13 @@ try:
|
|
| 47 |
cache_dir=os.environ['HF_HOME']
|
| 48 |
)
|
| 49 |
model = joblib.load(model_path)
|
| 50 |
-
logger.info("Model loaded successfully
|
| 51 |
except Exception as e:
|
| 52 |
-
logger.error(f"Failed to load model: {str(e)}")
|
| 53 |
raise RuntimeError(f"Model loading failed: {str(e)}")
|
| 54 |
|
| 55 |
# ------------------------------------------------------------
|
| 56 |
-
# Define Input Schema (
|
| 57 |
# ------------------------------------------------------------
|
| 58 |
class OneHotPatientData(BaseModel):
|
| 59 |
age: float
|
|
@@ -65,7 +63,7 @@ class OneHotPatientData(BaseModel):
|
|
| 65 |
other_cancer: int
|
| 66 |
gender_Male: int
|
| 67 |
family_history_Yes: int
|
| 68 |
-
|
| 69 |
country_Belgium: int
|
| 70 |
country_Bulgaria: int
|
| 71 |
country_Croatia: int
|
|
@@ -93,7 +91,6 @@ class OneHotPatientData(BaseModel):
|
|
| 93 |
country_Spain: int
|
| 94 |
country_Sweden: int
|
| 95 |
|
| 96 |
-
# Corrected to use uppercase Roman numerals
|
| 97 |
cancer_stage_Stage_II: int
|
| 98 |
cancer_stage_Stage_III: int
|
| 99 |
cancer_stage_Stage_IV: int
|
|
@@ -101,7 +98,7 @@ class OneHotPatientData(BaseModel):
|
|
| 101 |
smoking_status_Former_Smoker: int
|
| 102 |
smoking_status_Never_Smoked: int
|
| 103 |
smoking_status_Passive_Smoker: int
|
| 104 |
-
|
| 105 |
treatment_type_Combined: int
|
| 106 |
treatment_type_Radiation: int
|
| 107 |
treatment_type_Surgery: int
|
|
@@ -111,7 +108,7 @@ class OneHotPatientData(BaseModel):
|
|
| 111 |
# ------------------------------------------------------------
|
| 112 |
@app.get("/")
|
| 113 |
def read_root():
|
| 114 |
-
return {"message": "Welcome to
|
| 115 |
|
| 116 |
# ------------------------------------------------------------
|
| 117 |
# Prediction Endpoint
|
|
@@ -121,46 +118,38 @@ def predict(data: OneHotPatientData):
|
|
| 121 |
try:
|
| 122 |
input_dict = data.dict()
|
| 123 |
logger.info(f"Incoming data: {input_dict}")
|
| 124 |
-
|
| 125 |
-
# Define the exact feature order your model expects (with underscores and uppercase Roman numerals)
|
| 126 |
feature_order = [
|
| 127 |
'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
|
| 128 |
-
'cirrhosis', 'other_cancer', 'gender_Male',
|
| 129 |
-
'country_Bulgaria',
|
| 130 |
-
'country_Czech_Republic',
|
| 131 |
-
'
|
| 132 |
-
'
|
| 133 |
-
'
|
| 134 |
-
'
|
| 135 |
-
'
|
| 136 |
-
'
|
| 137 |
-
'
|
| 138 |
-
'
|
| 139 |
-
'
|
| 140 |
-
'smoking_status_Former_Smoker', 'smoking_status_Never_Smoked',
|
| 141 |
-
'smoking_status_Passive_Smoker', 'treatment_type_Combined',
|
| 142 |
-
'treatment_type_Radiation', 'treatment_type_Surgery'
|
| 143 |
]
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
|
| 148 |
|
| 149 |
# Predict probabilities
|
| 150 |
probabilities = model.predict_proba(input_df)[0]
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
confidence_high_risk = probabilities[0]
|
| 154 |
risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
|
| 155 |
|
| 156 |
result = {
|
| 157 |
"risk": risk_level,
|
| 158 |
-
"confidence": f"{confidence_high_risk
|
| 159 |
}
|
| 160 |
-
|
| 161 |
-
logger.info(f"Prediction result: {result}")
|
| 162 |
return result
|
| 163 |
|
| 164 |
except Exception as e:
|
| 165 |
logger.error(f"Prediction error: {str(e)}")
|
| 166 |
-
return {"error": str(e), "input_data_received": data.dict()}
|
|
|
|
| 34 |
# ------------------------------------------------------------
|
| 35 |
os.environ['HF_HOME'] = '/tmp/huggingface'
|
| 36 |
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
|
|
|
|
| 37 |
|
| 38 |
MODEL_REPO_ID = "costaspinto/PulmoProbe"
|
| 39 |
MODEL_FILENAME = "best_model.joblib"
|
| 40 |
|
|
|
|
| 41 |
try:
|
| 42 |
model_path = hf_hub_download(
|
| 43 |
repo_id=MODEL_REPO_ID,
|
|
|
|
| 45 |
cache_dir=os.environ['HF_HOME']
|
| 46 |
)
|
| 47 |
model = joblib.load(model_path)
|
| 48 |
+
logger.info("✅ Model loaded successfully")
|
| 49 |
except Exception as e:
|
| 50 |
+
logger.error(f"❌ Failed to load model: {str(e)}")
|
| 51 |
raise RuntimeError(f"Model loading failed: {str(e)}")
|
| 52 |
|
| 53 |
# ------------------------------------------------------------
|
| 54 |
+
# Define Input Schema (One-Hot Encoded)
|
| 55 |
# ------------------------------------------------------------
|
| 56 |
class OneHotPatientData(BaseModel):
|
| 57 |
age: float
|
|
|
|
| 63 |
other_cancer: int
|
| 64 |
gender_Male: int
|
| 65 |
family_history_Yes: int
|
| 66 |
+
|
| 67 |
country_Belgium: int
|
| 68 |
country_Bulgaria: int
|
| 69 |
country_Croatia: int
|
|
|
|
| 91 |
country_Spain: int
|
| 92 |
country_Sweden: int
|
| 93 |
|
|
|
|
| 94 |
cancer_stage_Stage_II: int
|
| 95 |
cancer_stage_Stage_III: int
|
| 96 |
cancer_stage_Stage_IV: int
|
|
|
|
| 98 |
smoking_status_Former_Smoker: int
|
| 99 |
smoking_status_Never_Smoked: int
|
| 100 |
smoking_status_Passive_Smoker: int
|
| 101 |
+
|
| 102 |
treatment_type_Combined: int
|
| 103 |
treatment_type_Radiation: int
|
| 104 |
treatment_type_Surgery: int
|
|
|
|
| 108 |
# ------------------------------------------------------------
|
| 109 |
@app.get("/")
|
| 110 |
def read_root():
|
| 111 |
+
return {"message": "Welcome to PulmoProbe AI API"}
|
| 112 |
|
| 113 |
# ------------------------------------------------------------
|
| 114 |
# Prediction Endpoint
|
|
|
|
| 118 |
try:
|
| 119 |
input_dict = data.dict()
|
| 120 |
logger.info(f"Incoming data: {input_dict}")
|
| 121 |
+
|
|
|
|
| 122 |
feature_order = [
|
| 123 |
'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
|
| 124 |
+
'cirrhosis', 'other_cancer', 'gender_Male',
|
| 125 |
+
'country_Belgium','country_Bulgaria','country_Croatia','country_Cyprus',
|
| 126 |
+
'country_Czech_Republic','country_Denmark','country_Estonia','country_Finland',
|
| 127 |
+
'country_France','country_Germany','country_Greece','country_Hungary',
|
| 128 |
+
'country_Ireland','country_Italy','country_Latvia','country_Lithuania',
|
| 129 |
+
'country_Luxembourg','country_Malta','country_Netherlands','country_Poland',
|
| 130 |
+
'country_Portugal','country_Romania','country_Slovakia','country_Slovenia',
|
| 131 |
+
'country_Spain','country_Sweden',
|
| 132 |
+
'cancer_stage_Stage_II','cancer_stage_Stage_III','cancer_stage_Stage_IV',
|
| 133 |
+
'family_history_Yes',
|
| 134 |
+
'smoking_status_Former_Smoker','smoking_status_Never_Smoked','smoking_status_Passive_Smoker',
|
| 135 |
+
'treatment_type_Combined','treatment_type_Radiation','treatment_type_Surgery'
|
|
|
|
|
|
|
|
|
|
| 136 |
]
|
| 137 |
|
| 138 |
+
# Fill missing fields with 0
|
| 139 |
+
input_dict_complete = {col: input_dict.get(col, 0) for col in feature_order}
|
| 140 |
+
input_df = pd.DataFrame([input_dict_complete], columns=feature_order)
|
| 141 |
|
| 142 |
# Predict probabilities
|
| 143 |
probabilities = model.predict_proba(input_df)[0]
|
| 144 |
+
confidence_high_risk = probabilities[1] # Class 1 = High Risk
|
|
|
|
|
|
|
| 145 |
risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
|
| 146 |
|
| 147 |
result = {
|
| 148 |
"risk": risk_level,
|
| 149 |
+
"confidence": f"{confidence_high_risk*100:.1f}%"
|
| 150 |
}
|
|
|
|
|
|
|
| 151 |
return result
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
logger.error(f"Prediction error: {str(e)}")
|
| 155 |
+
return {"error": str(e), "input_data_received": data.dict()}
|