costaspinto commited on
Commit
a92d9eb
·
verified ·
1 Parent(s): 3b383e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -47
app.py CHANGED
@@ -10,20 +10,20 @@ import os
10
  import logging
11
 
12
  # ------------------------------------------------------------
13
- # Setup Logging for Debugging
14
  # ------------------------------------------------------------
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  # ------------------------------------------------------------
19
- # FastAPI App Initialization
20
  # ------------------------------------------------------------
21
  app = FastAPI(title="PulmoProbe AI API")
22
 
23
- # Enable CORS so React frontend can communicate with FastAPI backend
24
  app.add_middleware(
25
  CORSMiddleware,
26
- allow_origins=["*"], # In production, specify domain
27
  allow_credentials=True,
28
  allow_methods=["*"],
29
  allow_headers=["*"],
@@ -32,7 +32,6 @@ app.add_middleware(
32
  # ------------------------------------------------------------
33
  # Hugging Face Model Setup
34
  # ------------------------------------------------------------
35
- # Use a writable temp directory for Hugging Face cache
36
  os.environ['HF_HOME'] = '/tmp/huggingface'
37
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
38
  logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
@@ -50,11 +49,12 @@ try:
50
  model = joblib.load(model_path)
51
  logger.info("Model loaded successfully.")
52
  except Exception as e:
53
- logger.error(f"Failed to download or load model: {str(e)}")
54
  raise RuntimeError(f"Model loading failed: {str(e)}")
55
 
56
  # ------------------------------------------------------------
57
- # Define Input Schema - Must Match Frontend Fields Exactly
 
58
  # ------------------------------------------------------------
59
  class OneHotPatientData(BaseModel):
60
  # Continuous fields
@@ -68,46 +68,54 @@ class OneHotPatientData(BaseModel):
68
  cirrhosis: int
69
  other_cancer: int
70
 
71
- # Family history
72
- family_history_Yes: int
73
-
74
- # Gender one-hot
75
  gender_Male: int
76
- gender_Female: int
77
 
78
- # Country one-hot
79
- country_Sweden: int
80
- country_Netherlands: int
81
- country_Hungary: int
82
  country_Belgium: int
83
- country_Italy: int
84
  country_Croatia: int
 
 
85
  country_Denmark: int
86
- country_Germany: int
 
87
  country_France: int
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  country_Slovakia: int
89
- country_Finland: int
90
  country_Spain: int
91
- country_UnitedKingdom: int
92
- country_UnitedStates: int
93
-
94
- # Cancer stage one-hot
95
- cancer_stage_StageI: int
96
- cancer_stage_StageII: int
97
- cancer_stage_StageIII: int
98
- cancer_stage_StageIV: int
99
-
100
- # Smoking status one-hot
101
- smoking_status_NeverSmoked: int
102
- smoking_status_FormerSmoker: int
103
- smoking_status_PassiveSmoker: int
104
- smoking_status_CurrentSmoker: int
105
-
106
- # Treatment type one-hot
107
- treatment_type_Chemotherapy: int
108
- treatment_type_Surgery: int
109
- treatment_type_Radiation: int
110
  treatment_type_Combined: int
 
 
111
 
112
  # ------------------------------------------------------------
113
  # Root Endpoint
@@ -122,18 +130,18 @@ def read_root():
122
  @app.post("/predict")
123
  def predict(data: OneHotPatientData):
124
  try:
125
- # Convert incoming data to DataFrame
126
- input_data = data.dict()
127
- logger.info(f"Received prediction request: {input_data}")
128
 
129
- input_df = pd.DataFrame([input_data])
 
130
 
131
- # Make prediction
132
  probabilities = model.predict_proba(input_df)[0]
133
- confidence_high_risk = probabilities[0] # Assuming index 0 = High Risk
134
- logger.info(f"Model raw probabilities: {probabilities}")
135
 
136
- # Determine risk level
 
137
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
138
 
139
  result = {
@@ -147,4 +155,3 @@ def predict(data: OneHotPatientData):
147
  except Exception as e:
148
  logger.error(f"Prediction error: {str(e)}")
149
  return {"error": str(e)}
150
-
 
10
  import logging
11
 
12
  # ------------------------------------------------------------
13
+ # Setup Logging
14
  # ------------------------------------------------------------
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  # ------------------------------------------------------------
19
+ # FastAPI Initialization
20
  # ------------------------------------------------------------
21
  app = FastAPI(title="PulmoProbe AI API")
22
 
23
+ # Allow CORS for frontend communication
24
  app.add_middleware(
25
  CORSMiddleware,
26
+ allow_origins=["*"], # Use specific domain in production
27
  allow_credentials=True,
28
  allow_methods=["*"],
29
  allow_headers=["*"],
 
32
  # ------------------------------------------------------------
33
  # Hugging Face Model Setup
34
  # ------------------------------------------------------------
 
35
  os.environ['HF_HOME'] = '/tmp/huggingface'
36
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
37
  logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
 
49
  model = joblib.load(model_path)
50
  logger.info("Model loaded successfully.")
51
  except Exception as e:
52
+ logger.error(f"Failed to load model: {str(e)}")
53
  raise RuntimeError(f"Model loading failed: {str(e)}")
54
 
55
  # ------------------------------------------------------------
56
+ # Define Input Schema
57
+ # Must exactly match the model feature names and order
58
  # ------------------------------------------------------------
59
  class OneHotPatientData(BaseModel):
60
  # Continuous fields
 
68
  cirrhosis: int
69
  other_cancer: int
70
 
71
+ # Gender (Male = 1, Female = 0)
 
 
 
72
  gender_Male: int
 
73
 
74
+ # Countries (One-Hot)
 
 
 
75
  country_Belgium: int
76
+ country_Bulgaria: int
77
  country_Croatia: int
78
+ country_Cyprus: int
79
+ country_Czech_Republic: int
80
  country_Denmark: int
81
+ country_Estonia: int
82
+ country_Finland: int
83
  country_France: int
84
+ country_Germany: int
85
+ country_Greece: int
86
+ country_Hungary: int
87
+ country_Ireland: int
88
+ country_Italy: int
89
+ country_Latvia: int
90
+ country_Lithuania: int
91
+ country_Luxembourg: int
92
+ country_Malta: int
93
+ country_Netherlands: int
94
+ country_Poland: int
95
+ country_Portugal: int
96
+ country_Romania: int
97
  country_Slovakia: int
98
+ country_Slovenia: int
99
  country_Spain: int
100
+ country_Sweden: int
101
+
102
+ # Cancer stages (Stage I is baseline)
103
+ cancer_stage_Stage_Ii: int
104
+ cancer_stage_Stage_Iii: int
105
+ cancer_stage_Stage_Iv: int
106
+
107
+ # Family history
108
+ family_history_Yes: int
109
+
110
+ # Smoking status (Current Smoker is baseline)
111
+ smoking_status_Former_Smoker: int
112
+ smoking_status_Never_Smoked: int
113
+ smoking_status_Passive_Smoker: int
114
+
115
+ # Treatment type (Chemotherapy is baseline)
 
 
 
116
  treatment_type_Combined: int
117
+ treatment_type_Radiation: int
118
+ treatment_type_Surgery: int
119
 
120
  # ------------------------------------------------------------
121
  # Root Endpoint
 
130
  @app.post("/predict")
131
  def predict(data: OneHotPatientData):
132
  try:
133
+ input_dict = data.dict()
134
+ logger.info(f"Incoming data: {input_dict}")
 
135
 
136
+ # Create DataFrame
137
+ input_df = pd.DataFrame([input_dict])
138
 
139
+ # Predict probabilities
140
  probabilities = model.predict_proba(input_df)[0]
141
+ logger.info(f"Model probabilities: {probabilities}")
 
142
 
143
+ # Assuming index 0 = High Risk
144
+ confidence_high_risk = probabilities[0]
145
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
146
 
147
  result = {
 
155
  except Exception as e:
156
  logger.error(f"Prediction error: {str(e)}")
157
  return {"error": str(e)}