import gradio as gr import torch import matplotlib.pyplot as plt import os from transformers import AutoFeatureExtractor, AutoModelForAudioClassification import logging import sys import librosa import librosa.display import os import numpy as np logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) MODEL_ID = "UshaMurux/ast-model-big" AST_SR = 16000 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = None model = None def load_model(): global feature_extractor, model if model is None: try: logger.info("Loading model...") feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID) model = AutoModelForAudioClassification.from_pretrained(MODEL_ID) model.to(device) model.eval() logger.info("Model loaded successfully...") except Exception as e: logger.error(f"Model loading failed: {e}") raise gr.Error( "Failed to load model...") return feature_extractor, model def predict_audio(audio_path): logger.info(f"inside predict_audio : {audio_path}") feature_extractor, model = load_model() id2label = model.config.id2label try: waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True) except Exception as e: logger.error(f"Audio loading failed : {e}") raise gr.Error("Failed to load the audio ...") try: waveform = torch.tensor(waveform) max_val = waveform.abs().max() if max_val > 0: waveform = waveform / max_val except Exception as e: logger.error(f"Preprocessing failed: {e}") raise gr.Error("Preprocessing failed ...") try: inputs = feature_extractor( waveform.numpy(), sampling_rate=sr, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items()} except Exception as e: logger.error(f"Feature extraction failed : {e}") raise gr.Error("Feature extraction failed ...") try: with torch.no_grad(): logits = model(**inputs).logits.squeeze(0) probs = torch.softmax(logits, dim=0).cpu().numpy() except Exception as e: logger.error(f"Inference failed : {e}") raise gr.Error("Inference failed ...") return waveform.numpy(), probs, id2label with gr.Blocks(title="AST Model") as demo: gr.Markdown("AST Genre Classifier") audio_input = gr.Audio(sources=["upload"], type="filepath") plot_output = gr.Plot() label_output = gr.Label(num_top_classes=5) def wrapper(audio_path): waveform, probs, id2label = predict_audio(audio_path) if waveform is None: return None, None, None mel_spec = librosa.feature.melspectrogram( y = waveform, sr = AST_SR, n_mels = 128 ) mel_db = librosa.power_to_db(mel_spec, ref=np.max) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) librosa.display.waveshow(waveform, sr = AST_SR, ax = ax[0]) ax[0].set_title("Waveform") ax[0].set_xlabel("Time (sec)") ax[0].set_ylabel("Amplitude") # mel spectrogram img = librosa.display.specshow( mel_db, sr = AST_SR, x_axis = 'time', y_axis = 'mel', cmap = 'viridis', ax = ax[1] ) ax[1].set_title("Mel Spectrogram") fig.colorbar(img, ax=ax[1], format="%+2.0f dB") label_dict = { id2label[i]: float(probs[i]) for i in range(len(probs)) } plt.close(fig) return fig, label_dict btn = gr.Button("Predict") btn.click(wrapper, audio_input, [plot_output, label_output]) #demo.queue().launch(show_error=True) #demo.queue().launch(share=True, show_error=True) demo.queue().launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False, share=True, show_error=True)