madhwdh11 commited on
Commit
b1d342c
·
verified ·
1 Parent(s): d82831d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -29
main.py CHANGED
@@ -1,42 +1,27 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from typing import List
4
  import numpy as np
5
- import joblib
6
- import os
7
 
8
- app = FastAPI(title="Stroke Risk Prediction")
9
 
10
- MODEL_PATH = "best_random_forest.pkl"
11
- _model = None
 
 
12
 
13
- def get_model():
14
- global _model
15
- if _model is None:
16
- if not os.path.exists(MODEL_PATH):
17
- raise FileNotFoundError(f"Missing model file: {MODEL_PATH}")
18
- _model = joblib.load(MODEL_PATH)
19
- return _model
20
-
21
- class PredictRequest(BaseModel):
22
- input_data: List[List[float]]
23
-
24
- @app.get("/")
25
- def root():
26
- return {"status": "ok"}
27
-
28
- @app.get("/health")
29
- def health():
30
- return {"status": "healthy"}
31
 
32
  @app.post("/predict")
33
- def predict(req: PredictRequest):
34
  try:
35
- clf = get_model()
36
- arr = np.array(req.input_data, dtype=float)
37
- preds = clf.predict(arr)
38
- return {"predictions": preds.tolist()}
39
  except Exception as e:
40
  raise HTTPException(status_code=500, detail=str(e))
41
 
42
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ import onnxruntime as ort
4
  import numpy as np
 
 
5
 
6
+ app = FastAPI()
7
 
8
+ # Load the ONNX model session
9
+ session = ort.InferenceSession("app/model.onnx", providers=["CPUExecutionProvider"])
10
+ input_name = session.get_inputs()[0].name
11
+ output_name = session.get_outputs()[0].name
12
 
13
+ # Define input schema
14
+ class InputData(BaseModel):
15
+ features: list[float]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.post("/predict")
18
+ def predict(data: InputData):
19
  try:
20
+ input_array = np.array([data.features], dtype=np.float32)
21
+ prediction = session.run([output_name], {input_name: input_array})
22
+ return {"prediction": prediction[0][0].item()}
 
23
  except Exception as e:
24
  raise HTTPException(status_code=500, detail=str(e))
25
 
26
 
27
+