vectorverse commited on
Commit
bd43bae
Β·
verified Β·
1 Parent(s): 7d84440

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +132 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from transformers import ASTFeatureExtractor, ASTForAudioClassification
6
+
7
+ # ── CONFIG ──────────────────────────────────────────────────────────────────
8
+ HF_REPO = "kashishvijayvergiya/music-genre-ast" # your HF repo
9
+ SAMPLE_RATE = 16000
10
+ DURATION = 20
11
+ MAX_LENGTH = SAMPLE_RATE * DURATION
12
+ N_TTA = 5
13
+
14
+ GENRES = ["blues", "classical", "country", "disco", "hiphop",
15
+ "jazz", "metal", "pop", "reggae", "rock"]
16
+ id2label = {i: g for i, g in enumerate(GENRES)}
17
+
18
+ GENRE_EMOJI = {
19
+ "blues": "🎸", "classical": "🎻", "country": "🀠", "disco": "πŸͺ©",
20
+ "hiphop": "🎀", "jazz": "🎺", "metal": "🀘", "pop": "🎡",
21
+ "reggae": "🌴", "rock": "πŸ”₯"
22
+ }
23
+
24
+ # ── LOAD MODEL (once at startup) ────────────────────────────────────────────
25
+ print("Loading model...")
26
+ feature_extractor = ASTFeatureExtractor.from_pretrained(HF_REPO)
27
+ model = ASTForAudioClassification.from_pretrained(HF_REPO)
28
+ model.eval()
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model.to(DEVICE)
31
+ print(f"Model ready on {DEVICE}!")
32
+
33
+ # ── AUDIO HELPERS ───────────────────────────────────────────────────────────
34
+ def load_audio(path):
35
+ y, _ = librosa.load(path, sr=SAMPLE_RATE, mono=True)
36
+ return y.astype(np.float32)
37
+
38
+ def normalize(y):
39
+ return y / (np.max(np.abs(y)) + 1e-6)
40
+
41
+ def random_crop(y):
42
+ if len(y) >= MAX_LENGTH:
43
+ start = np.random.randint(0, len(y) - MAX_LENGTH)
44
+ return y[start:start + MAX_LENGTH]
45
+ return np.pad(y, (0, MAX_LENGTH - len(y)))
46
+
47
+ def center_crop(y):
48
+ if len(y) >= MAX_LENGTH:
49
+ start = (len(y) - MAX_LENGTH) // 2
50
+ return y[start:start + MAX_LENGTH]
51
+ return np.pad(y, (0, MAX_LENGTH - len(y)))
52
+
53
+ # ── PREDICTION WITH TTA ─────────────────────────────────────────────────────
54
+ def predict(audio_path):
55
+ if audio_path is None:
56
+ return "Please upload an audio file.", None
57
+
58
+ try:
59
+ audio = load_audio(audio_path)
60
+ except Exception as e:
61
+ return f"Error loading audio: {e}", None
62
+
63
+ # TTA: center crop + N_TTA-1 random crops
64
+ crops = [center_crop(audio)]
65
+ for _ in range(N_TTA - 1):
66
+ crops.append(random_crop(audio))
67
+
68
+ all_probs = []
69
+ for crop in crops:
70
+ crop = normalize(crop)
71
+ inputs = feature_extractor(
72
+ crop, sampling_rate=SAMPLE_RATE, return_tensors="pt"
73
+ )
74
+ input_values = inputs["input_values"].to(DEVICE)
75
+ with torch.no_grad():
76
+ logits = model(input_values=input_values).logits
77
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
78
+ all_probs.append(probs)
79
+
80
+ avg_probs = np.mean(all_probs, axis=0)[0] # shape: (10,)
81
+ pred_idx = int(np.argmax(avg_probs))
82
+ pred_genre = id2label[pred_idx]
83
+ confidence = float(avg_probs[pred_idx]) * 100
84
+
85
+ # Build label dict for Gradio bar chart
86
+ label_probs = {
87
+ f"{GENRE_EMOJI.get(id2label[i], '')} {id2label[i].capitalize()}": float(avg_probs[i])
88
+ for i in range(len(GENRES))
89
+ }
90
+
91
+ result = f"## {GENRE_EMOJI.get(pred_genre, '')} {pred_genre.capitalize()}\n**Confidence: {confidence:.1f}%**"
92
+ return result, label_probs
93
+
94
+ # ── GRADIO UI ────────────────────────────────────────────────────────────────
95
+ with gr.Blocks(title="🎡 Music Genre Classifier") as demo:
96
+ gr.Markdown(
97
+ """
98
+ # 🎡 Music Genre Classifier
99
+ Upload any music file and the model will predict its genre.
100
+ Supports: blues, classical, country, disco, hip-hop, jazz, metal, pop, reggae, rock.
101
+
102
+ *Model: Fine-tuned Audio Spectrogram Transformer (AST) Β· TTA x5*
103
+ """
104
+ )
105
+
106
+ with gr.Row():
107
+ with gr.Column(scale=1):
108
+ audio_input = gr.Audio(
109
+ label="Upload Audio",
110
+ type="filepath",
111
+ sources=["upload", "microphone"]
112
+ )
113
+ predict_btn = gr.Button("🎯 Predict Genre", variant="primary")
114
+
115
+ with gr.Column(scale=1):
116
+ result_md = gr.Markdown(label="Prediction")
117
+ prob_chart = gr.Label(label="Genre Probabilities", num_top_classes=10)
118
+
119
+ predict_btn.click(
120
+ fn = predict,
121
+ inputs = [audio_input],
122
+ outputs = [result_md, prob_chart]
123
+ )
124
+
125
+ gr.Examples(
126
+ examples = [], # add example audio paths here if you have them
127
+ inputs = [audio_input],
128
+ label = "Examples"
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.36.0
3
+ torch>=2.0.0
4
+ torchaudio>=2.0.0
5
+ librosa>=0.10.0
6
+ numpy>=1.24.0