File size: 4,197 Bytes
d6f60ca dbfcb72 d6f60ca 3207bb6 d6f60ca 0976f5a d6f60ca 0976f5a 8306ff7 0976f5a 8306ff7 0976f5a 8306ff7 0976f5a 8306ff7 0976f5a d6f60ca baf4fb6 dbfcb72 d6f60ca dbfcb72 dfa0324 206ce3a dbfcb72 206ce3a dbfcb72 0976f5a dbfcb72 d6f60ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import gradio as gr
import torch
import matplotlib.pyplot as plt
import os
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import logging
import sys
import librosa
import librosa.display
import os
import numpy as np
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
MODEL_ID = "UshaMurux/ast-model-big"
AST_SR = 16000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = None
model = None
def load_model():
global feature_extractor, model
if model is None:
try:
logger.info("Loading model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
logger.info("Model loaded successfully...")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise gr.Error( "Failed to load model...")
return feature_extractor, model
def predict_audio(audio_path):
logger.info(f"inside predict_audio : {audio_path}")
feature_extractor, model = load_model()
id2label = model.config.id2label
try:
waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True)
except Exception as e:
logger.error(f"Audio loading failed : {e}")
raise gr.Error("Failed to load the audio ...")
try:
waveform = torch.tensor(waveform)
max_val = waveform.abs().max()
if max_val > 0:
waveform = waveform / max_val
except Exception as e:
logger.error(f"Preprocessing failed: {e}")
raise gr.Error("Preprocessing failed ...")
try:
inputs = feature_extractor(
waveform.numpy(),
sampling_rate=sr,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
logger.error(f"Feature extraction failed : {e}")
raise gr.Error("Feature extraction failed ...")
try:
with torch.no_grad():
logits = model(**inputs).logits.squeeze(0)
probs = torch.softmax(logits, dim=0).cpu().numpy()
except Exception as e:
logger.error(f"Inference failed : {e}")
raise gr.Error("Inference failed ...")
return waveform.numpy(), probs, id2label
with gr.Blocks(title="AST Model") as demo:
gr.Markdown("AST Genre Classifier")
audio_input = gr.Audio(sources=["upload"], type="filepath")
plot_output = gr.Plot()
label_output = gr.Label(num_top_classes=5)
def wrapper(audio_path):
waveform, probs, id2label = predict_audio(audio_path)
if waveform is None:
return None, None, None
mel_spec = librosa.feature.melspectrogram(
y = waveform, sr = AST_SR, n_mels = 128
)
mel_db = librosa.power_to_db(mel_spec, ref=np.max)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
librosa.display.waveshow(waveform, sr = AST_SR, ax = ax[0])
ax[0].set_title("Waveform")
ax[0].set_xlabel("Time (sec)")
ax[0].set_ylabel("Amplitude")
# mel spectrogram
img = librosa.display.specshow(
mel_db, sr = AST_SR,
x_axis = 'time', y_axis = 'mel',
cmap = 'viridis', ax = ax[1]
)
ax[1].set_title("Mel Spectrogram")
fig.colorbar(img, ax=ax[1], format="%+2.0f dB")
label_dict = {
id2label[i]: float(probs[i])
for i in range(len(probs))
}
plt.close(fig)
return fig, label_dict
btn = gr.Button("Predict")
btn.click(wrapper, audio_input, [plot_output, label_output])
#demo.queue().launch(show_error=True)
#demo.queue().launch(share=True, show_error=True)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
share=True, show_error=True)
|