costaspinto commited on
Commit
3b383e1
·
verified ·
1 Parent(s): 061d037

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -20
app.py CHANGED
@@ -7,49 +7,75 @@ import joblib
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
9
  import os
 
10
 
 
 
 
 
 
 
 
 
 
11
  app = FastAPI(title="PulmoProbe AI API")
12
 
13
- # Add CORS middleware to allow your frontend app to call the API
14
  app.add_middleware(
15
  CORSMiddleware,
16
- allow_origins=["*"],
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
- # --- Set Hugging Face cache to a writable location ---
 
 
 
23
  os.environ['HF_HOME'] = '/tmp/huggingface'
24
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
25
- print(f"HF_HOME set to {os.environ['HF_HOME']}")
26
 
27
- # --- Download and Load Model from Hugging Face Hub ---
28
  MODEL_REPO_ID = "costaspinto/PulmoProbe"
29
  MODEL_FILENAME = "best_model.joblib"
30
 
31
- print("Downloading model from Hugging Face Hub...")
32
- model_path = hf_hub_download(
33
- repo_id=MODEL_REPO_ID,
34
- filename=MODEL_FILENAME,
35
- cache_dir=os.environ['HF_HOME']
36
- )
37
- model = joblib.load(model_path)
38
- print("Model loaded successfully.")
 
 
 
 
39
 
40
- # --- Define Input Data Model for one-hot encoded features ---
41
- # This new model directly matches the one-hot encoded data from the frontend
 
42
  class OneHotPatientData(BaseModel):
 
43
  age: float
44
  bmi: float
45
  cholesterol_level: float
 
 
46
  hypertension: int
47
  asthma: int
48
  cirrhosis: int
49
  other_cancer: int
 
 
50
  family_history_Yes: int
 
 
51
  gender_Male: int
52
  gender_Female: int
 
 
53
  country_Sweden: int
54
  country_Netherlands: int
55
  country_Hungary: int
@@ -64,35 +90,61 @@ class OneHotPatientData(BaseModel):
64
  country_Spain: int
65
  country_UnitedKingdom: int
66
  country_UnitedStates: int
 
 
67
  cancer_stage_StageI: int
68
  cancer_stage_StageII: int
69
  cancer_stage_StageIII: int
70
  cancer_stage_StageIV: int
 
 
71
  smoking_status_NeverSmoked: int
72
  smoking_status_FormerSmoker: int
73
  smoking_status_PassiveSmoker: int
74
  smoking_status_CurrentSmoker: int
 
 
75
  treatment_type_Chemotherapy: int
76
  treatment_type_Surgery: int
77
  treatment_type_Radiation: int
78
  treatment_type_Combined: int
79
 
80
- # --- API Endpoints ---
 
 
81
  @app.get("/")
82
  def read_root():
83
  return {"message": "Welcome to the PulmoProbe AI API"}
84
 
 
 
 
85
  @app.post("/predict")
86
  def predict(data: OneHotPatientData):
87
  try:
88
- input_df = pd.DataFrame([data.dict()])
 
 
 
 
 
 
89
  probabilities = model.predict_proba(input_df)[0]
90
- confidence_high_risk = probabilities[0]
 
 
 
91
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
92
 
93
- return {
94
  "risk": risk_level,
95
  "confidence": f"{confidence_high_risk * 100:.1f}%"
96
  }
 
 
 
 
97
  except Exception as e:
98
- return {"error": str(e)}
 
 
 
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
9
  import os
10
+ import logging
11
 
12
+ # ------------------------------------------------------------
13
+ # Setup Logging for Debugging
14
+ # ------------------------------------------------------------
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # ------------------------------------------------------------
19
+ # FastAPI App Initialization
20
+ # ------------------------------------------------------------
21
  app = FastAPI(title="PulmoProbe AI API")
22
 
23
+ # Enable CORS so React frontend can communicate with FastAPI backend
24
  app.add_middleware(
25
  CORSMiddleware,
26
+ allow_origins=["*"], # In production, specify domain
27
  allow_credentials=True,
28
  allow_methods=["*"],
29
  allow_headers=["*"],
30
  )
31
 
32
+ # ------------------------------------------------------------
33
+ # Hugging Face Model Setup
34
+ # ------------------------------------------------------------
35
+ # Use a writable temp directory for Hugging Face cache
36
  os.environ['HF_HOME'] = '/tmp/huggingface'
37
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
38
+ logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
39
 
 
40
  MODEL_REPO_ID = "costaspinto/PulmoProbe"
41
  MODEL_FILENAME = "best_model.joblib"
42
 
43
+ logger.info("Downloading model from Hugging Face Hub...")
44
+ try:
45
+ model_path = hf_hub_download(
46
+ repo_id=MODEL_REPO_ID,
47
+ filename=MODEL_FILENAME,
48
+ cache_dir=os.environ['HF_HOME']
49
+ )
50
+ model = joblib.load(model_path)
51
+ logger.info("Model loaded successfully.")
52
+ except Exception as e:
53
+ logger.error(f"Failed to download or load model: {str(e)}")
54
+ raise RuntimeError(f"Model loading failed: {str(e)}")
55
 
56
+ # ------------------------------------------------------------
57
+ # Define Input Schema - Must Match Frontend Fields Exactly
58
+ # ------------------------------------------------------------
59
  class OneHotPatientData(BaseModel):
60
+ # Continuous fields
61
  age: float
62
  bmi: float
63
  cholesterol_level: float
64
+
65
+ # Binary medical conditions
66
  hypertension: int
67
  asthma: int
68
  cirrhosis: int
69
  other_cancer: int
70
+
71
+ # Family history
72
  family_history_Yes: int
73
+
74
+ # Gender one-hot
75
  gender_Male: int
76
  gender_Female: int
77
+
78
+ # Country one-hot
79
  country_Sweden: int
80
  country_Netherlands: int
81
  country_Hungary: int
 
90
  country_Spain: int
91
  country_UnitedKingdom: int
92
  country_UnitedStates: int
93
+
94
+ # Cancer stage one-hot
95
  cancer_stage_StageI: int
96
  cancer_stage_StageII: int
97
  cancer_stage_StageIII: int
98
  cancer_stage_StageIV: int
99
+
100
+ # Smoking status one-hot
101
  smoking_status_NeverSmoked: int
102
  smoking_status_FormerSmoker: int
103
  smoking_status_PassiveSmoker: int
104
  smoking_status_CurrentSmoker: int
105
+
106
+ # Treatment type one-hot
107
  treatment_type_Chemotherapy: int
108
  treatment_type_Surgery: int
109
  treatment_type_Radiation: int
110
  treatment_type_Combined: int
111
 
112
+ # ------------------------------------------------------------
113
+ # Root Endpoint
114
+ # ------------------------------------------------------------
115
  @app.get("/")
116
  def read_root():
117
  return {"message": "Welcome to the PulmoProbe AI API"}
118
 
119
+ # ------------------------------------------------------------
120
+ # Prediction Endpoint
121
+ # ------------------------------------------------------------
122
  @app.post("/predict")
123
  def predict(data: OneHotPatientData):
124
  try:
125
+ # Convert incoming data to DataFrame
126
+ input_data = data.dict()
127
+ logger.info(f"Received prediction request: {input_data}")
128
+
129
+ input_df = pd.DataFrame([input_data])
130
+
131
+ # Make prediction
132
  probabilities = model.predict_proba(input_df)[0]
133
+ confidence_high_risk = probabilities[0] # Assuming index 0 = High Risk
134
+ logger.info(f"Model raw probabilities: {probabilities}")
135
+
136
+ # Determine risk level
137
  risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
138
 
139
+ result = {
140
  "risk": risk_level,
141
  "confidence": f"{confidence_high_risk * 100:.1f}%"
142
  }
143
+
144
+ logger.info(f"Prediction result: {result}")
145
+ return result
146
+
147
  except Exception as e:
148
+ logger.error(f"Prediction error: {str(e)}")
149
+ return {"error": str(e)}
150
+