File size: 3,284 Bytes
d100097
 
e834c18
7ee2ad3
e834c18
4a03abd
e63bfc0
29e4b0d
 
d100097
4a03abd
e63bfc0
7ee2ad3
4a03abd
e834c18
 
 
 
 
 
e63bfc0
e834c18
 
d100097
29e4b0d
 
e834c18
29e4b0d
 
 
 
 
 
 
 
 
 
 
 
 
4a03abd
c65a7f5
29e4b0d
 
 
 
e63bfc0
4a03abd
 
c65a7f5
e63bfc0
4a03abd
 
e63bfc0
4a03abd
 
 
 
 
e63bfc0
22c3745
29e4b0d
4a03abd
11683d3
29e4b0d
11683d3
 
c65a7f5
7ee2ad3
29e4b0d
c65a7f5
e3b4d9a
db6ba6b
29e4b0d
e834c18
4a03abd
e834c18
e63bfc0
e834c18
 
db6ba6b
c65a7f5
db6ba6b
e834c18
e63bfc0
db6ba6b
29e4b0d
e834c18
 
4a03abd
c65a7f5
 
4a03abd
c65a7f5
 
e834c18
4a03abd
d100097
 
e834c18
e63bfc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import soundfile as sf
from scipy.signal import resample
import tempfile
import os

# Load YAMNet model from TensorFlow Hub
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")

# Load class labels
def load_class_map():
    class_map_path = tf.keras.utils.get_file(
        'yamnet_class_map.csv',
        'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
    )
    with open(class_map_path, 'r') as f:
        return [line.strip().split(',')[2] for line in f.readlines()[1:]]

class_names = load_class_map()

# Main classification function
def classify_audio(audio_input):
    try:
        # Case 1: Filepath from Gradio UI
        if isinstance(audio_input, str):
            file_path = audio_input

        # Case 2: Binary upload (n8n POST) without .name attribute
        elif hasattr(audio_input, "read"):
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
                tmp.write(audio_input.read())
                file_path = tmp.name
        else:
            raise ValueError("Unsupported input format")

        # Load audio file
        audio_data, sample_rate = sf.read(file_path)

        # Cleanup if temp file was created
        if 'tmp' in locals():
            os.unlink(tmp.name)

        # Convert stereo to mono
        if len(audio_data.shape) > 1:
            audio_data = np.mean(audio_data, axis=1)

        # Normalize
        audio_data = audio_data / np.max(np.abs(audio_data))

        # Resample to 16kHz if needed
        target_rate = 16000
        if sample_rate != target_rate:
            duration = audio_data.shape[0] / sample_rate
            new_length = int(duration * target_rate)
            audio_data = resample(audio_data, new_length)
            sample_rate = target_rate

        # Tensor for model
        waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)

        # Run YAMNet model
        scores, embeddings, spectrogram = yamnet_model(waveform)
        mean_scores = tf.reduce_mean(scores, axis=0).numpy()
        top_5 = np.argsort(mean_scores)[::-1][:5]

        # Output results
        top_prediction = class_names[top_5[0]]
        top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}

        # Plot waveform
        fig, ax = plt.subplots()
        ax.plot(audio_data)
        ax.set_title("Waveform")
        ax.set_xlabel("Time (samples)")
        ax.set_ylabel("Amplitude")
        plt.tight_layout()

        return top_prediction, top_scores, fig

    except Exception as e:
        return f"Error processing audio: {str(e)}", {}, None

# Gradio Interface
interface = gr.Interface(
    fn=classify_audio,
    inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
    outputs=[
        gr.Textbox(label="Top Prediction"),
        gr.Label(label="Top 5 Classes with Scores"),
        gr.Plot(label="Waveform")
    ],
    title="Audtheia YAMNet Audio Classifier",
    description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
)

if __name__ == "__main__":
    interface.launch()