Kaworu17 commited on
Commit
c65a7f5
·
verified ·
1 Parent(s): 9c15f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -47
app.py CHANGED
@@ -4,9 +4,7 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
- from scipy.signal import resample
8
- import uuid
9
- import os
10
 
11
  # Load YAMNet model from TensorFlow Hub
12
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
@@ -24,81 +22,62 @@ def load_class_map():
24
 
25
  class_names = load_class_map()
26
 
27
- # Audio classification
28
  def classify_audio(file_path):
29
  try:
30
- # Load and normalize audio
31
  audio_data, sample_rate = sf.read(file_path)
 
 
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
 
 
34
  audio_data = audio_data / np.max(np.abs(audio_data))
35
 
36
- # Resample to 16 kHz
37
  target_rate = 16000
38
  if sample_rate != target_rate:
39
- duration = len(audio_data) / sample_rate
40
- new_len = int(duration * target_rate)
41
- audio_data = resample(audio_data, new_len)
42
 
 
43
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
44
 
45
  # Run YAMNet
46
  scores, embeddings, spectrogram = yamnet_model(waveform)
47
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
48
- top_5_indices = np.argsort(mean_scores)[::-1][:5]
49
-
50
- top_prediction = class_names[top_5_indices[0]]
51
- confidence = float(mean_scores[top_5_indices[0]])
52
 
53
- # Dominant classes
54
- dominant_bands = ", ".join([class_names[i] for i in top_5_indices[:3]])
55
 
56
- # Waveform image
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
  ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
63
- waveform_filename = f"waveform_{uuid.uuid4().hex}.png"
64
- fig.savefig(waveform_filename)
65
- plt.close(fig)
66
 
67
- # Structured JSON output
68
- return {
69
- "classification": top_prediction,
70
- "confidence": confidence,
71
- "denoised_audio_url": "N/A",
72
- "spectrogram_url": "N/A",
73
- "bonus": {
74
- "frequency_range": "0–8000 Hz",
75
- "dominant_bands": dominant_bands
76
- },
77
- "waveform_url": waveform_filename
78
- }
79
 
80
  except Exception as e:
81
- return {
82
- "classification": "Error",
83
- "confidence": 0.0,
84
- "denoised_audio_url": "N/A",
85
- "spectrogram_url": "N/A",
86
- "bonus": {
87
- "frequency_range": "N/A",
88
- "dominant_bands": "N/A"
89
- },
90
- "waveform_url": "N/A",
91
- "error": str(e)
92
- }
93
 
94
  # Gradio interface
95
  interface = gr.Interface(
96
  fn=classify_audio,
97
- inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3"),
98
- outputs="json",
 
 
 
 
99
  title="Audtheia YAMNet Audio Classifier",
100
- description="Classify audio using YAMNet and return structured JSON output for n8n."
101
  )
102
 
103
  if __name__ == "__main__":
104
- interface.launch()
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
+ from scipy.signal import resample # Correct resampling method
 
 
8
 
9
  # Load YAMNet model from TensorFlow Hub
10
  yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
 
22
 
23
  class_names = load_class_map()
24
 
25
+ # Classification function
26
  def classify_audio(file_path):
27
  try:
28
+ # Load audio file (WAV, MP3, etc.)
29
  audio_data, sample_rate = sf.read(file_path)
30
+
31
+ # Convert stereo to mono if needed
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
34
+
35
+ # Normalize audio
36
  audio_data = audio_data / np.max(np.abs(audio_data))
37
 
38
+ # Resample to 16kHz if necessary
39
  target_rate = 16000
40
  if sample_rate != target_rate:
41
+ duration = audio_data.shape[0] / sample_rate
42
+ new_length = int(duration * target_rate)
43
+ audio_data = resample(audio_data, new_length)
44
 
45
+ # Convert to tensor
46
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
47
 
48
  # Run YAMNet
49
  scores, embeddings, spectrogram = yamnet_model(waveform)
50
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
51
+ top_5 = np.argsort(mean_scores)[::-1][:5]
 
 
 
52
 
53
+ top_prediction = class_names[top_5[0]]
54
+ top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
+ # Create waveform plot
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
  ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
 
 
 
63
 
64
+ return top_prediction, top_scores, fig
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  except Exception as e:
67
+ return f"Error processing audio: {e}", {}, None
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
+ inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
73
+ outputs=[
74
+ gr.Textbox(label="Top Prediction"),
75
+ gr.Label(label="Top 5 Classes with Scores"),
76
+ gr.Plot(label="Waveform")
77
+ ],
78
  title="Audtheia YAMNet Audio Classifier",
79
+ description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform."
80
  )
81
 
82
  if __name__ == "__main__":
83
+ interface.launch()