chirp / app.py
mg643's picture
backend changes for pi url
d48ad51
"""
app.py
FastAPI inference server for Warbler β€” bird audio species classifier.
Endpoints:
POST /predict β€” upload an audio file, returns top-K species predictions
GET /health β€” liveness check
GET /classes β€” list all supported species
Usage:
uvicorn app:app --host 0.0.0.0 --port 8000 --reload
Environment variables:
HF_REPO_ID β€” HuggingFace model repo to pull weights from
PI_URL β€” Raspberry Pi server URL e.g. http://192.168.1.42:5000
"""
import json
import os
import tempfile
from pathlib import Path
from typing import Optional
import httpx
import joblib
import numpy as np
import uvicorn
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from scripts.build_features import (
AUDIO_DURATION,
HOP_LENGTH,
N_FFT,
N_MELS,
N_MFCC,
SAMPLE_RATE,
compute_mel_spectrogram,
extract_mfcc,
load_audio,
)
from scripts.model import EfficientNetModel
# ── Paths ──────────────────────────────────────────────────────────────────────
MODELS_DIR = Path("models")
CONFIG_PATH = MODELS_DIR / "model_config.json"
# ── Pi config β€” set PI_URL env var to enable Pi notifications ─────────────────
PI_URL = os.getenv("PI_URL") # e.g. "http://192.168.1.42:5000"
# ── App ────────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Warbler β€” Bird Audio Classifier",
description="Identify North American bird species from audio clips.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Global model state ─────────────────────────────────────────────────────────
_model: Optional[EfficientNetModel] = None
_le = None
_config: Optional[dict] = None
def _load_from_hub(repo_id: str) -> None:
"""
Download model artifacts from HuggingFace Hub.
Args:
repo_id: e.g. 'mg643/chirp_model'
"""
from huggingface_hub import hf_hub_download
MODELS_DIR.mkdir(parents=True, exist_ok=True)
for filename in ["efficientnet_best.pt", "label_encoder.pkl", "model_config.json"]:
dest = MODELS_DIR / filename
if not dest.exists():
print(f"Downloading {filename} from {repo_id}…")
path = hf_hub_download(repo_id=repo_id, filename=filename)
dest.write_bytes(Path(path).read_bytes())
@app.on_event("startup")
def load_model() -> None:
"""Load model weights, label encoder, and config at server startup."""
global _model, _le, _config
hf_repo = "mg643/chirp_model"
if hf_repo:
_load_from_hub(hf_repo)
if not CONFIG_PATH.exists():
raise RuntimeError("model_config.json not found. Run setup.py or set HF_REPO_ID.")
with open(CONFIG_PATH) as f:
_config = json.load(f)
_le = joblib.load(MODELS_DIR / "label_encoder.pkl")
_model = EfficientNetModel.load(num_classes=_config["num_classes"], models_dir=MODELS_DIR)
pi_status = f"Pi notifications β†’ {PI_URL}" if PI_URL else "Pi notifications β†’ disabled (set PI_URL)"
print(f"Model loaded: {_config['best_model']} | {_config['num_classes']} classes")
print(pi_status)
def _preprocess_audio(audio_bytes: bytes) -> tuple[np.ndarray, np.ndarray]:
"""
Decode uploaded audio bytes and extract MFCC + mel spectrogram features.
Args:
audio_bytes: Raw bytes of any librosa-supported audio format.
Returns:
Tuple of (mfcc_vector, mel_spectrogram).
"""
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
audio = load_audio(tmp_path, sr=SAMPLE_RATE, duration=AUDIO_DURATION)
finally:
Path(tmp_path).unlink(missing_ok=True)
return extract_mfcc(audio), compute_mel_spectrogram(audio)
async def _notify_pi(species_code: str, common_name: str, confidence: float) -> None:
"""
Fire-and-forget POST to the Raspberry Pi server.
Fails silently so Pi issues never break the main prediction response.
Args:
species_code: BirdCLEF 6-letter code e.g. 'norcar'
common_name: Human-readable name e.g. 'Northern Cardinal'
confidence: Model confidence 0–1
"""
if not PI_URL:
return
payload = {
"species_code": species_code,
"common_name": common_name,
"confidence": confidence,
}
try:
async with httpx.AsyncClient(timeout=3.0) as client:
res = await client.post(f"{PI_URL}/bird", json=payload)
print(f"Pi notified: {res.status_code}")
except Exception as exc:
# Log but never raise β€” Pi being offline shouldn't break predictions
print(f"Pi notification failed (continuing): {exc}")
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.get("/health")
def health() -> dict:
"""Liveness check."""
return {
"status": "ok",
"model": _config["best_model"] if _config else "not loaded",
"classes": _config["num_classes"] if _config else 0,
"pi": PI_URL if PI_URL else "disabled",
}
@app.get("/classes")
def list_classes() -> dict:
"""Return all species codes the model can predict."""
if _le is None:
raise HTTPException(status_code=503, detail="Model not loaded.")
return {"classes": _le.classes_.tolist()}
@app.post("/predict")
async def predict(file: UploadFile = File(...), top_k: int = 3) -> dict:
"""
Identify a bird species from an uploaded audio file.
Args:
file: Audio file (.ogg, .mp3, .wav, .flac).
top_k: Number of top predictions to return (default 3).
Returns:
JSON with top-K predictions plus top species name and confidence.
"""
if _model is None:
raise HTTPException(status_code=503, detail="Model not loaded.")
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
try:
mfcc, mel = _preprocess_audio(audio_bytes)
except Exception as exc:
raise HTTPException(status_code=422, detail=f"Audio processing failed: {exc}")
# Run inference
probs_flat = _model.predict_proba(mel[np.newaxis])[0]
top_k = min(top_k, len(_le.classes_))
top_idx = np.argsort(probs_flat)[::-1][:top_k]
predictions = [
{
"species_code": _le.classes_[i],
"confidence": round(float(probs_flat[i]), 4),
}
for i in top_idx
]
top_species = predictions[0]["species_code"]
top_conf = predictions[0]["confidence"]
# Look up common name from config classes list (species_code IS the label here)
# Notify Pi asynchronously β€” does not block response
await _notify_pi(
species_code=top_species,
common_name=top_species, # replace with a lookup dict if you have common names
confidence=top_conf,
)
return {
"predictions": predictions,
"model": _config["best_model"],
"top_species": top_species,
"confidence": top_conf,
}
# ── Entry point ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)