Kaworu17 commited on
Commit
9c15f86
·
verified ·
1 Parent(s): 11683d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -25
app.py CHANGED
@@ -4,7 +4,9 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
- from scipy.signal import resample # Correct resampling method
 
 
8
 
9
  # Load YAMNet model from TensorFlow Hub
10
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
@@ -22,61 +24,80 @@ def load_class_map():
22
 
23
  class_names = load_class_map()
24
 
25
- # Classification function
26
  def classify_audio(file_path):
27
  try:
28
- # Load audio file (WAV, MP3, etc.)
29
  audio_data, sample_rate = sf.read(file_path)
30
-
31
- # Convert stereo to mono if needed
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
34
-
35
- # Normalize audio
36
  audio_data = audio_data / np.max(np.abs(audio_data))
37
 
38
- # Resample to 16kHz if necessary
39
  target_rate = 16000
40
  if sample_rate != target_rate:
41
- duration = audio_data.shape[0] / sample_rate
42
- new_length = int(duration * target_rate)
43
- audio_data = resample(audio_data, new_length)
44
 
45
- # Convert to tensor
46
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
47
 
48
  # Run YAMNet
49
  scores, embeddings, spectrogram = yamnet_model(waveform)
50
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
51
- top_5 = np.argsort(mean_scores)[::-1][:5]
 
 
 
52
 
53
- top_prediction = class_names[top_5[0]]
54
- top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
- # Create waveform plot
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
  ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
 
 
 
63
 
64
- return top_prediction, top_scores, fig
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  except Exception as e:
67
- return f"Error processing audio: {e}", {}, None
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
- inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
73
- outputs=[
74
- gr.Textbox(label="Top Prediction"),
75
- gr.Label(label="Top 5 Classes with Scores"),
76
- gr.Plot(label="Waveform")
77
- ],
78
  title="Audtheia YAMNet Audio Classifier",
79
- description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
80
  )
81
 
82
  if __name__ == "__main__":
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
+ from scipy.signal import resample
8
+ import uuid
9
+ import os
10
 
11
  # Load YAMNet model from TensorFlow Hub
12
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
 
24
 
25
  class_names = load_class_map()
26
 
27
+ # Audio classification
28
  def classify_audio(file_path):
29
  try:
30
+ # Load and normalize audio
31
  audio_data, sample_rate = sf.read(file_path)
 
 
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
 
 
34
  audio_data = audio_data / np.max(np.abs(audio_data))
35
 
36
+ # Resample to 16 kHz
37
  target_rate = 16000
38
  if sample_rate != target_rate:
39
+ duration = len(audio_data) / sample_rate
40
+ new_len = int(duration * target_rate)
41
+ audio_data = resample(audio_data, new_len)
42
 
 
43
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
44
 
45
  # Run YAMNet
46
  scores, embeddings, spectrogram = yamnet_model(waveform)
47
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
48
+ top_5_indices = np.argsort(mean_scores)[::-1][:5]
49
+
50
+ top_prediction = class_names[top_5_indices[0]]
51
+ confidence = float(mean_scores[top_5_indices[0]])
52
 
53
+ # Dominant classes
54
+ dominant_bands = ", ".join([class_names[i] for i in top_5_indices[:3]])
55
 
56
+ # Waveform image
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
  ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
63
+ waveform_filename = f"waveform_{uuid.uuid4().hex}.png"
64
+ fig.savefig(waveform_filename)
65
+ plt.close(fig)
66
 
67
+ # Structured JSON output
68
+ return {
69
+ "classification": top_prediction,
70
+ "confidence": confidence,
71
+ "denoised_audio_url": "N/A",
72
+ "spectrogram_url": "N/A",
73
+ "bonus": {
74
+ "frequency_range": "0–8000 Hz",
75
+ "dominant_bands": dominant_bands
76
+ },
77
+ "waveform_url": waveform_filename
78
+ }
79
 
80
  except Exception as e:
81
+ return {
82
+ "classification": "Error",
83
+ "confidence": 0.0,
84
+ "denoised_audio_url": "N/A",
85
+ "spectrogram_url": "N/A",
86
+ "bonus": {
87
+ "frequency_range": "N/A",
88
+ "dominant_bands": "N/A"
89
+ },
90
+ "waveform_url": "N/A",
91
+ "error": str(e)
92
+ }
93
 
94
  # Gradio interface
95
  interface = gr.Interface(
96
  fn=classify_audio,
97
+ inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3"),
98
+ outputs="json",
 
 
 
 
99
  title="Audtheia YAMNet Audio Classifier",
100
+ description="Classify audio using YAMNet and return structured JSON output for n8n."
101
  )
102
 
103
  if __name__ == "__main__":