|
|
| """ |
| app.py |
| |
| FastAPI inference server for Warbler β bird audio species classifier. |
| |
| Endpoints: |
| POST /predict β upload an audio file, returns top-K species predictions |
| GET /health β liveness check |
| GET /classes β list all supported species |
| |
| Usage: |
| uvicorn app:app --host 0.0.0.0 --port 8000 --reload |
| |
| Environment variables: |
| HF_REPO_ID β HuggingFace model repo to pull weights from |
| PI_URL β Raspberry Pi server URL e.g. http://192.168.1.42:5000 |
| """ |
|
|
| import json |
| import os |
| import tempfile |
| from pathlib import Path |
| from typing import Optional |
|
|
| import httpx |
| import joblib |
| import numpy as np |
| import uvicorn |
| from fastapi import FastAPI, File, HTTPException, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from scripts.build_features import ( |
| AUDIO_DURATION, |
| HOP_LENGTH, |
| N_FFT, |
| N_MELS, |
| N_MFCC, |
| SAMPLE_RATE, |
| compute_mel_spectrogram, |
| extract_mfcc, |
| load_audio, |
| ) |
| from scripts.model import EfficientNetModel |
|
|
| |
| MODELS_DIR = Path("models") |
| CONFIG_PATH = MODELS_DIR / "model_config.json" |
|
|
| |
| PI_URL = os.getenv("PI_URL") |
|
|
| |
| app = FastAPI( |
| title="Warbler β Bird Audio Classifier", |
| description="Identify North American bird species from audio clips.", |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| _model: Optional[EfficientNetModel] = None |
| _le = None |
| _config: Optional[dict] = None |
|
|
|
|
| def _load_from_hub(repo_id: str) -> None: |
| """ |
| Download model artifacts from HuggingFace Hub. |
| |
| Args: |
| repo_id: e.g. 'mg643/chirp_model' |
| """ |
| from huggingface_hub import hf_hub_download |
|
|
| MODELS_DIR.mkdir(parents=True, exist_ok=True) |
| for filename in ["efficientnet_best.pt", "label_encoder.pkl", "model_config.json"]: |
| dest = MODELS_DIR / filename |
| if not dest.exists(): |
| print(f"Downloading {filename} from {repo_id}β¦") |
| path = hf_hub_download(repo_id=repo_id, filename=filename) |
| dest.write_bytes(Path(path).read_bytes()) |
|
|
|
|
| @app.on_event("startup") |
| def load_model() -> None: |
| """Load model weights, label encoder, and config at server startup.""" |
| global _model, _le, _config |
|
|
| hf_repo = "mg643/chirp_model" |
| if hf_repo: |
| _load_from_hub(hf_repo) |
|
|
| if not CONFIG_PATH.exists(): |
| raise RuntimeError("model_config.json not found. Run setup.py or set HF_REPO_ID.") |
|
|
| with open(CONFIG_PATH) as f: |
| _config = json.load(f) |
|
|
| _le = joblib.load(MODELS_DIR / "label_encoder.pkl") |
| _model = EfficientNetModel.load(num_classes=_config["num_classes"], models_dir=MODELS_DIR) |
|
|
| pi_status = f"Pi notifications β {PI_URL}" if PI_URL else "Pi notifications β disabled (set PI_URL)" |
| print(f"Model loaded: {_config['best_model']} | {_config['num_classes']} classes") |
| print(pi_status) |
|
|
|
|
| def _preprocess_audio(audio_bytes: bytes) -> tuple[np.ndarray, np.ndarray]: |
| """ |
| Decode uploaded audio bytes and extract MFCC + mel spectrogram features. |
| |
| Args: |
| audio_bytes: Raw bytes of any librosa-supported audio format. |
| |
| Returns: |
| Tuple of (mfcc_vector, mel_spectrogram). |
| """ |
| with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp: |
| tmp.write(audio_bytes) |
| tmp_path = tmp.name |
|
|
| try: |
| audio = load_audio(tmp_path, sr=SAMPLE_RATE, duration=AUDIO_DURATION) |
| finally: |
| Path(tmp_path).unlink(missing_ok=True) |
|
|
| return extract_mfcc(audio), compute_mel_spectrogram(audio) |
|
|
|
|
| async def _notify_pi(species_code: str, common_name: str, confidence: float) -> None: |
| """ |
| Fire-and-forget POST to the Raspberry Pi server. |
| Fails silently so Pi issues never break the main prediction response. |
| |
| Args: |
| species_code: BirdCLEF 6-letter code e.g. 'norcar' |
| common_name: Human-readable name e.g. 'Northern Cardinal' |
| confidence: Model confidence 0β1 |
| """ |
| if not PI_URL: |
| return |
|
|
| payload = { |
| "species_code": species_code, |
| "common_name": common_name, |
| "confidence": confidence, |
| } |
|
|
| try: |
| async with httpx.AsyncClient(timeout=3.0) as client: |
| res = await client.post(f"{PI_URL}/bird", json=payload) |
| print(f"Pi notified: {res.status_code}") |
| except Exception as exc: |
| |
| print(f"Pi notification failed (continuing): {exc}") |
|
|
|
|
| |
|
|
| @app.get("/health") |
| def health() -> dict: |
| """Liveness check.""" |
| return { |
| "status": "ok", |
| "model": _config["best_model"] if _config else "not loaded", |
| "classes": _config["num_classes"] if _config else 0, |
| "pi": PI_URL if PI_URL else "disabled", |
| } |
|
|
|
|
| @app.get("/classes") |
| def list_classes() -> dict: |
| """Return all species codes the model can predict.""" |
| if _le is None: |
| raise HTTPException(status_code=503, detail="Model not loaded.") |
| return {"classes": _le.classes_.tolist()} |
|
|
|
|
| @app.post("/predict") |
| async def predict(file: UploadFile = File(...), top_k: int = 3) -> dict: |
| """ |
| Identify a bird species from an uploaded audio file. |
| |
| Args: |
| file: Audio file (.ogg, .mp3, .wav, .flac). |
| top_k: Number of top predictions to return (default 3). |
| |
| Returns: |
| JSON with top-K predictions plus top species name and confidence. |
| """ |
| if _model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded.") |
|
|
| audio_bytes = await file.read() |
| if not audio_bytes: |
| raise HTTPException(status_code=400, detail="Uploaded file is empty.") |
|
|
| try: |
| mfcc, mel = _preprocess_audio(audio_bytes) |
| except Exception as exc: |
| raise HTTPException(status_code=422, detail=f"Audio processing failed: {exc}") |
|
|
| |
| probs_flat = _model.predict_proba(mel[np.newaxis])[0] |
| top_k = min(top_k, len(_le.classes_)) |
| top_idx = np.argsort(probs_flat)[::-1][:top_k] |
|
|
| predictions = [ |
| { |
| "species_code": _le.classes_[i], |
| "confidence": round(float(probs_flat[i]), 4), |
| } |
| for i in top_idx |
| ] |
|
|
| top_species = predictions[0]["species_code"] |
| top_conf = predictions[0]["confidence"] |
|
|
| |
| |
| await _notify_pi( |
| species_code=top_species, |
| common_name=top_species, |
| confidence=top_conf, |
| ) |
|
|
| return { |
| "predictions": predictions, |
| "model": _config["best_model"], |
| "top_species": top_species, |
| "confidence": top_conf, |
| } |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |