UshaMurux commited on
Commit
d6f60ca
·
verified ·
1 Parent(s): 8645c94

created app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+ import logging
7
+ import sys
8
+ import librosa
9
+ import os
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format="%(asctime)s - %(levelname)s - %(message)s",
14
+ handlers=[logging.StreamHandler(sys.stdout)]
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+ MODEL_ID = "UshaMurux/ast-model-big"
19
+ AST_SR = 16000
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ feature_extractor = None
22
+ model = None
23
+
24
+
25
+ def load_model():
26
+ global feature_extractor, model
27
+ if model is None:
28
+ try:
29
+ logger.info("Loading model...")
30
+
31
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
32
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
33
+
34
+ model.to(device)
35
+ model.eval()
36
+
37
+ logger.info("Model loaded successfully...")
38
+ except Exception as e:
39
+ logger.error(f"Model loading failed: {e}")
40
+
41
+ raise gr.Error(
42
+ "Failed to load model.........."
43
+ )
44
+ return feature_extractor, model
45
+
46
+
47
+
48
+ def predict_audio(audio_path):
49
+ logger.info(f"inside predict_audio : {audio_path}")
50
+
51
+ feature_extractor, model = load_model()
52
+ id2label = model.config.id2label
53
+
54
+ waveform, sr = librosa.load(audio_path, sr=AST_SR, mono=True)
55
+
56
+ waveform = torch.tensor(waveform)
57
+ max_val = waveform.abs().max()
58
+ if max_val > 0:
59
+ waveform = waveform / max_val
60
+
61
+ inputs = feature_extractor(
62
+ waveform.numpy(),
63
+ sampling_rate=sr,
64
+ return_tensors="pt"
65
+ )
66
+
67
+ inputs = {k: v.to(device) for k, v in inputs.items()}
68
+ with torch.no_grad():
69
+ logits = model(**inputs).logits.squeeze(0)
70
+
71
+ probs = torch.softmax(logits, dim=0).cpu().numpy()
72
+
73
+ return waveform.numpy(), probs, id2label
74
+
75
+
76
+ with gr.Blocks(title="AST Model") as demo:
77
+ gr.Markdown("AST Genre Classifier")
78
+
79
+ audio_input = gr.Audio(sources=["upload"], type="filepath")
80
+ plot_output = gr.Plot()
81
+ label_output = gr.Label(num_top_classes=5)
82
+
83
+ def wrapper(audio_path):
84
+ waveform, probs, id2label = predict_audio(audio_path)
85
+
86
+ fig, ax = plt.subplots(figsize=(10, 3))
87
+ ax.plot(waveform)
88
+ ax.set_title("Waveform")
89
+
90
+ label_dict = {
91
+ id2label[i]: float(probs[i])
92
+ for i in range(len(probs))
93
+ }
94
+
95
+ plt.close(fig)
96
+ return fig, label_dict
97
+
98
+ btn = gr.Button("Predict")
99
+ btn.click(wrapper, audio_input, [plot_output, label_output])
100
+
101
+ #demo.queue().launch(show_error=True)
102
+ #demo.queue().launch(share=True, show_error=True)
103
+ demo.queue().launch(
104
+ server_name="0.0.0.0",
105
+ server_port=7860,
106
+ ssr_mode=False,
107
+ share=True, show_error=True)
108
+