File size: 3,284 Bytes
d100097 e834c18 7ee2ad3 e834c18 4a03abd e63bfc0 29e4b0d d100097 4a03abd e63bfc0 7ee2ad3 4a03abd e834c18 e63bfc0 e834c18 d100097 29e4b0d e834c18 29e4b0d 4a03abd c65a7f5 29e4b0d e63bfc0 4a03abd c65a7f5 e63bfc0 4a03abd e63bfc0 4a03abd e63bfc0 22c3745 29e4b0d 4a03abd 11683d3 29e4b0d 11683d3 c65a7f5 7ee2ad3 29e4b0d c65a7f5 e3b4d9a db6ba6b 29e4b0d e834c18 4a03abd e834c18 e63bfc0 e834c18 db6ba6b c65a7f5 db6ba6b e834c18 e63bfc0 db6ba6b 29e4b0d e834c18 4a03abd c65a7f5 4a03abd c65a7f5 e834c18 4a03abd d100097 e834c18 e63bfc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import soundfile as sf
from scipy.signal import resample
import tempfile
import os
# Load YAMNet model from TensorFlow Hub
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
# Load class labels
def load_class_map():
class_map_path = tf.keras.utils.get_file(
'yamnet_class_map.csv',
'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
)
with open(class_map_path, 'r') as f:
return [line.strip().split(',')[2] for line in f.readlines()[1:]]
class_names = load_class_map()
# Main classification function
def classify_audio(audio_input):
try:
# Case 1: Filepath from Gradio UI
if isinstance(audio_input, str):
file_path = audio_input
# Case 2: Binary upload (n8n POST) without .name attribute
elif hasattr(audio_input, "read"):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_input.read())
file_path = tmp.name
else:
raise ValueError("Unsupported input format")
# Load audio file
audio_data, sample_rate = sf.read(file_path)
# Cleanup if temp file was created
if 'tmp' in locals():
os.unlink(tmp.name)
# Convert stereo to mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Normalize
audio_data = audio_data / np.max(np.abs(audio_data))
# Resample to 16kHz if needed
target_rate = 16000
if sample_rate != target_rate:
duration = audio_data.shape[0] / sample_rate
new_length = int(duration * target_rate)
audio_data = resample(audio_data, new_length)
sample_rate = target_rate
# Tensor for model
waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
# Run YAMNet model
scores, embeddings, spectrogram = yamnet_model(waveform)
mean_scores = tf.reduce_mean(scores, axis=0).numpy()
top_5 = np.argsort(mean_scores)[::-1][:5]
# Output results
top_prediction = class_names[top_5[0]]
top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
# Plot waveform
fig, ax = plt.subplots()
ax.plot(audio_data)
ax.set_title("Waveform")
ax.set_xlabel("Time (samples)")
ax.set_ylabel("Amplitude")
plt.tight_layout()
return top_prediction, top_scores, fig
except Exception as e:
return f"Error processing audio: {str(e)}", {}, None
# Gradio Interface
interface = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
outputs=[
gr.Textbox(label="Top Prediction"),
gr.Label(label="Top 5 Classes with Scores"),
gr.Plot(label="Waveform")
],
title="Audtheia YAMNet Audio Classifier",
description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
)
if __name__ == "__main__":
interface.launch()
|