costaspinto commited on
Commit
c84f188
·
verified ·
1 Parent(s): 5bfcf9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -65
app.py CHANGED
@@ -1,66 +1,71 @@
1
- # backend/app.py
2
-
3
- from fastapi import FastAPI
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- import joblib
7
- import pandas as pd
8
- from huggingface_hub import hf_hub_download
9
- import os
10
-
11
- app = FastAPI(title="PulmoProbe AI API")
12
-
13
- # Add CORS middleware to allow your Vercel app to call the API
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=["*"],
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
-
22
- # --- Download and Load Model from Hugging Face Hub ---
23
- # This points to your model repository
24
- MODEL_REPO_ID = "costaspinto/PulmoProbe"
25
- MODEL_FILENAME = "best_model.joblib"
26
-
27
- print("Downloading model from Hugging Face Hub...")
28
- model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
29
- model = joblib.load(model_path)
30
- print("Model loaded successfully.")
31
-
32
- # --- Define Input Data Model ---
33
- class PatientData(BaseModel):
34
- age: float
35
- gender: str
36
- country: str
37
- cancer_stage: str
38
- family_history: int
39
- smoking_status: str
40
- bmi: float
41
- cholesterol_level: float
42
- hypertension: int
43
- asthma: int
44
- cirrhosis: int
45
- other_cancer: int
46
- treatment_type: str
47
-
48
- # --- Define API Endpoints ---
49
- @app.get("/")
50
- def read_root():
51
- return {"message": "Welcome to the PulmoProbe AI API"}
52
-
53
- @app.post("/predict")
54
- def predict(data: PatientData):
55
- try:
56
- input_df = pd.DataFrame([data.dict()])
57
- probabilities = model.predict_proba(input_df)[0]
58
- confidence_high_risk = probabilities[0]
59
- risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
60
-
61
- return {
62
- "risk": risk_level,
63
- "confidence": f"{confidence_high_risk * 100:.1f}"
64
- }
65
- except Exception as e:
 
 
 
 
 
66
  return {"error": str(e)}
 
1
+ # backend/app.py
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import joblib
7
+ import pandas as pd
8
+ from huggingface_hub import hf_hub_download
9
+ import os
10
+
11
+ app = FastAPI(title="PulmoProbe AI API")
12
+
13
+ # Add CORS middleware to allow your Vercel app to call the API
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ # --- Set the Hugging Face cache directory to a writable location ---
23
+ # This ensures the application has permission to download and store the model.
24
+ os.environ['HF_HOME'] = '/app/cache'
25
+ print("Setting HF_HOME to /app/cache")
26
+
27
+ # --- Download and Load Model from Hugging Face Hub ---
28
+ # This points to your model repository
29
+ MODEL_REPO_ID = "costaspinto/PulmoProbe"
30
+ MODEL_FILENAME = "best_model.joblib"
31
+
32
+ print("Downloading model from Hugging Face Hub...")
33
+ model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
34
+ model = joblib.load(model_path)
35
+ print("Model loaded successfully.")
36
+
37
+ # --- Define Input Data Model ---
38
+ class PatientData(BaseModel):
39
+ age: float
40
+ gender: str
41
+ country: str
42
+ cancer_stage: str
43
+ family_history: int
44
+ smoking_status: str
45
+ bmi: float
46
+ cholesterol_level: float
47
+ hypertension: int
48
+ asthma: int
49
+ cirrhosis: int
50
+ other_cancer: int
51
+ treatment_type: str
52
+
53
+ # --- Define API Endpoints ---
54
+ @app.get("/")
55
+ def read_root():
56
+ return {"message": "Welcome to the PulmoProbe AI API"}
57
+
58
+ @app.post("/predict")
59
+ def predict(data: PatientData):
60
+ try:
61
+ input_df = pd.DataFrame([data.dict()])
62
+ probabilities = model.predict_proba(input_df)[0]
63
+ confidence_high_risk = probabilities[0]
64
+ risk_level = "High Risk of Non-Survival" if confidence_high_risk > 0.5 else "Low Risk of Non-Survival"
65
+
66
+ return {
67
+ "risk": risk_level,
68
+ "confidence": f"{confidence_high_risk * 100:.1f}"
69
+ }
70
+ except Exception as e:
71
  return {"error": str(e)}