File size: 5,284 Bytes
fa148af
 
 
acd11c2
062923b
 
acd11c2
062923b
fa148af
 
acd11c2
062923b
e9f7d24
062923b
 
 
 
 
 
 
acd11c2
 
 
 
 
 
 
 
fa148af
 
 
acd11c2
062923b
acd11c2
062923b
acd11c2
 
 
 
 
062923b
acd11c2
 
 
 
 
 
062923b
acd11c2
062923b
 
 
acd11c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44dff6a
acd11c2
 
 
 
 
 
 
e9f7d24
acd11c2
 
fa148af
062923b
acd11c2
 
fa148af
 
acd11c2
 
 
 
 
 
 
 
fa148af
acd11c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa148af
acd11c2
fa148af
 
 
 
 
 
 
acd11c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa148af
acd11c2
 
 
 
 
 
 
 
fa148af
 
062923b
 
acd11c2
 
 
062923b
acd11c2
fa148af
 
acd11c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# ==========================================
# Mumbai Bird Call Identifier β€” FINAL FIX
# ==========================================

import streamlit as st
import torch
import torchaudio
import numpy as np
import io
import matplotlib.cm as cm
from torchvision import models, transforms
from PIL import Image

# ================== PAGE CONFIG ==================
st.set_page_config(
    page_title="Mumbai Bird Call Identifier",
    page_icon="🐦",
    layout="centered"
)

# ================== CONSTANTS ==================
SAMPLE_RATE = 22050
DURATION = 5
TARGET_SAMPLES = SAMPLE_RATE * DURATION

HIGH_CONF = 0.60
MEDIUM_CONF = 0.35

# ⚠️ MUST MATCH TRAINING
SPECTROGRAM_COLORMAP = "magma"   # ← change ONLY if training used something else

# ================== LOAD MODEL ==================
@st.cache_resource
def load_model():
    checkpoint = torch.load("multi_species_model.pth", map_location="cpu")

    label_map = checkpoint["label_map"]
    index_to_species = {v: k for k, v in label_map.items()}
    class_names = [index_to_species[i] for i in range(len(index_to_species))]

    model = models.mobilenet_v3_small(pretrained=False)
    model.classifier[3] = torch.nn.Linear(
        model.classifier[3].in_features,
        len(class_names)
    )

    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    return model, class_names


model, class_names = load_model()

# ================== AUDIO TRANSFORMS ==================
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=512,
    win_length=512,
    hop_length=256,
    f_min=50,
    f_max=11000,
    n_mels=128,
    norm="slaney",
    mel_scale="slaney"
)

db_transform = torchaudio.transforms.AmplitudeToDB(
    stype="power",
    top_db=80
)

# ================== IMAGE TRANSFORM (MATCH TRAINING) ==================
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# ================== UI ==================
st.title("🐦 Mumbai Balcony Bird Call Identifier")
st.markdown(
    """
Identify **204 Indian bird species** from their calls.  
Model trained on **PNG spectrogram images** from Mumbai & Maharashtra.
"""
)

audio_file = st.file_uploader(
    "Upload bird call audio (WAV / MP3 / M4A / OGG)",
    type=["wav", "mp3", "m4a", "ogg"]
)

# ================== PROCESS ==================
if audio_file:
    st.audio(audio_file)

    with st.spinner("🧠 Analyzing bird call..."):
        # -------- LOAD AUDIO --------
        waveform, sr = torchaudio.load(io.BytesIO(audio_file.read()))

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        if sr != SAMPLE_RATE:
            waveform = torchaudio.functional.resample(
                waveform, sr, SAMPLE_RATE
            )

        # -------- ENERGY-BASED CROP --------
        energy = waveform.abs()
        peak = energy.argmax()
        start = max(0, peak - TARGET_SAMPLES // 2)
        end = start + TARGET_SAMPLES
        waveform = waveform[:, start:end]

        if waveform.shape[1] < TARGET_SAMPLES:
            waveform = torch.nn.functional.pad(
                waveform,
                (0, TARGET_SAMPLES - waveform.shape[1])
            )

        # -------- MEL SPECTROGRAM --------
        mel = mel_transform(waveform)
        mel = db_transform(mel)
        mel = mel.squeeze(0)

        # -------- MEL β†’ COLORED PNG (CRITICAL FIX) --------
        mel_np = mel.numpy()
        mel_norm = (mel_np - mel_np.min()) / (mel_np.ptp() + 1e-8)

        cmap = cm.get_cmap(SPECTROGRAM_COLORMAP)
        colored = cmap(mel_norm)[:, :, :3]  # drop alpha
        mel_img = (colored * 255).astype(np.uint8)

        mel_pil = Image.fromarray(mel_img)

        # -------- MODEL INPUT --------
        model_input = val_transform(mel_pil).unsqueeze(0)

        # -------- INFERENCE --------
        with torch.no_grad():
            logits = model(model_input)
            probs = torch.softmax(logits[0], dim=0)

        top5_probs, top5_idx = torch.topk(probs, 5)

    # ================== RESULTS ==================
    st.markdown("---")

    top1_prob = top5_probs[0].item()
    top1_species = class_names[top5_idx[0]]

    if top1_prob >= HIGH_CONF:
        st.success("βœ… High confidence identification")
    elif top1_prob >= MEDIUM_CONF:
        st.warning("⚠️ Medium confidence identification")
    else:
        st.error("❓ Low confidence – possibly unknown species")

    st.markdown(f"## 🐦 {top1_species}")
    st.metric("Confidence", f"{top1_prob*100:.1f}%")

    st.markdown("### πŸ” Other possible matches")
    for i in range(1, 5):
        st.markdown(
            f"- **{class_names[top5_idx[i]]}** β€” {top5_probs[i].item():.1%}"
        )

    st.markdown("---")
    st.subheader("πŸ“Š Spectrogram used by the model")
    st.image(mel_img, use_container_width=True)

else:
    st.info("πŸ‘† Upload a bird call audio file to begin")

# ================== FOOTER ==================
st.markdown("---")
st.caption(
    "⚠️ This model predicts among known species only. "
    "Low confidence may indicate an unseen species or noisy audio."
)