# ========================================== # Mumbai Bird Call Identifier — FINAL FIX # ========================================== import streamlit as st import torch import torchaudio import numpy as np import io import matplotlib.cm as cm from torchvision import models, transforms from PIL import Image # ================== PAGE CONFIG ================== st.set_page_config( page_title="Mumbai Bird Call Identifier", page_icon="🐦", layout="centered" ) # ================== CONSTANTS ================== SAMPLE_RATE = 22050 DURATION = 5 TARGET_SAMPLES = SAMPLE_RATE * DURATION HIGH_CONF = 0.60 MEDIUM_CONF = 0.35 # ⚠️ MUST MATCH TRAINING SPECTROGRAM_COLORMAP = "magma" # ← change ONLY if training used something else # ================== LOAD MODEL ================== @st.cache_resource def load_model(): checkpoint = torch.load("multi_species_model.pth", map_location="cpu") label_map = checkpoint["label_map"] index_to_species = {v: k for k, v in label_map.items()} class_names = [index_to_species[i] for i in range(len(index_to_species))] model = models.mobilenet_v3_small(pretrained=False) model.classifier[3] = torch.nn.Linear( model.classifier[3].in_features, len(class_names) ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, class_names model, class_names = load_model() # ================== AUDIO TRANSFORMS ================== mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=SAMPLE_RATE, n_fft=512, win_length=512, hop_length=256, f_min=50, f_max=11000, n_mels=128, norm="slaney", mel_scale="slaney" ) db_transform = torchaudio.transforms.AmplitudeToDB( stype="power", top_db=80 ) # ================== IMAGE TRANSFORM (MATCH TRAINING) ================== val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ================== UI ================== st.title("🐦 Mumbai Balcony Bird Call Identifier") st.markdown( """ Identify **204 Indian bird species** from their calls. Model trained on **PNG spectrogram images** from Mumbai & Maharashtra. """ ) audio_file = st.file_uploader( "Upload bird call audio (WAV / MP3 / M4A / OGG)", type=["wav", "mp3", "m4a", "ogg"] ) # ================== PROCESS ================== if audio_file: st.audio(audio_file) with st.spinner("🧠 Analyzing bird call..."): # -------- LOAD AUDIO -------- waveform, sr = torchaudio.load(io.BytesIO(audio_file.read())) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != SAMPLE_RATE: waveform = torchaudio.functional.resample( waveform, sr, SAMPLE_RATE ) # -------- ENERGY-BASED CROP -------- energy = waveform.abs() peak = energy.argmax() start = max(0, peak - TARGET_SAMPLES // 2) end = start + TARGET_SAMPLES waveform = waveform[:, start:end] if waveform.shape[1] < TARGET_SAMPLES: waveform = torch.nn.functional.pad( waveform, (0, TARGET_SAMPLES - waveform.shape[1]) ) # -------- MEL SPECTROGRAM -------- mel = mel_transform(waveform) mel = db_transform(mel) mel = mel.squeeze(0) # -------- MEL → COLORED PNG (CRITICAL FIX) -------- mel_np = mel.numpy() mel_norm = (mel_np - mel_np.min()) / (mel_np.ptp() + 1e-8) cmap = cm.get_cmap(SPECTROGRAM_COLORMAP) colored = cmap(mel_norm)[:, :, :3] # drop alpha mel_img = (colored * 255).astype(np.uint8) mel_pil = Image.fromarray(mel_img) # -------- MODEL INPUT -------- model_input = val_transform(mel_pil).unsqueeze(0) # -------- INFERENCE -------- with torch.no_grad(): logits = model(model_input) probs = torch.softmax(logits[0], dim=0) top5_probs, top5_idx = torch.topk(probs, 5) # ================== RESULTS ================== st.markdown("---") top1_prob = top5_probs[0].item() top1_species = class_names[top5_idx[0]] if top1_prob >= HIGH_CONF: st.success("✅ High confidence identification") elif top1_prob >= MEDIUM_CONF: st.warning("⚠️ Medium confidence identification") else: st.error("❓ Low confidence – possibly unknown species") st.markdown(f"## 🐦 {top1_species}") st.metric("Confidence", f"{top1_prob*100:.1f}%") st.markdown("### 🔍 Other possible matches") for i in range(1, 5): st.markdown( f"- **{class_names[top5_idx[i]]}** — {top5_probs[i].item():.1%}" ) st.markdown("---") st.subheader("📊 Spectrogram used by the model") st.image(mel_img, use_container_width=True) else: st.info("👆 Upload a bird call audio file to begin") # ================== FOOTER ================== st.markdown("---") st.caption( "⚠️ This model predicts among known species only. " "Low confidence may indicate an unseen species or noisy audio." )