amritn8 commited on
Commit
2843631
·
verified ·
1 Parent(s): a7819ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -49
app.py CHANGED
@@ -3,82 +3,58 @@ import joblib
3
  import numpy as np
4
  import gradio as gr
5
  from scipy.io import wavfile
6
- import os
7
 
8
- # Load assets
9
  model = tf.keras.models.load_model("animal_sound_cnn.keras")
10
  label_encoder = joblib.load("label_encoder.joblib")
11
 
12
- def get_model_input_shape():
13
- """Dynamically get the model's expected input shape"""
14
- if len(model.input_shape) == 2:
15
- return model.input_shape[1] # For (None, 384) shape
16
- elif len(model.input_shape) == 4:
17
- return model.input_shape[1:] # For (None, 64, 64, 1) shape
18
- return None
19
-
20
  def preprocess_audio(audio_path):
21
- """Universal audio preprocessing that adapts to your model"""
22
  try:
23
- # 1. Load and normalize audio
24
  sr, y = wavfile.read(audio_path)
25
- y = np.mean(y, axis=1) if len(y.shape) > 1 else y # Stereo to mono
26
- y = y.astype(np.float32) / np.max(np.abs(y))
27
-
28
- # 2. Create spectrogram
29
- n_fft = 512
30
- hop_length = 256
31
- stft = tf.signal.stft(y, frame_length=n_fft, frame_step=hop_length, fft_length=n_fft)
32
- spectrogram = tf.abs(stft)
33
 
34
- # 3. Reshape based on model requirements
35
- expected_shape = get_model_input_shape()
 
36
 
37
- if expected_shape and len(expected_shape) == 1: # Flattened input (384)
38
- flattened = tf.reshape(spectrogram, (1, -1))
39
- if flattened.shape[1] < expected_shape[0]:
40
- flattened = tf.pad(flattened, [[0, 0], [0, expected_shape[0] - flattened.shape[1]]])
41
- else:
42
- flattened = flattened[:, :expected_shape[0]]
43
- return flattened.numpy().astype(np.float32)
44
-
45
- else: # Image-like input (64, 64, 1)
46
- # Convert to mel spectrogram
47
- linear_to_mel = tf.signal.linear_to_mel_weight_matrix(
48
- num_mel_bins=64,
49
- num_spectrogram_bins=spectrogram.shape[-1],
50
- sample_rate=22050,
51
- lower_edge_hertz=125,
52
- upper_edge_hertz=7500)
53
- mel_spectrogram = tf.tensordot(spectrogram, linear_to_mel, 1)
54
- log_mel = tf.math.log(mel_spectrogram + 1e-6)
55
-
56
- # Resize and add channel dimension
57
- resized = tf.image.resize(tf.expand_dims(log_mel, -1), (64, 64))
58
- return tf.expand_dims(resized, 0).numpy().astype(np.float32)
59
 
 
 
60
  except Exception as e:
61
- print(f"Preprocessing error: {str(e)}")
62
  return None
63
 
64
  def predict(audio_path):
65
  try:
 
66
  processed = preprocess_audio(audio_path)
67
  if processed is None:
68
- return "Error: Invalid audio input"
69
 
70
- print(f"Final input shape: {processed.shape}")
 
71
 
 
72
  pred = model.predict(processed)
73
  return label_encoder.inverse_transform([np.argmax(pred)])[0]
74
 
75
  except Exception as e:
76
- return f"Prediction failed: {str(e)}"
77
 
 
78
  gr.Interface(
79
  fn=predict,
80
  inputs=gr.Audio(type="filepath"),
81
  outputs="label",
82
  title="Animal Sound Classifier",
83
- examples=["example.wav"] if os.path.exists("example.wav") else None
84
  ).launch()
 
3
  import numpy as np
4
  import gradio as gr
5
  from scipy.io import wavfile
 
6
 
7
+ # Load model and label encoder
8
  model = tf.keras.models.load_model("animal_sound_cnn.keras")
9
  label_encoder = joblib.load("label_encoder.joblib")
10
 
 
 
 
 
 
 
 
 
11
  def preprocess_audio(audio_path):
12
+ """Simple audio preprocessing for animal sounds"""
13
  try:
14
+ # 1. Load audio file (convert to mono if stereo)
15
  sr, y = wavfile.read(audio_path)
16
+ y = np.mean(y, axis=1) if len(y.shape) > 1 else y
17
+ y = y.astype(np.float32) / np.max(np.abs(y)) # Normalize
 
 
 
 
 
 
18
 
19
+ # 2. Create spectrogram (adjust these parameters to match your training)
20
+ spectrogram = tf.signal.stft(y, frame_length=256, frame_step=128, fft_length=256)
21
+ spectrogram = tf.abs(spectrogram) # Magnitude
22
 
23
+ # 3. Reshape to what your model expects (1, 384)
24
+ flattened = tf.reshape(spectrogram, (1, -1)) # Flatten all
25
+ if flattened.shape[1] < 384:
26
+ flattened = tf.pad(flattened, [[0, 0], [0, 384-flattened.shape[1]]])
27
+ else:
28
+ flattened = flattened[:, :384] # Trim if too long
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ return flattened.numpy()
31
+
32
  except Exception as e:
33
+ print(f"Audio processing error: {str(e)}")
34
  return None
35
 
36
  def predict(audio_path):
37
  try:
38
+ # Process audio
39
  processed = preprocess_audio(audio_path)
40
  if processed is None:
41
+ return "Error: Couldn't process audio"
42
 
43
+ # Debug output
44
+ print(f"Model input shape: {processed.shape}")
45
 
46
+ # Predict and return animal name
47
  pred = model.predict(processed)
48
  return label_encoder.inverse_transform([np.argmax(pred)])[0]
49
 
50
  except Exception as e:
51
+ return f"Prediction error: {str(e)}"
52
 
53
+ # Create simple interface
54
  gr.Interface(
55
  fn=predict,
56
  inputs=gr.Audio(type="filepath"),
57
  outputs="label",
58
  title="Animal Sound Classifier",
59
+ description="Upload a short animal sound (2-5 seconds)"
60
  ).launch()