| import gradio as gr |
| import torch |
| import matplotlib.pyplot as plt |
| import os |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
| import logging |
| import sys |
| import librosa |
| import librosa.display |
| import os |
| import numpy as np |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)] |
| ) |
|
|
| logger = logging.getLogger(__name__) |
| MODEL_ID = "UshaMurux/ast-model-big" |
| AST_SR = 16000 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| feature_extractor = None |
| model = None |
|
|
|
|
| def load_model(): |
| global feature_extractor, model |
| if model is None: |
| try: |
| logger.info("Loading model...") |
| |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID) |
| model = AutoModelForAudioClassification.from_pretrained(MODEL_ID) |
| |
| model.to(device) |
| model.eval() |
|
|
| logger.info("Model loaded successfully...") |
| except Exception as e: |
| logger.error(f"Model loading failed: {e}") |
| raise gr.Error( "Failed to load model...") |
| return feature_extractor, model |
|
|
|
|
| def predict_audio(audio_path): |
| logger.info(f"inside predict_audio : {audio_path}") |
| |
| feature_extractor, model = load_model() |
| id2label = model.config.id2label |
|
|
| try: |
| waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True) |
| except Exception as e: |
| logger.error(f"Audio loading failed : {e}") |
| raise gr.Error("Failed to load the audio ...") |
|
|
| try: |
| waveform = torch.tensor(waveform) |
| max_val = waveform.abs().max() |
| if max_val > 0: |
| waveform = waveform / max_val |
| except Exception as e: |
| logger.error(f"Preprocessing failed: {e}") |
| raise gr.Error("Preprocessing failed ...") |
|
|
| try: |
| inputs = feature_extractor( |
| waveform.numpy(), |
| sampling_rate=sr, |
| return_tensors="pt" |
| ) |
| |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| except Exception as e: |
| logger.error(f"Feature extraction failed : {e}") |
| raise gr.Error("Feature extraction failed ...") |
|
|
| try: |
| with torch.no_grad(): |
| logits = model(**inputs).logits.squeeze(0) |
| |
| probs = torch.softmax(logits, dim=0).cpu().numpy() |
| except Exception as e: |
| logger.error(f"Inference failed : {e}") |
| raise gr.Error("Inference failed ...") |
| |
| return waveform.numpy(), probs, id2label |
|
|
|
|
| with gr.Blocks(title="AST Model") as demo: |
| gr.Markdown("AST Genre Classifier") |
|
|
| audio_input = gr.Audio(sources=["upload"], type="filepath") |
| plot_output = gr.Plot() |
| label_output = gr.Label(num_top_classes=5) |
|
|
| def wrapper(audio_path): |
| waveform, probs, id2label = predict_audio(audio_path) |
|
|
| if waveform is None: |
| return None, None, None |
| |
| mel_spec = librosa.feature.melspectrogram( |
| y = waveform, sr = AST_SR, n_mels = 128 |
| ) |
|
|
| mel_db = librosa.power_to_db(mel_spec, ref=np.max) |
| |
| fig, ax = plt.subplots(1, 2, figsize=(10, 5)) |
| |
| librosa.display.waveshow(waveform, sr = AST_SR, ax = ax[0]) |
| ax[0].set_title("Waveform") |
| ax[0].set_xlabel("Time (sec)") |
| ax[0].set_ylabel("Amplitude") |
| |
| |
| img = librosa.display.specshow( |
| mel_db, sr = AST_SR, |
| x_axis = 'time', y_axis = 'mel', |
| cmap = 'viridis', ax = ax[1] |
| ) |
|
|
| ax[1].set_title("Mel Spectrogram") |
| fig.colorbar(img, ax=ax[1], format="%+2.0f dB") |
| |
| label_dict = { |
| id2label[i]: float(probs[i]) |
| for i in range(len(probs)) |
| } |
|
|
| plt.close(fig) |
| return fig, label_dict |
|
|
| btn = gr.Button("Predict") |
| btn.click(wrapper, audio_input, [plot_output, label_output]) |
|
|
| |
| |
| demo.queue().launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ssr_mode=False, |
| share=True, show_error=True) |
|
|
|
|