Spaces:
Runtime error
Runtime error
Upload 16 files
Browse files- .gitattributes +35 -35
- README.md +14 -14
- app.py +94 -0
- models/cnn/10s.pth +3 -0
- models/cnn/1s.pth +3 -0
- models/cnn/3s.pth +3 -0
- models/cnn/5s.pth +3 -0
- models/crnn/10s.pth +3 -0
- models/crnn/1s.pth +3 -0
- models/crnn/3s.pth +3 -0
- models/crnn/5s.pth +3 -0
- models/metrics_summary_table.csv +17 -0
- requirements.txt +4 -0
- src/__init__.py +2 -0
- src/models.py +153 -0
- src/utility.py +57 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*
|
| 26 |
-
|
| 27 |
-
*.tar
|
| 28 |
-
*.
|
| 29 |
-
*.
|
| 30 |
-
*.
|
| 31 |
-
*.
|
| 32 |
-
*.
|
| 33 |
-
*.
|
| 34 |
-
*
|
| 35 |
-
*
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: MusicGenrePulse
|
| 3 |
-
emoji: 🦀
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.15.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: DL app to classify music and get genre distribution.
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MusicGenrePulse
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.15.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: DL app to classify music and get genre distribution.
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import librosa
|
| 5 |
+
import time
|
| 6 |
+
from MusicGenrePulse.src.utility import slice_songs # Adjust your imports as needed
|
| 7 |
+
from MusicGenrePulse.src import MusicCNN, MusicCRNN2D
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
DESIRED_SR = 22050
|
| 11 |
+
HOP_LENGTH = 512
|
| 12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
NUM_CLASSES = 10
|
| 14 |
+
|
| 15 |
+
# Model loading (example for cnn and crnn; update paths as necessary)
|
| 16 |
+
models = {"cnn": {}, "crnn": {}}
|
| 17 |
+
# For instance:
|
| 18 |
+
cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"}
|
| 19 |
+
crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth",
|
| 20 |
+
10: "models/crnn/10s.pth"}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_frames(slice_length):
|
| 24 |
+
return int(slice_length * DESIRED_SR / HOP_LENGTH)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Load cnn models
|
| 28 |
+
for slice_len, path in cnn_model_paths.items():
|
| 29 |
+
model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE)
|
| 30 |
+
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
|
| 31 |
+
_ = model(dummy_input)
|
| 32 |
+
model.load_state_dict(torch.load(path, map_location=DEVICE))
|
| 33 |
+
model.to(DEVICE)
|
| 34 |
+
model.eval()
|
| 35 |
+
models["cnn"][slice_len] = model
|
| 36 |
+
|
| 37 |
+
# Load crnn models
|
| 38 |
+
for slice_len, path in crnn_model_paths.items():
|
| 39 |
+
model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE)
|
| 40 |
+
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
|
| 41 |
+
_ = model(dummy_input)
|
| 42 |
+
model.load_state_dict(torch.load(path, map_location=DEVICE))
|
| 43 |
+
model.to(DEVICE)
|
| 44 |
+
model.eval()
|
| 45 |
+
models["crnn"][slice_len] = model
|
| 46 |
+
|
| 47 |
+
GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def predict_genre(audio_file, slice_length, architecture):
|
| 51 |
+
slice_length = int(slice_length)
|
| 52 |
+
start_time = time.time()
|
| 53 |
+
|
| 54 |
+
y, sr = librosa.load(audio_file, sr=DESIRED_SR)
|
| 55 |
+
target_length = int(np.ceil(len(y) / sr)) * sr
|
| 56 |
+
if len(y) < target_length:
|
| 57 |
+
y = np.pad(y, (0, target_length - len(y)), mode='constant')
|
| 58 |
+
|
| 59 |
+
mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128)
|
| 60 |
+
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
|
| 61 |
+
min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db)
|
| 62 |
+
normalized_spectrogram = (mel_spectrogram_db - min_val) / (
|
| 63 |
+
max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db
|
| 64 |
+
|
| 65 |
+
X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH,
|
| 66 |
+
length_in_seconds=slice_length)
|
| 67 |
+
X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE)
|
| 68 |
+
|
| 69 |
+
model_used = models[architecture][slice_length]
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
outputs = model_used(X_slices)
|
| 72 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
|
| 73 |
+
|
| 74 |
+
avg_probs = np.mean(probabilities, axis=0)
|
| 75 |
+
genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)}
|
| 76 |
+
inference_time = time.time() - start_time
|
| 77 |
+
return genre_distribution, f"Inference Time: {inference_time:.2f} seconds"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)")
|
| 81 |
+
architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture")
|
| 82 |
+
|
| 83 |
+
demo = gr.Interface(
|
| 84 |
+
fn=predict_genre,
|
| 85 |
+
inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown],
|
| 86 |
+
outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")],
|
| 87 |
+
title="Music Genre Classifier",
|
| 88 |
+
description="Upload an audio file, select a slice length and model architecture to predict its genre distribution."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
demo.launch()
|
| 93 |
+
|
| 94 |
+
|
models/cnn/10s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7280b3c52a5f2c741180160eed533237817bb9ba66bd8edf6519b0ce7776670b
|
| 3 |
+
size 224021010
|
models/cnn/1s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2914628c6cb675e7c75e14f1686c0f20bd84c6cacad4a41db367bd476796f6be
|
| 3 |
+
size 22694418
|
models/cnn/3s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e8135267a6db5e835e246755cde7997d00944415bd5db344fc2e036cb8f5b06
|
| 3 |
+
size 68831762
|
models/cnn/5s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ca0c926d4393f112c3acee391ab1cc4fbfef41f015c156e8aeb68dbdf1cee09
|
| 3 |
+
size 110774802
|
models/crnn/10s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:161bb49d3ed4bb761e9d7c6095afe5fbc144c7c95d49df5c1198d477d96191ce
|
| 3 |
+
size 1626402
|
models/crnn/1s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fbdeec8991d3b655e6785a2aa451accebd479ab3a0f3d61a4fc39429af471d6a
|
| 3 |
+
size 1626402
|
models/crnn/3s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68bdbe9d17904bfb57720554edd4947c3c9e3df9ad4f14c505fdeb4420ce3b77
|
| 3 |
+
size 1626402
|
models/crnn/5s.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e72bf03a664a8c1418d3b16d0634df595428aed5e4bcda75e75115548ab085df
|
| 3 |
+
size 1626402
|
models/metrics_summary_table.csv
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Model,Split Size,Slice Accuracy,Slice Loss,Song Accuracy,Execution Time,Epoch
|
| 2 |
+
0,MusicCNN,1s,0.7991666666666667,0.8332288451790809,0.9,0h 21m 1s,106
|
| 3 |
+
1,MusicCNN,3s,0.8205263157894737,0.8359026940245378,0.86,0h 24m 59s,82
|
| 4 |
+
2,MusicCNN,5s,0.8372727272727273,0.8710557313398881,0.85,0h 24m 19s,84
|
| 5 |
+
3,MusicCNN,10s,0.836,0.987051441192627,0.88,0h 38m 37s,133
|
| 6 |
+
4,MusicCRNN2D,1s,0.8333333333333334,0.746949385046959,0.94,0h 13m 12s,48
|
| 7 |
+
5,MusicCRNN2D,3s,0.8078947368421052,0.8572936429475483,0.89,0h 11m 48s,43
|
| 8 |
+
6,MusicCRNN2D,5s,0.8190909090909091,0.8851883194663308,0.89,0h 13m 50s,68
|
| 9 |
+
7,MusicCRNN2D,10s,0.778,0.9759823226928712,0.85,0h 18m 24s,66
|
| 10 |
+
8,MusicCRNN1D,1s,0.5648333333333333,1.7309636125564576,0.7,0h 11m 57s,97
|
| 11 |
+
9,MusicCRNN1D,3s,0.5510526315789473,1.7096873275857225,0.62,0h 2m 33s,65
|
| 12 |
+
10,MusicCRNN1D,5s,0.5972727272727273,1.698943519592285,0.69,0h 3m 30s,123
|
| 13 |
+
11,MusicCRNN1D,10s,0.532,1.783525552749634,0.59,0h 2m 3s,124
|
| 14 |
+
12,MusicRNN,1s,0.6378333333333334,1.2230633710006171,0.78,0h 5m 56s,45
|
| 15 |
+
13,MusicRNN,3s,0.6210526315789474,1.2467606188984293,0.71,0h 2m 14s,42
|
| 16 |
+
14,MusicRNN,5s,0.6018181818181818,1.1574554492668672,0.63,0h 0m 43s,46
|
| 17 |
+
15,MusicRNN,10s,0.502,1.3741582012176514,0.53,0h 0m 49s,47
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
librosa
|
| 4 |
+
numpy
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
src/models.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class MusicCNN(nn.Module):
|
| 4 |
+
def __init__(self, num_classes, dropout_rate=0.3, device="cuda"):
|
| 5 |
+
super(MusicCNN, self).__init__()
|
| 6 |
+
self.device = device
|
| 7 |
+
|
| 8 |
+
# Convolutional blocks
|
| 9 |
+
self.conv_block1 = nn.Sequential(
|
| 10 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
| 11 |
+
nn.BatchNorm2d(32),
|
| 12 |
+
nn.ReLU(),
|
| 13 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
| 14 |
+
nn.BatchNorm2d(32),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.MaxPool2d(2, 2),
|
| 17 |
+
nn.Dropout2d(dropout_rate)
|
| 18 |
+
).to(device)
|
| 19 |
+
|
| 20 |
+
self.conv_block2 = nn.Sequential(
|
| 21 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 22 |
+
nn.BatchNorm2d(64),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 25 |
+
nn.BatchNorm2d(64),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.MaxPool2d(2, 2),
|
| 28 |
+
nn.Dropout2d(dropout_rate)
|
| 29 |
+
).to(device)
|
| 30 |
+
|
| 31 |
+
self.conv_block3 = nn.Sequential(
|
| 32 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 33 |
+
nn.BatchNorm2d(128),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 36 |
+
nn.BatchNorm2d(128),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.MaxPool2d(2, 2),
|
| 39 |
+
nn.Dropout2d(dropout_rate)
|
| 40 |
+
).to(device)
|
| 41 |
+
|
| 42 |
+
self.fc_layers = None # Fully connected layers will be initialized later
|
| 43 |
+
self.num_classes = num_classes
|
| 44 |
+
self.dropout_rate = dropout_rate
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.conv_block1(x)
|
| 48 |
+
x = self.conv_block2(x)
|
| 49 |
+
x = self.conv_block3(x)
|
| 50 |
+
|
| 51 |
+
# Flatten dynamically
|
| 52 |
+
x = x.view(x.size(0), -1)
|
| 53 |
+
|
| 54 |
+
# Initialize FC layers dynamically
|
| 55 |
+
if self.fc_layers is None:
|
| 56 |
+
fc_input_size = x.size(1)
|
| 57 |
+
self.fc_layers = nn.Sequential(
|
| 58 |
+
nn.Linear(fc_input_size, 512),
|
| 59 |
+
nn.BatchNorm1d(512),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.Dropout(self.dropout_rate),
|
| 62 |
+
nn.Linear(512, 256),
|
| 63 |
+
nn.BatchNorm1d(256),
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
nn.Dropout(self.dropout_rate),
|
| 66 |
+
nn.Linear(256, self.num_classes)
|
| 67 |
+
).to(self.device)
|
| 68 |
+
|
| 69 |
+
x = self.fc_layers(x)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MusicCRNN2D(nn.Module):
|
| 74 |
+
def __init__(self, num_classes, dropout_rate=0.1, gru_hidden_size=32, device="cuda"):
|
| 75 |
+
super(MusicCRNN2D, self).__init__()
|
| 76 |
+
self.device = device
|
| 77 |
+
|
| 78 |
+
# Input batch normalization
|
| 79 |
+
self.input_bn = nn.BatchNorm2d(1).to(device)
|
| 80 |
+
|
| 81 |
+
# Convolutional blocks
|
| 82 |
+
self.conv_block1 = nn.Sequential(
|
| 83 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
| 84 |
+
nn.BatchNorm2d(64),
|
| 85 |
+
nn.ELU(),
|
| 86 |
+
nn.MaxPool2d((2, 2)),
|
| 87 |
+
nn.Dropout2d(dropout_rate)
|
| 88 |
+
).to(device)
|
| 89 |
+
|
| 90 |
+
self.conv_block2 = nn.Sequential(
|
| 91 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 92 |
+
nn.BatchNorm2d(128),
|
| 93 |
+
nn.ELU(),
|
| 94 |
+
nn.MaxPool2d((4, 2)),
|
| 95 |
+
nn.Dropout2d(dropout_rate)
|
| 96 |
+
).to(device)
|
| 97 |
+
|
| 98 |
+
self.conv_block3 = nn.Sequential(
|
| 99 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 100 |
+
nn.BatchNorm2d(128),
|
| 101 |
+
nn.ELU(),
|
| 102 |
+
nn.MaxPool2d((4, 2)),
|
| 103 |
+
nn.Dropout2d(dropout_rate)
|
| 104 |
+
).to(device)
|
| 105 |
+
|
| 106 |
+
self.conv_block4 = nn.Sequential(
|
| 107 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 108 |
+
nn.BatchNorm2d(128),
|
| 109 |
+
nn.ELU(),
|
| 110 |
+
nn.MaxPool2d((4, 2)),
|
| 111 |
+
nn.Dropout2d(dropout_rate)
|
| 112 |
+
).to(device)
|
| 113 |
+
|
| 114 |
+
self.gru_stack = None # GRU layers will be initialized later
|
| 115 |
+
self.classifier = None
|
| 116 |
+
self.num_classes = num_classes
|
| 117 |
+
self.dropout_rate = dropout_rate
|
| 118 |
+
self.gru_hidden_size = gru_hidden_size
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
x = self.input_bn(x)
|
| 122 |
+
x = self.conv_block1(x)
|
| 123 |
+
x = self.conv_block2(x)
|
| 124 |
+
x = self.conv_block3(x)
|
| 125 |
+
x = self.conv_block4(x)
|
| 126 |
+
|
| 127 |
+
# Reshape for GRU
|
| 128 |
+
batch_size, _, freq, time = x.shape
|
| 129 |
+
x = x.permute(0, 3, 1, 2) # (batch, time, channels, freq)
|
| 130 |
+
x = x.reshape(batch_size, time, -1)
|
| 131 |
+
|
| 132 |
+
# Initialize GRU dynamically
|
| 133 |
+
if self.gru_stack is None:
|
| 134 |
+
gru_input_size = x.size(2)
|
| 135 |
+
self.gru_stack = nn.GRU(
|
| 136 |
+
input_size=gru_input_size,
|
| 137 |
+
hidden_size=self.gru_hidden_size,
|
| 138 |
+
batch_first=True,
|
| 139 |
+
bidirectional=True,
|
| 140 |
+
).to(self.device)
|
| 141 |
+
self.classifier = nn.Sequential(
|
| 142 |
+
nn.Dropout(self.dropout_rate * 3),
|
| 143 |
+
nn.Linear(self.gru_hidden_size * 2, self.num_classes) # * 2 for bidirectional
|
| 144 |
+
).to(self.device)
|
| 145 |
+
|
| 146 |
+
x, _ = self.gru_stack(x)
|
| 147 |
+
|
| 148 |
+
# Take the last time step
|
| 149 |
+
x = x[:, -1, :]
|
| 150 |
+
|
| 151 |
+
# Classification
|
| 152 |
+
x = self.classifier(x)
|
| 153 |
+
return x
|
src/utility.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def slice_songs(X, Y, S,
|
| 5 |
+
sr=22050,
|
| 6 |
+
hop_length=512,
|
| 7 |
+
length_in_seconds=30,
|
| 8 |
+
overlap=0.5):
|
| 9 |
+
"""
|
| 10 |
+
Slice spectrograms into smaller splits with overlap.
|
| 11 |
+
|
| 12 |
+
Parameters:
|
| 13 |
+
X: Array of spectrograms
|
| 14 |
+
Y: Array of labels
|
| 15 |
+
S: Array of song names
|
| 16 |
+
sr: Sample rate (default: 22050)
|
| 17 |
+
hop_length: Hop length used in spectrogram creation (default: 512)
|
| 18 |
+
length_in_seconds: Length of each slice in seconds (default: 30)
|
| 19 |
+
overlap: Overlap ratio between consecutive slices (default: 0.5 for 50% overlap)
|
| 20 |
+
"""
|
| 21 |
+
# Compute the number of frames for the desired slice length
|
| 22 |
+
frames_per_second = sr / hop_length
|
| 23 |
+
slice_length_frames = int(length_in_seconds * frames_per_second)
|
| 24 |
+
|
| 25 |
+
# Calculate hop size for overlapping (stride)
|
| 26 |
+
stride = int(slice_length_frames * (1 - overlap))
|
| 27 |
+
|
| 28 |
+
# Initialize lists for sliced data
|
| 29 |
+
X_slices = []
|
| 30 |
+
Y_slices = []
|
| 31 |
+
S_slices = []
|
| 32 |
+
|
| 33 |
+
# Slice each spectrogram
|
| 34 |
+
for i, spectrogram in enumerate(X):
|
| 35 |
+
num_frames = spectrogram.shape[1]
|
| 36 |
+
|
| 37 |
+
# Calculate start positions for all slices
|
| 38 |
+
start_positions = range(0, num_frames - slice_length_frames + 1, stride)
|
| 39 |
+
|
| 40 |
+
for start_frame in start_positions:
|
| 41 |
+
end_frame = start_frame + slice_length_frames
|
| 42 |
+
|
| 43 |
+
# Extract the slice
|
| 44 |
+
slice_ = spectrogram[:, start_frame:end_frame]
|
| 45 |
+
|
| 46 |
+
# Only add if the slice is the expected length
|
| 47 |
+
if slice_.shape[1] == slice_length_frames:
|
| 48 |
+
X_slices.append(slice_)
|
| 49 |
+
Y_slices.append(Y[i])
|
| 50 |
+
S_slices.append(S[i])
|
| 51 |
+
|
| 52 |
+
# Convert lists to numpy arrays
|
| 53 |
+
X_slices = np.array(X_slices)
|
| 54 |
+
Y_slices = np.array(Y_slices)
|
| 55 |
+
S_slices = np.array(S_slices)
|
| 56 |
+
|
| 57 |
+
return X_slices, Y_slices, S_slices
|