Kaworu17 commited on
Commit
11683d3
·
verified ·
1 Parent(s): 22c3745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -3,13 +3,14 @@ import tensorflow_hub as hub
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
- import soundfile as sf # PySoundFile for broader audio format support
 
7
 
8
- # Load YAMNet model
9
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
10
  yamnet_model = hub.load(yamnet_model_handle)
11
 
12
- # Load class names
13
  def load_class_map():
14
  class_map_path = tf.keras.utils.get_file(
15
  'yamnet_class_map.csv',
@@ -21,33 +22,38 @@ def load_class_map():
21
 
22
  class_names = load_class_map()
23
 
24
- # Core classifier function
25
  def classify_audio(file_path):
26
  try:
27
- # Load audio file using soundfile (supports WAV, MP3, FLAC, OGG, etc.)
28
  audio_data, sample_rate = sf.read(file_path)
29
 
30
  # Convert stereo to mono if needed
31
  if len(audio_data.shape) > 1:
32
  audio_data = np.mean(audio_data, axis=1)
33
 
34
- # Normalize
35
  audio_data = audio_data / np.max(np.abs(audio_data))
36
 
37
- # Resample if needed
38
- if sample_rate != 16000:
39
- audio_data = tf.audio.resample(audio_data, sample_rate, 16000)
40
- sample_rate = 16000
 
 
41
 
42
- # Predict
43
- scores, embeddings, spectrogram = yamnet_model(audio_data)
44
- mean_scores = np.mean(scores, axis=0)
 
 
 
45
  top_5 = np.argsort(mean_scores)[::-1][:5]
46
 
47
  top_prediction = class_names[top_5[0]]
48
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
49
 
50
- # Plot waveform
51
  fig, ax = plt.subplots()
52
  ax.plot(audio_data)
53
  ax.set_title("Waveform")
@@ -60,7 +66,7 @@ def classify_audio(file_path):
60
  except Exception as e:
61
  return f"Error processing audio: {e}", {}, None
62
 
63
- # Gradio UI
64
  interface = gr.Interface(
65
  fn=classify_audio,
66
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
@@ -70,7 +76,7 @@ interface = gr.Interface(
70
  gr.Plot(label="Waveform")
71
  ],
72
  title="Audtheia YAMNet Audio Classifier",
73
- description="Upload environmental or animal sounds (WAV/MP3). Classifies with YAMNet and shows waveform + top 5 predictions."
74
  )
75
 
76
  if __name__ == "__main__":
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
+ import soundfile as sf
7
+ from scipy.signal import resample # Correct resampling method
8
 
9
+ # Load YAMNet model from TensorFlow Hub
10
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
11
  yamnet_model = hub.load(yamnet_model_handle)
12
 
13
+ # Load class labels
14
  def load_class_map():
15
  class_map_path = tf.keras.utils.get_file(
16
  'yamnet_class_map.csv',
 
22
 
23
  class_names = load_class_map()
24
 
25
+ # Classification function
26
  def classify_audio(file_path):
27
  try:
28
+ # Load audio file (WAV, MP3, etc.)
29
  audio_data, sample_rate = sf.read(file_path)
30
 
31
  # Convert stereo to mono if needed
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
34
 
35
+ # Normalize audio
36
  audio_data = audio_data / np.max(np.abs(audio_data))
37
 
38
+ # Resample to 16kHz if necessary
39
+ target_rate = 16000
40
+ if sample_rate != target_rate:
41
+ duration = audio_data.shape[0] / sample_rate
42
+ new_length = int(duration * target_rate)
43
+ audio_data = resample(audio_data, new_length)
44
 
45
+ # Convert to tensor
46
+ waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
47
+
48
+ # Run YAMNet
49
+ scores, embeddings, spectrogram = yamnet_model(waveform)
50
+ mean_scores = tf.reduce_mean(scores, axis=0).numpy()
51
  top_5 = np.argsort(mean_scores)[::-1][:5]
52
 
53
  top_prediction = class_names[top_5[0]]
54
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
+ # Create waveform plot
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
 
66
  except Exception as e:
67
  return f"Error processing audio: {e}", {}, None
68
 
69
+ # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
 
76
  gr.Plot(label="Waveform")
77
  ],
78
  title="Audtheia YAMNet Audio Classifier",
79
+ description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
80
  )
81
 
82
  if __name__ == "__main__":