bird-identifier / app.py
AKMESSI's picture
Update app.py
fa148af verified
# ==========================================
# Mumbai Bird Call Identifier β€” FINAL FIX
# ==========================================
import streamlit as st
import torch
import torchaudio
import numpy as np
import io
import matplotlib.cm as cm
from torchvision import models, transforms
from PIL import Image
# ================== PAGE CONFIG ==================
st.set_page_config(
page_title="Mumbai Bird Call Identifier",
page_icon="🐦",
layout="centered"
)
# ================== CONSTANTS ==================
SAMPLE_RATE = 22050
DURATION = 5
TARGET_SAMPLES = SAMPLE_RATE * DURATION
HIGH_CONF = 0.60
MEDIUM_CONF = 0.35
# ⚠️ MUST MATCH TRAINING
SPECTROGRAM_COLORMAP = "magma" # ← change ONLY if training used something else
# ================== LOAD MODEL ==================
@st.cache_resource
def load_model():
checkpoint = torch.load("multi_species_model.pth", map_location="cpu")
label_map = checkpoint["label_map"]
index_to_species = {v: k for k, v in label_map.items()}
class_names = [index_to_species[i] for i in range(len(index_to_species))]
model = models.mobilenet_v3_small(pretrained=False)
model.classifier[3] = torch.nn.Linear(
model.classifier[3].in_features,
len(class_names)
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, class_names
model, class_names = load_model()
# ================== AUDIO TRANSFORMS ==================
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=512,
win_length=512,
hop_length=256,
f_min=50,
f_max=11000,
n_mels=128,
norm="slaney",
mel_scale="slaney"
)
db_transform = torchaudio.transforms.AmplitudeToDB(
stype="power",
top_db=80
)
# ================== IMAGE TRANSFORM (MATCH TRAINING) ==================
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ================== UI ==================
st.title("🐦 Mumbai Balcony Bird Call Identifier")
st.markdown(
"""
Identify **204 Indian bird species** from their calls.
Model trained on **PNG spectrogram images** from Mumbai & Maharashtra.
"""
)
audio_file = st.file_uploader(
"Upload bird call audio (WAV / MP3 / M4A / OGG)",
type=["wav", "mp3", "m4a", "ogg"]
)
# ================== PROCESS ==================
if audio_file:
st.audio(audio_file)
with st.spinner("🧠 Analyzing bird call..."):
# -------- LOAD AUDIO --------
waveform, sr = torchaudio.load(io.BytesIO(audio_file.read()))
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(
waveform, sr, SAMPLE_RATE
)
# -------- ENERGY-BASED CROP --------
energy = waveform.abs()
peak = energy.argmax()
start = max(0, peak - TARGET_SAMPLES // 2)
end = start + TARGET_SAMPLES
waveform = waveform[:, start:end]
if waveform.shape[1] < TARGET_SAMPLES:
waveform = torch.nn.functional.pad(
waveform,
(0, TARGET_SAMPLES - waveform.shape[1])
)
# -------- MEL SPECTROGRAM --------
mel = mel_transform(waveform)
mel = db_transform(mel)
mel = mel.squeeze(0)
# -------- MEL β†’ COLORED PNG (CRITICAL FIX) --------
mel_np = mel.numpy()
mel_norm = (mel_np - mel_np.min()) / (mel_np.ptp() + 1e-8)
cmap = cm.get_cmap(SPECTROGRAM_COLORMAP)
colored = cmap(mel_norm)[:, :, :3] # drop alpha
mel_img = (colored * 255).astype(np.uint8)
mel_pil = Image.fromarray(mel_img)
# -------- MODEL INPUT --------
model_input = val_transform(mel_pil).unsqueeze(0)
# -------- INFERENCE --------
with torch.no_grad():
logits = model(model_input)
probs = torch.softmax(logits[0], dim=0)
top5_probs, top5_idx = torch.topk(probs, 5)
# ================== RESULTS ==================
st.markdown("---")
top1_prob = top5_probs[0].item()
top1_species = class_names[top5_idx[0]]
if top1_prob >= HIGH_CONF:
st.success("βœ… High confidence identification")
elif top1_prob >= MEDIUM_CONF:
st.warning("⚠️ Medium confidence identification")
else:
st.error("❓ Low confidence – possibly unknown species")
st.markdown(f"## 🐦 {top1_species}")
st.metric("Confidence", f"{top1_prob*100:.1f}%")
st.markdown("### πŸ” Other possible matches")
for i in range(1, 5):
st.markdown(
f"- **{class_names[top5_idx[i]]}** β€” {top5_probs[i].item():.1%}"
)
st.markdown("---")
st.subheader("πŸ“Š Spectrogram used by the model")
st.image(mel_img, use_container_width=True)
else:
st.info("πŸ‘† Upload a bird call audio file to begin")
# ================== FOOTER ==================
st.markdown("---")
st.caption(
"⚠️ This model predicts among known species only. "
"Low confidence may indicate an unseen species or noisy audio."
)