madhwdh11 commited on
Commit
5f03b61
·
verified ·
1 Parent(s): 7318d27

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -26
main.py CHANGED
@@ -1,25 +1,8 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import List
4
- import numpy as np
5
- import joblib
6
  import os
 
7
 
8
- app = FastAPI(title="Stroke Risk Prediction")
9
-
10
- MODEL_PATH = "best_random_forest.pkl"
11
- _model = None
12
-
13
- def get_model():
14
- global _model
15
- if _model is None:
16
- if not os.path.exists(MODEL_PATH):
17
- raise FileNotFoundError(f"Missing model file: {MODEL_PATH}")
18
- _model = joblib.load(MODEL_PATH)
19
- return _model
20
-
21
- class PredictRequest(BaseModel):
22
- input_data: List[List[float]]
23
 
24
  @app.get("/")
25
  def root():
@@ -30,14 +13,17 @@ def health():
30
  return {"status": "healthy"}
31
 
32
  @app.post("/predict")
33
- def predict(req: PredictRequest):
34
  try:
35
- clf = get_model()
36
- arr = np.array(req.input_data, dtype=float)
37
- preds = clf.predict(arr)
38
- return {"predictions": preds.tolist()}
 
 
 
39
  except Exception as e:
40
- raise HTTPException(status_code=500, detail=str(e))
41
 
42
 
43
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
 
 
 
 
2
  import os
3
+ from rf_model import score
4
 
5
+ app = FastAPI(title="Stroke Risk (Pure‑Python RF)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  @app.get("/")
8
  def root():
 
13
  return {"status": "healthy"}
14
 
15
  @app.post("/predict")
16
+ async def predict(request: Request):
17
  try:
18
+ body = await request.json()
19
+ data = body.get("input_data")
20
+ if not isinstance(data, list):
21
+ raise ValueError("`input_data` must be a list of rows")
22
+ # each row is a list of floats
23
+ preds = [score(row) for row in data]
24
+ return {"predictions": preds}
25
  except Exception as e:
26
+ raise HTTPException(status_code=400, detail=str(e))
27
 
28
 
29