haseebnawazz's picture
Update api.py
ba4b2a5 verified
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
import joblib
from fastapi import FastAPI
import numpy as np
import os
from typing import List
# Set a different cache directory
os.environ["HF_HOME"] = "/tmp/hf_cache"
app = FastAPI()
# Download model
model_path = hf_hub_download(
repo_id="haseebnawazz/sleep_stage_classifier-RF",
filename="class_balanced_RF_model.joblib",
cache_dir="/tmp/hf_cache"
)
model = joblib.load(model_path)
class FeatureInput(BaseModel):
features: List[float]
@app.post("/predict")
def predict(input: FeatureInput):
try:
features = np.array(input.features).reshape(1, -1)
prediction = model.predict(features).tolist()
return {"prediction": prediction}
except Exception as e:
error_message = traceback.format_exc()
print("[SERVER ERROR]:", error_message) # This will show in Hugging Face logs
return {"error": str(e)}
@app.get("/")
async def read_root():
return {"message": "Welcome to the API"}
@app.get("/logs")
async def get_logs():
return {"logs": "container logs here"}