File size: 3,177 Bytes
b3dff30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import pipeline
import numpy as np

classifier = pipeline(
    "audio-classification",
    model="dima806/bird_sounds_classification",
    device=-1,
)

# Get the full species list from the model config
SPECIES_LIST = sorted(set(
    classifier.model.config.id2label.values()
))

def classify_bird(audio):
    if audio is None:
        return "Please upload or record an audio file."

    sr, y = audio

    # Convert to float32 and normalize
    if y.dtype == np.int16:
        y = y.astype(np.float32) / 32768.0
    elif y.dtype == np.int32:
        y = y.astype(np.float32) / 2147483648.0
    elif y.dtype != np.float32:
        y = y.astype(np.float32)

    # If stereo, take first channel
    if len(y.shape) > 1:
        y = y[:, 0]

    # Resample to 16kHz if needed (model expects 16kHz)
    if sr != 16000:
        # Simple resampling using numpy interpolation
        duration = len(y) / sr
        new_length = int(duration * 16000)
        y = np.interp(
            np.linspace(0, len(y) - 1, new_length),
            np.arange(len(y)),
            y,
        )
        sr = 16000

    results = classifier({"sampling_rate": sr, "raw": y}, top_k=5)

    # Format output
    lines = []
    for i, pred in enumerate(results, 1):
        score = pred["score"]
        label = pred["label"]

        if i == 1 and score < 0.40:
            lines.append("Not confident - this may not be a recognizable bird song,")
            lines.append("or the species may not be in this model's training data.")
            lines.append(f"Best guess: {label} ({score:.0%})")
            lines.append("")
            lines.append("Top 5 predictions:")
            lines.append(f"  1. {label} - {score:.1%}")
            continue

        bar_length = int(score * 20)
        bar = "#" * bar_length + "." * (20 - bar_length)
        lines.append(f"{i}. {label}")
        lines.append(f"   {bar}  {score:.1%}")

    return "\n".join(lines)


demo = gr.Interface(
    fn=classify_bird,
    inputs=gr.Audio(
        label="Upload or Record a Bird Song",
        type="numpy",
    ),
    outputs=gr.Textbox(label="Classification Results", lines=12),
    title="Bird Song Classifier",
    description=(
        "Upload a bird song recording and this model will try to identify the species. "
        "Uses dima806/bird_sounds_classification, a wav2vec2-based classifier trained on "
        "50 bird species (mostly Tinamous, Guans, and Chachalacas - neotropical birds). "
        "Best results with clean recordings of 3+ seconds.\n\n"
        "Note: This model was trained on tropical/neotropical species. "
        "It won't recognize common North American backyard birds like cardinals or robins. "
        "That's a training data limitation, not an architecture limitation.\n\n"
        "Try recordings from Xeno-Canto (https://xeno-canto.org/) - search for species like "
        "Great Tinamou, Plain Chachalaca, or Crested Guan."
    ),
    article=(
        "### Species this model knows\n\n"
        + ", ".join(SPECIES_LIST)
        + "\n\n---\n*Riley's Space 2 - AI + Research Level 2*"
    ),
    theme=gr.themes.Soft(),
)

demo.launch()