madhwdh11 commited on
Commit
70ad082
·
verified ·
1 Parent(s): 5f03b61

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -12
main.py CHANGED
@@ -1,26 +1,26 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- import os
 
 
 
 
3
  from rf_model import score
4
 
5
  app = FastAPI(title="Stroke Risk (Pure‑Python RF)")
6
 
7
- @app.get("/")
8
- def root():
9
- return {"status": "ok"}
10
 
11
  @app.get("/health")
12
  def health():
13
  return {"status": "healthy"}
14
 
15
  @app.post("/predict")
16
- async def predict(request: Request):
17
  try:
18
- body = await request.json()
19
- data = body.get("input_data")
20
- if not isinstance(data, list):
21
- raise ValueError("`input_data` must be a list of rows")
22
- # each row is a list of floats
23
- preds = [score(row) for row in data]
24
  return {"predictions": preds}
25
  except Exception as e:
26
  raise HTTPException(status_code=400, detail=str(e))
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ import numpy as np
5
+
6
+ # import the pure‑Python scoring function
7
  from rf_model import score
8
 
9
  app = FastAPI(title="Stroke Risk (Pure‑Python RF)")
10
 
11
+ class PredictRequest(BaseModel):
12
+ input_data: List[List[float]]
 
13
 
14
  @app.get("/health")
15
  def health():
16
  return {"status": "healthy"}
17
 
18
  @app.post("/predict")
19
+ def predict(req: PredictRequest):
20
  try:
21
+ arr = np.array(req.input_data, dtype=float)
22
+ # call the generated function directly
23
+ preds = [score(row.tolist()) for row in arr]
 
 
 
24
  return {"predictions": preds}
25
  except Exception as e:
26
  raise HTTPException(status_code=400, detail=str(e))