from pydantic import BaseModel from huggingface_hub import hf_hub_download import joblib from fastapi import FastAPI import numpy as np import os from typing import List # Set a different cache directory os.environ["HF_HOME"] = "/tmp/hf_cache" app = FastAPI() # Download model model_path = hf_hub_download( repo_id="haseebnawazz/sleep_stage_classifier-RF", filename="class_balanced_RF_model.joblib", cache_dir="/tmp/hf_cache" ) model = joblib.load(model_path) class FeatureInput(BaseModel): features: List[float] @app.post("/predict") def predict(input: FeatureInput): try: features = np.array(input.features).reshape(1, -1) prediction = model.predict(features).tolist() return {"prediction": prediction} except Exception as e: error_message = traceback.format_exc() print("[SERVER ERROR]:", error_message) # This will show in Hugging Face logs return {"error": str(e)} @app.get("/") async def read_root(): return {"message": "Welcome to the API"} @app.get("/logs") async def get_logs(): return {"logs": "container logs here"}