madhwdh11 commited on
Commit
e29208d
·
verified ·
1 Parent(s): 48f38d2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -19
main.py CHANGED
@@ -3,37 +3,38 @@ from pydantic import BaseModel
3
  from typing import List
4
  import joblib
5
  import numpy as np
 
6
 
7
- app = FastAPI(title="Stroke Risk Prediction API")
8
 
9
- # Lazy loading
10
  model = None
11
 
12
- @app.on_event("startup")
13
  def load_model():
14
  global model
15
- try:
16
- model = joblib.load("best_random_forest.pkl")
17
- except Exception as e:
18
- print(f"Failed to load model: {e}")
19
- raise RuntimeError("Model loading failed.")
20
 
21
  class PredictRequest(BaseModel):
22
  input_data: List[List[float]]
23
 
24
- @app.post("/predict")
25
- def predict(req: PredictRequest):
26
- try:
27
- input_array = np.array(req.input_data)
28
- predictions = model.predict(input_array)
29
- return {"predictions": predictions.tolist()}
30
- except Exception as e:
31
- raise HTTPException(status_code=400, detail=str(e))
32
-
33
  @app.get("/")
34
- def read_root():
35
- return {"status": "ok"}
36
 
37
  @app.get("/health")
38
  def health():
39
  return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import List
4
  import joblib
5
  import numpy as np
6
+ import os
7
 
8
+ app = FastAPI()
9
 
 
10
  model = None
11
 
 
12
  def load_model():
13
  global model
14
+ if model is None:
15
+ model_path = "best_random_forest.pkl"
16
+ if not os.path.exists(model_path):
17
+ raise RuntimeError("Model file not found.")
18
+ model = joblib.load(model_path)
19
 
20
  class PredictRequest(BaseModel):
21
  input_data: List[List[float]]
22
 
 
 
 
 
 
 
 
 
 
23
  @app.get("/")
24
+ def root():
25
+ return {"message": "API is running."}
26
 
27
  @app.get("/health")
28
  def health():
29
  return {"status": "healthy"}
30
+
31
+ @app.post("/predict")
32
+ def predict(data: PredictRequest):
33
+ try:
34
+ load_model() # Only loads once, on first request
35
+ inputs = np.array(data.input_data)
36
+ preds = model.predict(inputs)
37
+ return {"predictions": preds.tolist()}
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+