File size: 2,659 Bytes
d100097
 
e834c18
7ee2ad3
e834c18
4a03abd
e63bfc0
d100097
4a03abd
e63bfc0
7ee2ad3
4a03abd
e834c18
 
 
 
 
 
e63bfc0
e834c18
 
d100097
e3b4d9a
4a03abd
e834c18
e63bfc0
4a03abd
c65a7f5
e63bfc0
4a03abd
 
c65a7f5
e63bfc0
4a03abd
 
e63bfc0
4a03abd
 
 
 
 
e63bfc0
22c3745
c65a7f5
4a03abd
11683d3
4a03abd
11683d3
 
c65a7f5
7ee2ad3
c65a7f5
e3b4d9a
db6ba6b
e63bfc0
e834c18
4a03abd
e834c18
e63bfc0
e834c18
 
db6ba6b
c65a7f5
db6ba6b
e834c18
e63bfc0
db6ba6b
e63bfc0
e834c18
 
4a03abd
c65a7f5
 
4a03abd
c65a7f5
 
e834c18
4a03abd
d100097
 
e834c18
e63bfc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import soundfile as sf
from scipy.signal import resample

# Load YAMNet model from TensorFlow Hub
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")

# Load class labels
def load_class_map():
    class_map_path = tf.keras.utils.get_file(
        'yamnet_class_map.csv',
        'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
    )
    with open(class_map_path, 'r') as f:
        return [line.strip().split(',')[2] for line in f.readlines()[1:]]

class_names = load_class_map()

# Classification function
def classify_audio(file_path):
    try:
        # Load audio
        audio_data, sample_rate = sf.read(file_path)

        # Convert stereo to mono
        if len(audio_data.shape) > 1:
            audio_data = np.mean(audio_data, axis=1)

        # Normalize
        audio_data = audio_data / np.max(np.abs(audio_data))

        # Resample to 16kHz if needed
        target_rate = 16000
        if sample_rate != target_rate:
            duration = audio_data.shape[0] / sample_rate
            new_length = int(duration * target_rate)
            audio_data = resample(audio_data, new_length)
            sample_rate = target_rate

        # Convert to tensor
        waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)

        # Run YAMNet
        scores, embeddings, spectrogram = yamnet_model(waveform)
        mean_scores = tf.reduce_mean(scores, axis=0).numpy()
        top_5 = np.argsort(mean_scores)[::-1][:5]

        top_prediction = class_names[top_5[0]]
        top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}

        # Waveform plot
        fig, ax = plt.subplots()
        ax.plot(audio_data)
        ax.set_title("Waveform")
        ax.set_xlabel("Time (samples)")
        ax.set_ylabel("Amplitude")
        plt.tight_layout()

        return top_prediction, top_scores, fig

    except Exception as e:
        return f"Error processing audio: {str(e)}", {}, None

# Gradio interface (HF-compatible)
interface = gr.Interface(
    fn=classify_audio,
    inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
    outputs=[
        gr.Textbox(label="Top Prediction"),
        gr.Label(label="Top 5 Classes with Scores"),
        gr.Plot(label="Waveform")
    ],
    title="Audtheia YAMNet Audio Classifier",
    description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
)

if __name__ == "__main__":
    interface.launch()