Kaworu17 commited on
Commit
4a03abd
·
verified ·
1 Parent(s): e3b4d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -29
app.py CHANGED
@@ -3,77 +3,81 @@ import tensorflow_hub as hub
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
- from scipy.signal import resample
 
7
 
8
- # Load YAMNet model
9
- yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
 
10
 
11
- # Load class names
12
  def load_class_map():
13
  class_map_path = tf.keras.utils.get_file(
14
  'yamnet_class_map.csv',
15
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
16
  )
17
  with open(class_map_path, 'r') as f:
18
- return [line.strip().split(',')[2] for line in f.readlines()[1:]]
 
19
 
20
  class_names = load_class_map()
21
 
22
  # Classification function
23
- def classify_audio(audio, sample_rate):
24
  try:
25
- # Convert stereo to mono
26
- if len(audio.shape) > 1:
27
- audio = np.mean(audio, axis=1)
28
 
29
- # Normalize
30
- audio = audio / np.max(np.abs(audio))
 
31
 
32
- # Resample to 16kHz if needed
33
- target_sr = 16000
34
- if sample_rate != target_sr:
35
- duration = audio.shape[0] / sample_rate
36
- new_length = int(duration * target_sr)
37
- audio = resample(audio, new_length)
38
- sample_rate = target_sr
 
 
39
 
40
  # Convert to tensor
41
- waveform = tf.convert_to_tensor(audio, dtype=tf.float32)
42
 
43
- # Predict
44
  scores, embeddings, spectrogram = yamnet_model(waveform)
45
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
46
  top_5 = np.argsort(mean_scores)[::-1][:5]
47
 
48
- # Extract predictions
49
  top_prediction = class_names[top_5[0]]
50
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
51
 
52
  # Create waveform plot
53
  fig, ax = plt.subplots()
54
- ax.plot(audio)
55
  ax.set_title("Waveform")
56
- ax.set_xlabel("Time (samples)")
57
  ax.set_ylabel("Amplitude")
58
  plt.tight_layout()
59
 
60
  return top_prediction, top_scores, fig
61
 
62
  except Exception as e:
63
- return f"Error: {str(e)}", {}, None
64
 
65
- # Gradio Interface (binary audio compatible for n8n)
66
  interface = gr.Interface(
67
  fn=classify_audio,
68
- inputs=gr.Audio(source="upload", type="numpy", label="Upload .wav or .mp3"),
69
  outputs=[
70
  gr.Textbox(label="Top Prediction"),
71
- gr.Label(label="Top 5 Class Scores"),
72
  gr.Plot(label="Waveform")
73
  ],
74
  title="Audtheia YAMNet Audio Classifier",
75
- description="Classifies audio with YAMNet and returns predictions with waveform plot."
76
  )
77
 
78
  if __name__ == "__main__":
79
- interface.launch()
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
+ import soundfile as sf
7
+ from scipy.signal import resample # Correct resampling method
8
 
9
+ # Load YAMNet model from TensorFlow Hub
10
+ yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
11
+ yamnet_model = hub.load(yamnet_model_handle)
12
 
13
+ # Load class labels
14
  def load_class_map():
15
  class_map_path = tf.keras.utils.get_file(
16
  'yamnet_class_map.csv',
17
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
18
  )
19
  with open(class_map_path, 'r') as f:
20
+ class_names = [line.strip().split(',')[2] for line in f.readlines()[1:]]
21
+ return class_names
22
 
23
  class_names = load_class_map()
24
 
25
  # Classification function
26
+ def classify_audio(file_path):
27
  try:
28
+ # Load audio file (WAV, MP3, etc.)
29
+ audio_data, sample_rate = sf.read(file_path)
 
30
 
31
+ # Convert stereo to mono if needed
32
+ if len(audio_data.shape) > 1:
33
+ audio_data = np.mean(audio_data, axis=1)
34
 
35
+ # Normalize audio
36
+ audio_data = audio_data / np.max(np.abs(audio_data))
37
+
38
+ # Resample to 16kHz if necessary
39
+ target_rate = 16000
40
+ if sample_rate != target_rate:
41
+ duration = audio_data.shape[0] / sample_rate
42
+ new_length = int(duration * target_rate)
43
+ audio_data = resample(audio_data, new_length)
44
 
45
  # Convert to tensor
46
+ waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
47
 
48
+ # Run YAMNet
49
  scores, embeddings, spectrogram = yamnet_model(waveform)
50
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
51
  top_5 = np.argsort(mean_scores)[::-1][:5]
52
 
 
53
  top_prediction = class_names[top_5[0]]
54
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
  # Create waveform plot
57
  fig, ax = plt.subplots()
58
+ ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
+ ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
63
 
64
  return top_prediction, top_scores, fig
65
 
66
  except Exception as e:
67
+ return f"Error processing audio: {e}", {}, None
68
 
69
+ # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
+ inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
73
  outputs=[
74
  gr.Textbox(label="Top Prediction"),
75
+ gr.Label(label="Top 5 Classes with Scores"),
76
  gr.Plot(label="Waveform")
77
  ],
78
  title="Audtheia YAMNet Audio Classifier",
79
+ description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
80
  )
81
 
82
  if __name__ == "__main__":
83
+ interface.launch()