saha23s's picture
rerunning
d8af6cc
import gradio as gr
import numpy as np
import tensorflow as tf
import librosa
from pydub import AudioSegment
import pickle
import gdown
import os
# Download the model from Google Drive
model_url = "https://drive.google.com/uc?id=1Z7Of2hr48B3c1mb_j0WV8y6JaTkFvrUe"
output = "best_model.keras"
if not os.path.exists(output):
gdown.download(model_url, output, quiet=False)
# Load the species names
with open('species_names_resnet.pkl', 'rb') as f:
species_names_resnet = pickle.load(f)
# Load the trained model
model = tf.keras.models.load_model('best_model.keras')
# Define preprocessing functions
def pad_audio(segment, target_duration_ms):
if len(segment) < target_duration_ms:
padding = AudioSegment.silent(duration=target_duration_ms - len(segment))
segment = segment + padding
return segment
def preprocess_audio(audio_path, target_duration_ms=3000, sr=22050, n_mels=96, fmax=8000):
segment = AudioSegment.from_file(audio_path)
segment = pad_audio(segment, target_duration_ms)
samples = np.array(segment.get_array_of_samples(), dtype=np.float32) / 32768.0
mel_spectrogram = librosa.feature.melspectrogram(y=samples, sr=sr, n_mels=n_mels, fmax=fmax)
log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
if log_mel_spectrogram.shape[1] < 130:
log_mel_spectrogram = np.pad(log_mel_spectrogram, ((0, 0), (0, 130 - log_mel_spectrogram.shape[1])), mode='constant')
elif log_mel_spectrogram.shape[1] > 130:
log_mel_spectrogram = log_mel_spectrogram[:, :130]
return log_mel_spectrogram.T # Transpose to [130, 96]
def predict(audio_file):
features = preprocess_audio(audio_file)
features = features[np.newaxis, ..., np.newaxis] # Add batch and channel dimensions
features = np.repeat(features, 3, axis=-1) # Repeat to create 3 channels
features = np.transpose(features, (0, 2, 1, 3)) # Transpose to match the expected shape (None, 96, 130, 3)
y_pred_prob = model.predict(features)
y_pred_label = np.argmax(y_pred_prob, axis=1)
return species_names_resnet[y_pred_label[0]]
# Create Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="Bird Species Prediction",
description="Upload an audio file to predict the bird species."
)
if __name__ == "__main__":
interface.launch()