costaspinto commited on
Commit
c64fad1
·
verified ·
1 Parent(s): fa72262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -20
app.py CHANGED
@@ -9,22 +9,66 @@ from huggingface_hub import hf_hub_download
9
  import os
10
  import logging
11
 
12
- # ... (logging setup, FastAPI initialization, and model loading remain the same) ...
 
 
 
 
13
 
14
  # ------------------------------------------------------------
15
- # Define Input Schema (Corrected Names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # ------------------------------------------------------------
17
  class OneHotPatientData(BaseModel):
 
18
  age: float
19
  bmi: float
20
  cholesterol_level: float
 
21
  hypertension: int
22
  asthma: int
23
  cirrhosis: int
24
  other_cancer: int
 
25
  gender_Male: int
26
- family_history_Yes: int
27
-
28
  country_Belgium: int
29
  country_Bulgaria: int
30
  country_Croatia: int
@@ -51,23 +95,30 @@ class OneHotPatientData(BaseModel):
51
  country_Slovenia: int
52
  country_Spain: int
53
  country_Sweden: int
54
-
55
- cancer_stage_Stage_II: int
56
- cancer_stage_Stage_III: int
57
- cancer_stage_Stage_IV: int
58
-
 
 
59
  smoking_status_Former_Smoker: int
60
  smoking_status_Never_Smoked: int
61
  smoking_status_Passive_Smoker: int
62
-
63
  treatment_type_Combined: int
64
  treatment_type_Radiation: int
65
  treatment_type_Surgery: int
66
 
67
- # ... (root endpoint remains the same) ...
 
 
 
 
 
68
 
69
  # ------------------------------------------------------------
70
- # Prediction Endpoint (Corrected Feature Order)
71
  # ------------------------------------------------------------
72
  @app.post("/predict")
73
  def predict(data: OneHotPatientData):
@@ -75,26 +126,26 @@ def predict(data: OneHotPatientData):
75
  input_dict = data.dict()
76
  logger.info(f"Incoming data: {input_dict}")
77
 
78
- # Define the exact feature order your model expects (with underscores)
79
  feature_order = [
80
  'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
81
  'cirrhosis', 'other_cancer', 'gender_Male', 'country_Belgium',
82
  'country_Bulgaria', 'country_Croatia', 'country_Cyprus',
83
- 'country_Czech_Republic', 'country_Denmark', 'country_Estonia',
84
  'country_Finland', 'country_France', 'country_Germany',
85
  'country_Greece', 'country_Hungary', 'country_Ireland',
86
  'country_Italy', 'country_Latvia', 'country_Lithuania',
87
  'country_Luxembourg', 'country_Malta', 'country_Netherlands',
88
  'country_Poland', 'country_Portugal', 'country_Romania',
89
  'country_Slovakia', 'country_Slovenia', 'country_Spain',
90
- 'country_Sweden', 'cancer_stage_Stage_II', 'cancer_stage_Stage_III',
91
- 'cancer_stage_Stage_IV', 'family_history_Yes',
92
- 'smoking_status_Former_Smoker', 'smoking_status_Never_Smoked',
93
- 'smoking_status_Passive_Smoker', 'treatment_type_Combined',
94
  'treatment_type_Radiation', 'treatment_type_Surgery'
95
  ]
96
 
97
- # Create DataFrame and ensure the columns are in the correct order
98
  input_df = pd.DataFrame([input_dict], columns=feature_order)
99
  logger.info(f"DataFrame for prediction: {input_df}")
100
 
@@ -102,6 +153,7 @@ def predict(data: OneHotPatientData):
102
  probabilities = model.predict_proba(input_df)[0]
103
  logger.info(f"Model probabilities: {probabilities}")
104
 
 
105
  confidence_high_risk = probabilities[0]
106
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
107
 
@@ -115,4 +167,4 @@ def predict(data: OneHotPatientData):
115
 
116
  except Exception as e:
117
  logger.error(f"Prediction error: {str(e)}")
118
- return {"error": str(e)}
 
9
  import os
10
  import logging
11
 
12
+ # ------------------------------------------------------------
13
+ # Setup Logging
14
+ # ------------------------------------------------------------
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
  # ------------------------------------------------------------
