haseebnawazz commited on
Commit
481d5dc
·
verified ·
1 Parent(s): 7fee830

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +15 -10
api.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from huggingface_hub import hf_hub_download
2
  import joblib
3
  from fastapi import FastAPI
@@ -7,23 +8,27 @@ import os
7
  # Set a different cache directory
8
  os.environ["HF_HOME"] = "/tmp/hf_cache"
9
 
10
- # Initialize FastAPI app
11
  app = FastAPI()
12
 
13
- # Download the model from Hugging Face Model Repo
14
  model_path = hf_hub_download(
15
- repo_id="haseebnawazz/sleep-stage-classifier-model", # Update with your HF model repo
16
  filename="model.joblib",
17
- cache_dir="/tmp/hf_cache" # Ensure the cache is stored in a writable directory
18
  )
19
  model = joblib.load(model_path)
20
 
21
- # Define a prediction endpoint
 
 
 
 
22
  @app.post("/predict")
23
- def predict(features: list):
24
- features = np.array(features).reshape(1, -1) # Convert input to NumPy array
25
- prediction = model.predict(features).tolist() # Predict and convert to list
26
- return {"prediction": prediction}
 
27
 
28
  @app.get("/")
29
  async def read_root():
@@ -31,4 +36,4 @@ async def read_root():
31
 
32
  @app.get("/logs")
33
  async def get_logs():
34
- return {"logs": "container logs here"}
 
1
+ from pydantic import BaseModel
2
  from huggingface_hub import hf_hub_download
3
  import joblib
4
  from fastapi import FastAPI
 
8
  # Set a different cache directory
9
  os.environ["HF_HOME"] = "/tmp/hf_cache"
10
 
 
11
  app = FastAPI()
12
 
13
+ # Download model
14
  model_path = hf_hub_download(
15
+ repo_id="haseebnawazz/sleep-stage-classifier-model",
16
  filename="model.joblib",
17
+ cache_dir="/tmp/hf_cache"
18
  )
19
  model = joblib.load(model_path)
20
 
21
+ # Define Pydantic schema to accept {"features": {...}}
22
+ class FeaturesInput(BaseModel):
23
+ features: dict
24
+
25
+ # ✅ Update endpoint to use schema
26
  @app.post("/predict")
27
+ def predict(input: FeaturesInput):
28
+ features = input.features
29
+ features_array = np.array([list(features.values())]) # Ensure 2D array for model
30
+ prediction = model.predict(features_array).tolist()
31
+ return {"predicted_stage": prediction[0]}
32
 
33
  @app.get("/")
34
  async def read_root():
 
36
 
37
  @app.get("/logs")
38
  async def get_logs():
39
+ return {"logs": "container logs here"}