Kaworu17 commited on
Commit
7ee2ad3
·
verified ·
1 Parent(s): b16542e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -43
app.py CHANGED
@@ -1,65 +1,84 @@
 
 
1
  import tensorflow as tf
2
  import tensorflow_hub as hub
3
  import tensorflow_io as tfio
4
- import numpy as np
5
- import gradio as gr
 
6
  import pandas as pd
7
- import os
8
 
9
- # Load class names for AudioSet/YAMNet
10
  yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
11
  yamnet_model = hub.load(yamnet_model_handle)
 
 
12
  class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')
13
  class_names = list(pd.read_csv(class_map_path)['display_name'])
14
 
15
- # Load WAV, normalize and resample
16
- def load_wav_16k_mono(wav_bytes):
17
- audio, sample_rate = tf.audio.decode_wav(wav_bytes, desired_channels=1)
18
- audio = tf.squeeze(audio, axis=-1)
19
- audio = tfio.audio.resample(audio, rate_in=sample_rate, rate_out=16000)
20
- return audio
21
-
22
- # Create transfer learning model (simple dense classifier on top of YAMNet embeddings)
23
- def create_classifier():
24
- return tf.keras.Sequential([
25
- tf.keras.layers.Input(shape=(1024,), name='input_embedding'),
26
- tf.keras.layers.Dense(512, activation='relu'),
27
- tf.keras.layers.Dense(521) # 521 classes from YAMNet
28
- ])
29
 
30
- classifier_model = create_classifier()
31
- classifier_model.compile(
32
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
33
- optimizer='adam',
34
- metrics=['accuracy']
35
- )
 
 
 
 
 
36
 
37
- # Mock training weights for demo purposes
38
- # In production, load fine-tuned weights:
39
- # classifier_model.load_weights("your_finetuned_model.h5")
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Full pipeline for inference
42
  def classify_sound(audio_file):
43
- wav_bytes = tf.io.read_file(audio_file.name)
44
- waveform = load_wav_16k_mono(wav_bytes)
 
 
 
 
 
45
 
46
- # Extract embeddings from YAMNet
47
- _, embeddings, _ = yamnet_model(waveform)
 
48
 
49
- # Classify using your classifier model
50
- predictions = classifier_model(embeddings)
51
- averaged_predictions = tf.reduce_mean(predictions, axis=0)
52
- top_class = tf.math.argmax(averaged_predictions).numpy()
53
- confidence = tf.reduce_max(tf.nn.softmax(averaged_predictions)).numpy()
54
 
55
- return f"{class_names[top_class]} (confidence: {confidence:.2%})"
56
 
57
- interface = gr.Interface(
 
58
  fn=classify_sound,
59
- inputs=gr.Audio(type="filepath"),
60
- outputs="text",
 
 
 
 
61
  title="YAMNet Audio Classifier",
62
- description="Upload an audio clip to classify using YAMNet and a custom classifier trained on AudioSet embeddings."
63
  )
64
 
65
- interface.launch()
 
1
+ import gradio as gr
2
+ import numpy as np
3
  import tensorflow as tf
4
  import tensorflow_hub as hub
5
  import tensorflow_io as tfio
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ from PIL import Image
9
  import pandas as pd
 
10
 
11
+ # Load YAMNet model
12
  yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
13
  yamnet_model = hub.load(yamnet_model_handle)
14
+
15
+ # Load class names
16
  class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')
17
  class_names = list(pd.read_csv(class_map_path)['display_name'])
18
 
19
+ # Decode and resample audio
20
+ def load_wav_16k_mono(audio_bytes):
21
+ audio_tensor, sample_rate = tf.audio.decode_wav(audio_bytes, desired_channels=1)
22
+ audio_tensor = tf.squeeze(audio_tensor, axis=-1)
23
+ audio_tensor = tfio.audio.resample(audio_tensor, rate_in=tf.cast(sample_rate, tf.int64), rate_out=16000)
24
+ return audio_tensor
 
 
 
 
 
 
 
 
25
 
26
+ # Plot waveform
27
+ def plot_waveform(audio_tensor):
28
+ plt.figure(figsize=(8, 2))
29
+ plt.plot(audio_tensor.numpy())
30
+ plt.title("Waveform")
31
+ plt.tight_layout()
32
+ buf = io.BytesIO()
33
+ plt.savefig(buf, format='png')
34
+ plt.close()
35
+ buf.seek(0)
36
+ return Image.open(buf)
37
 
38
+ # Plot log-mel spectrogram
39
+ def plot_spectrogram(spectrogram):
40
+ plt.figure(figsize=(8, 3))
41
+ plt.imshow(spectrogram.numpy().T, aspect='auto', origin='lower', interpolation='nearest')
42
+ plt.title("Log-mel Spectrogram")
43
+ plt.xlabel("Frames")
44
+ plt.ylabel("Mel Bands")
45
+ plt.tight_layout()
46
+ buf = io.BytesIO()
47
+ plt.savefig(buf, format='png')
48
+ plt.close()
49
+ buf.seek(0)
50
+ return Image.open(buf)
51
 
52
+ # Gradio interface logic
53
  def classify_sound(audio_file):
54
+ if isinstance(audio_file, str):
55
+ audio_bytes = tf.io.read_file(audio_file)
56
+ else:
57
+ audio_bytes = audio_file.read()
58
+
59
+ waveform = load_wav_16k_mono(audio_bytes)
60
+ scores, embeddings, spectrogram = yamnet_model(waveform)
61
 
62
+ mean_scores = tf.reduce_mean(scores, axis=0)
63
+ top_class = tf.math.argmax(mean_scores)
64
+ inferred_class = class_names[top_class]
65
 
66
+ waveform_img = plot_waveform(waveform)
67
+ spectrogram_img = plot_spectrogram(spectrogram)
 
 
 
68
 
69
+ return inferred_class, waveform_img, spectrogram_img
70
 
71
+ # Gradio app
72
+ app = gr.Interface(
73
  fn=classify_sound,
74
+ inputs=gr.Audio(type="file", label="Upload audio file"),
75
+ outputs=[
76
+ gr.Text(label="Predicted Class"),
77
+ gr.Image(type="pil", label="Waveform"),
78
+ gr.Image(type="pil", label="Log-mel Spectrogram")
79
+ ],
80
  title="YAMNet Audio Classifier",
81
+ description="Classify environmental and animal sounds using YAMNet. Visualize waveform and log-mel spectrogram."
82
  )
83
 
84
+ app.launch()