Spaces:
Sleeping
Sleeping
| # ========================================== | |
| # Mumbai Bird Call Identifier β FINAL FIX | |
| # ========================================== | |
| import streamlit as st | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import io | |
| import matplotlib.cm as cm | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # ================== PAGE CONFIG ================== | |
| st.set_page_config( | |
| page_title="Mumbai Bird Call Identifier", | |
| page_icon="π¦", | |
| layout="centered" | |
| ) | |
| # ================== CONSTANTS ================== | |
| SAMPLE_RATE = 22050 | |
| DURATION = 5 | |
| TARGET_SAMPLES = SAMPLE_RATE * DURATION | |
| HIGH_CONF = 0.60 | |
| MEDIUM_CONF = 0.35 | |
| # β οΈ MUST MATCH TRAINING | |
| SPECTROGRAM_COLORMAP = "magma" # β change ONLY if training used something else | |
| # ================== LOAD MODEL ================== | |
| def load_model(): | |
| checkpoint = torch.load("multi_species_model.pth", map_location="cpu") | |
| label_map = checkpoint["label_map"] | |
| index_to_species = {v: k for k, v in label_map.items()} | |
| class_names = [index_to_species[i] for i in range(len(index_to_species))] | |
| model = models.mobilenet_v3_small(pretrained=False) | |
| model.classifier[3] = torch.nn.Linear( | |
| model.classifier[3].in_features, | |
| len(class_names) | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| return model, class_names | |
| model, class_names = load_model() | |
| # ================== AUDIO TRANSFORMS ================== | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_fft=512, | |
| win_length=512, | |
| hop_length=256, | |
| f_min=50, | |
| f_max=11000, | |
| n_mels=128, | |
| norm="slaney", | |
| mel_scale="slaney" | |
| ) | |
| db_transform = torchaudio.transforms.AmplitudeToDB( | |
| stype="power", | |
| top_db=80 | |
| ) | |
| # ================== IMAGE TRANSFORM (MATCH TRAINING) ================== | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # ================== UI ================== | |
| st.title("π¦ Mumbai Balcony Bird Call Identifier") | |
| st.markdown( | |
| """ | |
| Identify **204 Indian bird species** from their calls. | |
| Model trained on **PNG spectrogram images** from Mumbai & Maharashtra. | |
| """ | |
| ) | |
| audio_file = st.file_uploader( | |
| "Upload bird call audio (WAV / MP3 / M4A / OGG)", | |
| type=["wav", "mp3", "m4a", "ogg"] | |
| ) | |
| # ================== PROCESS ================== | |
| if audio_file: | |
| st.audio(audio_file) | |
| with st.spinner("π§ Analyzing bird call..."): | |
| # -------- LOAD AUDIO -------- | |
| waveform, sr = torchaudio.load(io.BytesIO(audio_file.read())) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| if sr != SAMPLE_RATE: | |
| waveform = torchaudio.functional.resample( | |
| waveform, sr, SAMPLE_RATE | |
| ) | |
| # -------- ENERGY-BASED CROP -------- | |
| energy = waveform.abs() | |
| peak = energy.argmax() | |
| start = max(0, peak - TARGET_SAMPLES // 2) | |
| end = start + TARGET_SAMPLES | |
| waveform = waveform[:, start:end] | |
| if waveform.shape[1] < TARGET_SAMPLES: | |
| waveform = torch.nn.functional.pad( | |
| waveform, | |
| (0, TARGET_SAMPLES - waveform.shape[1]) | |
| ) | |
| # -------- MEL SPECTROGRAM -------- | |
| mel = mel_transform(waveform) | |
| mel = db_transform(mel) | |
| mel = mel.squeeze(0) | |
| # -------- MEL β COLORED PNG (CRITICAL FIX) -------- | |
| mel_np = mel.numpy() | |
| mel_norm = (mel_np - mel_np.min()) / (mel_np.ptp() + 1e-8) | |
| cmap = cm.get_cmap(SPECTROGRAM_COLORMAP) | |
| colored = cmap(mel_norm)[:, :, :3] # drop alpha | |
| mel_img = (colored * 255).astype(np.uint8) | |
| mel_pil = Image.fromarray(mel_img) | |
| # -------- MODEL INPUT -------- | |
| model_input = val_transform(mel_pil).unsqueeze(0) | |
| # -------- INFERENCE -------- | |
| with torch.no_grad(): | |
| logits = model(model_input) | |
| probs = torch.softmax(logits[0], dim=0) | |
| top5_probs, top5_idx = torch.topk(probs, 5) | |
| # ================== RESULTS ================== | |
| st.markdown("---") | |
| top1_prob = top5_probs[0].item() | |
| top1_species = class_names[top5_idx[0]] | |
| if top1_prob >= HIGH_CONF: | |
| st.success("β High confidence identification") | |
| elif top1_prob >= MEDIUM_CONF: | |
| st.warning("β οΈ Medium confidence identification") | |
| else: | |
| st.error("β Low confidence β possibly unknown species") | |
| st.markdown(f"## π¦ {top1_species}") | |
| st.metric("Confidence", f"{top1_prob*100:.1f}%") | |
| st.markdown("### π Other possible matches") | |
| for i in range(1, 5): | |
| st.markdown( | |
| f"- **{class_names[top5_idx[i]]}** β {top5_probs[i].item():.1%}" | |
| ) | |
| st.markdown("---") | |
| st.subheader("π Spectrogram used by the model") | |
| st.image(mel_img, use_container_width=True) | |
| else: | |
| st.info("π Upload a bird call audio file to begin") | |
| # ================== FOOTER ================== | |
| st.markdown("---") | |
| st.caption( | |
| "β οΈ This model predicts among known species only. " | |
| "Low confidence may indicate an unseen species or noisy audio." | |
| ) | |