Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from huggingface_hub import hf_hub_download | |
| from utils import load_model_by_type, encoder_from_model | |
| from preproc import label_decoding, apple_csv_to_data, apple_extract_beats | |
| import pandas as pd | |
| from io import StringIO | |
| from pathlib import Path | |
| import os | |
| # Get the absolute path to the package directory | |
| PACKAGE_ROOT = Path(__file__).parent.parent.parent | |
| MODEL_DIR = PACKAGE_ROOT / "models" | |
| app = FastAPI( | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| openapi_url="/openapi.json" | |
| ) | |
| # Dynamically set the cache directory | |
| DEFAULT_CACHE_DIR = "./cache" # Local directory for cache | |
| CACHE_DIR = os.getenv("CACHE_DIR", DEFAULT_CACHE_DIR) | |
| # Ensure the cache directory exists | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # Use absolute paths with Path objects | |
| model_cache = {} | |
| encoder_cache = {} | |
| HF_REPO_ID = "fabriciojm/hadt-models" | |
| app.state.model = None # Initialize as None, load on first request | |
| def root(): | |
| return dict(greeting="Hello") | |
| def model_loader(model_name): | |
| # Load model if not already loaded | |
| model_path = MODEL_DIR / f"{model_name}" | |
| encoder_name = encoder_from_model(model_name) | |
| encoder_path = MODEL_DIR / encoder_name | |
| # if model in model_path, load it, otherwise download it from HF | |
| if model_name not in model_cache: | |
| try: | |
| if not model_path.exists(): | |
| # Convert downloaded paths to Path objects | |
| model_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{model_name}", cache_dir=CACHE_DIR)) | |
| encoder_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{encoder_name}", cache_dir=CACHE_DIR)) | |
| model_cache[model_name] = load_model_by_type(model_path) # Ensure string path for loading | |
| encoder_cache[model_name] = encoder_path | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") # Add debug print | |
| raise HTTPException(status_code=404, detail=f"Model {model_name} not found: {str(e)}") | |
| return model_cache[model_name] | |
| async def predict(model_name: str, filepath_csv: UploadFile = File(...)): | |
| model = app.state.model = model_loader(model_name) | |
| # Read the uploaded CSV file | |
| file_content = await filepath_csv.read() | |
| X = pd.read_csv(StringIO(file_content.decode('utf-8'))) | |
| y_pred = model.predict_with_pipeline(X) | |
| # Decode prediction using absolute path | |
| y_pred = label_decoding(values=y_pred, path=encoder_cache[model_name]) | |
| return {"prediction": y_pred} | |
| async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)): | |
| model = app.state.model = model_loader(model_name) | |
| # Read the uploaded CSV file | |
| file_content = await filepath_csv.read() | |
| # X = pd.read_csv(StringIO(file_content.decode('utf-8'))) | |
| X, sample_rate = apple_csv_to_data(file_content) | |
| beats = apple_extract_beats(X, sample_rate) | |
| y_pred = model.predict_with_pipeline(beats) | |
| # Decode prediction using absolute path | |
| y_pred = label_decoding(values=y_pred, path=encoder_cache[model_name]) | |
| return {"prediction": y_pred} | |
| # @app.post("/predict_multibeats") | |
| # async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)): | |
| # # Read the uploaded CSV file | |
| # file_content = await filepath_csv.read() | |
| # X = pd.read_csv(StringIO(file_content.decode('utf-8'))) | |
| # y_pred = model.predict_with_pipeline(X) | |
| # return {"prediction": y_pred} | |