madhwdh11 commited on
Commit
323f563
·
verified ·
1 Parent(s): 7ffb757

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -6
main.py CHANGED
@@ -4,15 +4,22 @@ from typing import List
4
  import joblib
5
  import numpy as np
6
 
7
- # Initialize app
8
- app = FastAPI(title="Random Forest Classifier API")
9
 
10
- # Load the saved model
11
- model = joblib.load("best_random_forest.pkl")
 
 
 
 
 
 
 
 
 
12
 
13
- # Define input data format
14
  class PredictRequest(BaseModel):
15
- input_data: List[List[float]] # 2D list for batch prediction
16
 
17
  @app.post("/predict")
18
  def predict(req: PredictRequest):
@@ -22,3 +29,11 @@ def predict(req: PredictRequest):
22
  return {"predictions": predictions.tolist()}
23
  except Exception as e:
24
  raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
 
4
  import joblib
5
  import numpy as np
6
 
7
+ app = FastAPI(title="Stroke Risk Prediction API")
 
8
 
9
+ # Lazy loading
10
+ model = None
11
+
12
+ @app.on_event("startup")
13
+ def load_model():
14
+ global model
15
+ try:
16
+ model = joblib.load("best_random_forest.pkl")
17
+ except Exception as e:
18
+ print(f"Failed to load model: {e}")
19
+ raise RuntimeError("Model loading failed.")
20
 
 
21
  class PredictRequest(BaseModel):
22
+ input_data: List[List[float]]
23
 
24
  @app.post("/predict")
25
  def predict(req: PredictRequest):
 
29
  return {"predictions": predictions.tolist()}
30
  except Exception as e:
31
  raise HTTPException(status_code=400, detail=str(e))
32
+
33
+ @app.get("/")
34
+ def read_root():
35
+ return {"status": "ok"}
36
+
37
+ @app.get("/health")
38
+ def health():
39
+ return {"status": "healthy"}