mashup / app.py
UshaMurux's picture
Update app.py
dfa0324 verified
import gradio as gr
import torch
import matplotlib.pyplot as plt
import os
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import logging
import sys
import librosa
import librosa.display
import os
import numpy as np
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
MODEL_ID = "UshaMurux/ast-model-big"
AST_SR = 16000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = None
model = None
def load_model():
global feature_extractor, model
if model is None:
try:
logger.info("Loading model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
logger.info("Model loaded successfully...")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise gr.Error( "Failed to load model...")
return feature_extractor, model
def predict_audio(audio_path):
logger.info(f"inside predict_audio : {audio_path}")
feature_extractor, model = load_model()
id2label = model.config.id2label
try:
waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True)
except Exception as e:
logger.error(f"Audio loading failed : {e}")
raise gr.Error("Failed to load the audio ...")
try:
waveform = torch.tensor(waveform)
max_val = waveform.abs().max()
if max_val > 0:
waveform = waveform / max_val
except Exception as e:
logger.error(f"Preprocessing failed: {e}")
raise gr.Error("Preprocessing failed ...")
try:
inputs = feature_extractor(
waveform.numpy(),
sampling_rate=sr,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
logger.error(f"Feature extraction failed : {e}")
raise gr.Error("Feature extraction failed ...")
try:
with torch.no_grad():
logits = model(**inputs).logits.squeeze(0)
probs = torch.softmax(logits, dim=0).cpu().numpy()
except Exception as e:
logger.error(f"Inference failed : {e}")
raise gr.Error("Inference failed ...")
return waveform.numpy(), probs, id2label
with gr.Blocks(title="AST Model") as demo:
gr.Markdown("AST Genre Classifier")
audio_input = gr.Audio(sources=["upload"], type="filepath")
plot_output = gr.Plot()
label_output = gr.Label(num_top_classes=5)
def wrapper(audio_path):
waveform, probs, id2label = predict_audio(audio_path)
if waveform is None:
return None, None, None
mel_spec = librosa.feature.melspectrogram(
y = waveform, sr = AST_SR, n_mels = 128
)
mel_db = librosa.power_to_db(mel_spec, ref=np.max)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
librosa.display.waveshow(waveform, sr = AST_SR, ax = ax[0])
ax[0].set_title("Waveform")
ax[0].set_xlabel("Time (sec)")
ax[0].set_ylabel("Amplitude")
# mel spectrogram
img = librosa.display.specshow(
mel_db, sr = AST_SR,
x_axis = 'time', y_axis = 'mel',
cmap = 'viridis', ax = ax[1]
)
ax[1].set_title("Mel Spectrogram")
fig.colorbar(img, ax=ax[1], format="%+2.0f dB")
label_dict = {
id2label[i]: float(probs[i])
for i in range(len(probs))
}
plt.close(fig)
return fig, label_dict
btn = gr.Button("Predict")
btn.click(wrapper, audio_input, [plot_output, label_output])
#demo.queue().launch(show_error=True)
#demo.queue().launch(share=True, show_error=True)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
share=True, show_error=True)