19
+ # FastAPI Initialization
20
+ # ------------------------------------------------------------
21
+ # This line MUST come before any @app.<method> decorators
22
+ app = FastAPI(title="PulmoProbe AI API")
23
+
24
+ # Allow CORS for frontend communication
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"], # Use specific domain in production
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # ------------------------------------------------------------
34
+ # Hugging Face Model Setup
35
+ # ------------------------------------------------------------
36
+ os.environ['HF_HOME'] = '/tmp/huggingface'
37
+ os.makedirs(os.environ['HF_HOME'], exist_ok=True)
38
+ logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
39
+
40
+ MODEL_REPO_ID = "costaspinto/PulmoProbe"
41
+ MODEL_FILENAME = "best_model.joblib"
42
+
43
+ logger.info("Downloading model from Hugging Face Hub...")
44
+ try:
45
+ model_path = hf_hub_download(
46
+ repo_id=MODEL_REPO_ID,
47
+ filename=MODEL_FILENAME,
48
+ cache_dir=os.environ['HF_HOME']
49
+ )
50
+ model = joblib.load(model_path)
51
+ logger.info("Model loaded successfully.")
52
+ except Exception as e:
53
+ logger.error(f"Failed to load model: {str(e)}")
54
+ raise RuntimeError(f"Model loading failed: {str(e)}")
55
+
56
+ # ------------------------------------------------------------
57
+ # Define Input Schema
58
  # ------------------------------------------------------------
59
  class OneHotPatientData(BaseModel):
60
+ # Continuous fields
61
  age: float
62
  bmi: float
63
  cholesterol_level: float
64
+ # Binary medical conditions
65
  hypertension: int
66
  asthma: int
67
  cirrhosis: int
68
  other_cancer: int
69
+ # Gender (Male = 1, Female = 0)
70
  gender_Male: int
71
+ # Countries (One-Hot)
 
72
  country_Belgium: int
73
  country_Bulgaria: int
74
  country_Croatia: int
 
95
  country_Slovenia: int
96
  country_Spain: int
97
  country_Sweden: int
98
+ # Cancer stages (Stage I is baseline)
99
+ cancer_stage_Stage_Ii: int
100
+ cancer_stage_Stage_Iii: int
101
+ cancer_stage_Stage_Iv: int
102
+ # Family history
103
+ family_history_Yes: int
104
+ # Smoking status (Current Smoker is baseline)
105
  smoking_status_Former_Smoker: int
106
  smoking_status_Never_Smoked: int
107
  smoking_status_Passive_Smoker: int
108
+ # Treatment type (Chemotherapy is baseline)
109
  treatment_type_Combined: int
110
  treatment_type_Radiation: int
111
  treatment_type_Surgery: int
112
 
113
+ # ------------------------------------------------------------
114
+ # Root Endpoint
115
+ # ------------------------------------------------------------
116
+ @app.get("/")
117
+ def read_root():
118
+ return {"message": "Welcome to the PulmoProbe AI API"}
119
 
120
  # ------------------------------------------------------------
121
+ # Prediction Endpoint
122
  # ------------------------------------------------------------
123
  @app.post("/predict")
124
  def predict(data: OneHotPatientData):
 
126
  input_dict = data.dict()
127
  logger.info(f"Incoming data: {input_dict}")
128
 
129
+ # Define the exact feature order your model expects
130
  feature_order = [
131
  'age', 'bmi', 'cholesterol_level', 'hypertension', 'asthma',
132
  'cirrhosis', 'other_cancer', 'gender_Male', 'country_Belgium',
133
  'country_Bulgaria', 'country_Croatia', 'country_Cyprus',
134
+ 'country_Czech Republic', 'country_Denmark', 'country_Estonia',
135
  'country_Finland', 'country_France', 'country_Germany',
136
  'country_Greece', 'country_Hungary', 'country_Ireland',
137
  'country_Italy', 'country_Latvia', 'country_Lithuania',
138
  'country_Luxembourg', 'country_Malta', 'country_Netherlands',
139
  'country_Poland', 'country_Portugal', 'country_Romania',
140
  'country_Slovakia', 'country_Slovenia', 'country_Spain',
141
+ 'country_Sweden', 'cancer_stage_Stage Ii', 'cancer_stage_Stage Iii',
142
+ 'cancer_stage_Stage Iv', 'family_history_Yes',
143
+ 'smoking_status_Former Smoker', 'smoking_status_Never Smoked',
144
+ 'smoking_status_Passive Smoker', 'treatment_type_Combined',
145
  'treatment_type_Radiation', 'treatment_type_Surgery'
146
  ]
147
 
148
+ # Convert dictionary to a DataFrame and ensure the columns are in the correct order
149
  input_df = pd.DataFrame([input_dict], columns=feature_order)
150
  logger.info(f"DataFrame for prediction: {input_df}")
151
 
 
153
  probabilities = model.predict_proba(input_df)[0]
154
  logger.info(f"Model probabilities: {probabilities}")
155
 
156
+ # Assuming index 0 = High Risk
157
  confidence_high_risk = probabilities[0]
158
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
159
 
 
167
 
168
  except Exception as e:
169
  logger.error(f"Prediction error: {str(e)}")
170
+ return {"error": str(e), "input_data_received": data.dict()}