Kaworu17 commited on
Commit
579b540
·
verified ·
1 Parent(s): c65a7f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -3,81 +3,77 @@ import tensorflow_hub as hub
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
- import soundfile as sf
7
- from scipy.signal import resample # Correct resampling method
8
 
9
- # Load YAMNet model from TensorFlow Hub
10
- yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
11
- yamnet_model = hub.load(yamnet_model_handle)
12
 
13
- # Load class labels
14
  def load_class_map():
15
  class_map_path = tf.keras.utils.get_file(
16
  'yamnet_class_map.csv',
17
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
18
  )
19
  with open(class_map_path, 'r') as f:
20
- class_names = [line.strip().split(',')[2] for line in f.readlines()[1:]]
21
- return class_names
22
 
23
  class_names = load_class_map()
24
 
25
- # Classification function
26
- def classify_audio(file_path):
27
  try:
28
- # Load audio file (WAV, MP3, etc.)
29
- audio_data, sample_rate = sf.read(file_path)
 
30
 
31
- # Convert stereo to mono if needed
32
- if len(audio_data.shape) > 1:
33
- audio_data = np.mean(audio_data, axis=1)
34
 
35
- # Normalize audio
36
- audio_data = audio_data / np.max(np.abs(audio_data))
37
-
38
- # Resample to 16kHz if necessary
39
- target_rate = 16000
40
- if sample_rate != target_rate:
41
- duration = audio_data.shape[0] / sample_rate
42
- new_length = int(duration * target_rate)
43
- audio_data = resample(audio_data, new_length)
44
 
45
  # Convert to tensor
46
- waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
47
 
48
- # Run YAMNet
49
  scores, embeddings, spectrogram = yamnet_model(waveform)
50
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
51
  top_5 = np.argsort(mean_scores)[::-1][:5]
52
 
 
53
  top_prediction = class_names[top_5[0]]
54
- top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
  # Create waveform plot
57
  fig, ax = plt.subplots()
58
- ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
- ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
63
 
64
  return top_prediction, top_scores, fig
65
 
66
  except Exception as e:
67
- return f"Error processing audio: {e}", {}, None
68
 
69
- # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
- inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
73
  outputs=[
74
  gr.Textbox(label="Top Prediction"),
75
- gr.Label(label="Top 5 Classes with Scores"),
76
  gr.Plot(label="Waveform")
77
  ],
78
  title="Audtheia YAMNet Audio Classifier",
79
- description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
80
  )
81
 
82
  if __name__ == "__main__":
83
- interface.launch()
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
+ from scipy.signal import resample
 
7
 
8
+ # Load YAMNet model
9
+ yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
 
10
 
11
+ # Load class names
12
  def load_class_map():
13
  class_map_path = tf.keras.utils.get_file(
14
  'yamnet_class_map.csv',
15
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
16
  )
17
  with open(class_map_path, 'r') as f:
18
+ return [line.strip().split(',')[2] for line in f.readlines()[1:]]
 
19
 
20
  class_names = load_class_map()
21
 
22
+ # Classification function for binary audio input
23
+ def classify_audio(audio, sample_rate):
24
  try:
25
+ # Convert stereo to mono
26
+ if len(audio.shape) > 1:
27
+ audio = np.mean(audio, axis=1)
28
 
29
+ # Normalize
30
+ audio = audio / np.max(np.abs(audio))
 
31
 
32
+ # Resample if needed
33
+ target_sr = 16000
34
+ if sample_rate != target_sr:
35
+ duration = audio.shape[0] / sample_rate
36
+ new_length = int(duration * target_sr)
37
+ audio = resample(audio, new_length)
38
+ sample_rate = target_sr
 
 
39
 
40
  # Convert to tensor
41
+ waveform = tf.convert_to_tensor(audio, dtype=tf.float32)
42
 
43
+ # Predict
44
  scores, embeddings, spectrogram = yamnet_model(waveform)
45
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
46
  top_5 = np.argsort(mean_scores)[::-1][:5]
47
 
48
+ # Extract predictions
49
  top_prediction = class_names[top_5[0]]
50
+ top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5]
51
 
52
  # Create waveform plot
53
  fig, ax = plt.subplots()
54
+ ax.plot(audio)
55
  ax.set_title("Waveform")
56
+ ax.set_xlabel("Time (samples)")
57
  ax.set_ylabel("Amplitude")
58
  plt.tight_layout()
59
 
60
  return top_prediction, top_scores, fig
61
 
62
  except Exception as e:
63
+ return f"Error: {str(e)}", {}, None
64
 
65
+ # Gradio Interface (IMPORTANT: type="numpy" allows binary POSTs from n8n)
66
  interface = gr.Interface(
67
  fn=classify_audio,
68
+ inputs=gr.Audio(source="upload", type="numpy", label="Upload .wav or .mp3"),
69
  outputs=[
70
  gr.Textbox(label="Top Prediction"),
71
+ gr.Label(label="Top 5 Class Scores"),
72
  gr.Plot(label="Waveform")
73
  ],
74
  title="Audtheia YAMNet Audio Classifier",
75
+ description="Classifies audio with YAMNet and returns predictions with waveform plot."
76
  )
77
 
78
  if __name__ == "__main__":
79
+ interface.launch()