Kaworu17 commited on
Commit
e834c18
·
verified ·
1 Parent(s): 62aae20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -65
app.py CHANGED
@@ -1,84 +1,76 @@
1
- import gradio as gr
2
- import numpy as np
3
  import tensorflow as tf
4
  import tensorflow_hub as hub
5
- import tensorflow_io as tfio
6
  import matplotlib.pyplot as plt
7
- import io
8
- from PIL import Image
9
- import pandas as pd
10
 
11
- # Load YAMNet model
12
- yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
13
  yamnet_model = hub.load(yamnet_model_handle)
14
 
15
- # Load class names
16
- class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')
17
- class_names = list(pd.read_csv(class_map_path)['display_name'])
 
 
 
 
 
 
 
 
18
 
19
- # Decode and resample audio
20
- def load_wav_16k_mono(audio_bytes):
21
- audio_tensor, sample_rate = tf.audio.decode_wav(audio_bytes, desired_channels=1)
22
- audio_tensor = tf.squeeze(audio_tensor, axis=-1)
23
- audio_tensor = tfio.audio.resample(audio_tensor, rate_in=tf.cast(sample_rate, tf.int64), rate_out=16000)
24
- return audio_tensor
25
 
26
- # Plot waveform
27
- def plot_waveform(audio_tensor):
28
- plt.figure(figsize=(8, 2))
29
- plt.plot(audio_tensor.numpy())
30
- plt.title("Waveform")
31
- plt.tight_layout()
32
- buf = io.BytesIO()
33
- plt.savefig(buf, format='png')
34
- plt.close()
35
- buf.seek(0)
36
- return Image.open(buf)
37
 
38
- # Plot log-mel spectrogram
39
- def plot_spectrogram(spectrogram):
40
- plt.figure(figsize=(8, 3))
41
- plt.imshow(spectrogram.numpy().T, aspect='auto', origin='lower', interpolation='nearest')
42
- plt.title("Log-mel Spectrogram")
43
- plt.xlabel("Frames")
44
- plt.ylabel("Mel Bands")
45
- plt.tight_layout()
46
- buf = io.BytesIO()
47
- plt.savefig(buf, format='png')
48
- plt.close()
49
- buf.seek(0)
50
- return Image.open(buf)
51
 
52
- # Gradio interface logic
53
- def classify_sound(audio_file):
54
- if isinstance(audio_file, str):
55
- audio_bytes = tf.io.read_file(audio_file)
56
- else:
57
- audio_bytes = audio_file.read()
58
 
59
- waveform = load_wav_16k_mono(audio_bytes)
60
- scores, embeddings, spectrogram = yamnet_model(waveform)
 
 
61
 
62
- mean_scores = tf.reduce_mean(scores, axis=0)
63
- top_class = tf.math.argmax(mean_scores)
64
- inferred_class = class_names[top_class]
 
 
 
 
65
 
66
- waveform_img = plot_waveform(waveform)
67
- spectrogram_img = plot_spectrogram(spectrogram)
68
 
69
- return inferred_class, waveform_img, spectrogram_img
 
70
 
71
- # Gradio app
72
- app = gr.Interface(
73
- fn=classify_sound,
74
- inputs=gr.Audio(type="file", label="Upload audio file"),
75
  outputs=[
76
- gr.Text(label="Predicted Class"),
77
- gr.Image(type="pil", label="Waveform"),
78
- gr.Image(type="pil", label="Log-mel Spectrogram")
79
  ],
80
- title="YAMNet Audio Classifier",
81
- description="Classify environmental and animal sounds using YAMNet. Visualize waveform and log-mel spectrogram."
82
  )
83
 
84
- app.launch()
 
 
 
 
 
1
  import tensorflow as tf
2
  import tensorflow_hub as hub
3
+ import numpy as np
4
  import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ import os
7
+ import scipy.io.wavfile as wavfile
8
 
9
+ # Load YAMNet model from TensorFlow Hub
10
+ yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
11
  yamnet_model = hub.load(yamnet_model_handle)
12
 
13
+ # Load class names for YAMNet
14
+ def load_class_map():
15
+ class_map_path = tf.keras.utils.get_file(
16
+ 'yamnet_class_map.csv',
17
+ 'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
18
+ )
19
+ with open(class_map_path, 'r') as f:
20
+ class_names = [line.strip().split(',')[2] for line in f.readlines()[1:]]
21
+ return class_names
22
+
23
+ class_names = load_class_map()
24
 
25
+ # Function to preprocess and classify audio
26
+ def classify_audio(file_path):
27
+ try:
28
+ # Read audio file
29
+ sample_rate, audio_data = wavfile.read(file_path)
 
30
 
31
+ # Ensure mono
32
+ if len(audio_data.shape) > 1:
33
+ audio_data = np.mean(audio_data, axis=1)
 
 
 
 
 
 
 
 
34
 
35
+ # Normalize audio
36
+ audio_data = audio_data / np.max(np.abs(audio_data))
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Run inference
39
+ scores, embeddings, spectrogram = yamnet_model(audio_data)
40
+ scores_np = scores.numpy()
 
 
 
41
 
42
+ # Get mean scores
43
+ mean_scores = np.mean(scores_np, axis=0)
44
+ top_5_indices = np.argsort(mean_scores)[::-1][:5]
45
+ top_class = class_names[top_5_indices[0]]
46
 
47
+ # Prepare waveform plot
48
+ fig, ax = plt.subplots()
49
+ ax.plot(audio_data)
50
+ ax.set_title("Waveform")
51
+ ax.set_xlabel("Sample Index")
52
+ ax.set_ylabel("Amplitude")
53
+ plt.tight_layout()
54
 
55
+ # Return predictions and plot
56
+ return top_class, {class_names[i]: float(mean_scores[i]) for i in top_5_indices}, fig
57
 
58
+ except Exception as e:
59
+ return f"Error processing audio: {str(e)}", {}, None
60
 
61
+ # Build Gradio interface
62
+ interface = gr.Interface(
63
+ fn=classify_audio,
64
+ inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
65
  outputs=[
66
+ gr.Textbox(label="Top Prediction"),
67
+ gr.Label(label="Top 5 Classes with Scores"),
68
+ gr.Plot(label="Waveform")
69
  ],
70
+ title="Audtheia YAMNet Audio Classifier",
71
+ description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
72
  )
73
 
74
+ # Launch app
75
+ if __name__ == "__main__":
76
+ interface.launch()