Skynova commited on
Commit
ffa9b64
·
verified ·
1 Parent(s): 83d6852

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: MusicGenrePulse
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.15.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: DL app to classify music and get genre distribution.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: MusicGenrePulse
3
+ emoji: 🦀
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.15.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: DL app to classify music and get genre distribution.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import librosa
5
+ import time
6
+ from MusicGenrePulse.src.utility import slice_songs # Adjust your imports as needed
7
+ from MusicGenrePulse.src import MusicCNN, MusicCRNN2D
8
+
9
+ # Configuration
10
+ DESIRED_SR = 22050
11
+ HOP_LENGTH = 512
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ NUM_CLASSES = 10
14
+
15
+ # Model loading (example for cnn and crnn; update paths as necessary)
16
+ models = {"cnn": {}, "crnn": {}}
17
+ # For instance:
18
+ cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"}
19
+ crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth",
20
+ 10: "models/crnn/10s.pth"}
21
+
22
+
23
+ def get_frames(slice_length):
24
+ return int(slice_length * DESIRED_SR / HOP_LENGTH)
25
+
26
+
27
+ # Load cnn models
28
+ for slice_len, path in cnn_model_paths.items():
29
+ model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE)
30
+ dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
31
+ _ = model(dummy_input)
32
+ model.load_state_dict(torch.load(path, map_location=DEVICE))
33
+ model.to(DEVICE)
34
+ model.eval()
35
+ models["cnn"][slice_len] = model
36
+
37
+ # Load crnn models
38
+ for slice_len, path in crnn_model_paths.items():
39
+ model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE)
40
+ dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
41
+ _ = model(dummy_input)
42
+ model.load_state_dict(torch.load(path, map_location=DEVICE))
43
+ model.to(DEVICE)
44
+ model.eval()
45
+ models["crnn"][slice_len] = model
46
+
47
+ GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
48
+
49
+
50
+ def predict_genre(audio_file, slice_length, architecture):
51
+ slice_length = int(slice_length)
52
+ start_time = time.time()
53
+
54
+ y, sr = librosa.load(audio_file, sr=DESIRED_SR)
55
+ target_length = int(np.ceil(len(y) / sr)) * sr
56
+ if len(y) < target_length:
57
+ y = np.pad(y, (0, target_length - len(y)), mode='constant')
58
+
59
+ mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128)
60
+ mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
61
+ min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db)
62
+ normalized_spectrogram = (mel_spectrogram_db - min_val) / (
63
+ max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db
64
+
65
+ X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH,
66
+ length_in_seconds=slice_length)
67
+ X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE)
68
+
69
+ model_used = models[architecture][slice_length]
70
+ with torch.no_grad():
71
+ outputs = model_used(X_slices)
72
+ probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
73
+
74
+ avg_probs = np.mean(probabilities, axis=0)
75
+ genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)}
76
+ inference_time = time.time() - start_time
77
+ return genre_distribution, f"Inference Time: {inference_time:.2f} seconds"
78
+
79
+
80
+ slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)")
81
+ architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture")
82
+
83
+ demo = gr.Interface(
84
+ fn=predict_genre,
85
+ inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown],
86
+ outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")],
87
+ title="Music Genre Classifier",
88
+ description="Upload an audio file, select a slice length and model architecture to predict its genre distribution."
89
+ )
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch()
93
+
94
+
models/cnn/10s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7280b3c52a5f2c741180160eed533237817bb9ba66bd8edf6519b0ce7776670b
3
+ size 224021010
models/cnn/1s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2914628c6cb675e7c75e14f1686c0f20bd84c6cacad4a41db367bd476796f6be
3
+ size 22694418
models/cnn/3s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e8135267a6db5e835e246755cde7997d00944415bd5db344fc2e036cb8f5b06
3
+ size 68831762
models/cnn/5s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca0c926d4393f112c3acee391ab1cc4fbfef41f015c156e8aeb68dbdf1cee09
3
+ size 110774802
models/crnn/10s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:161bb49d3ed4bb761e9d7c6095afe5fbc144c7c95d49df5c1198d477d96191ce
3
+ size 1626402
models/crnn/1s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdeec8991d3b655e6785a2aa451accebd479ab3a0f3d61a4fc39429af471d6a
3
+ size 1626402
models/crnn/3s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68bdbe9d17904bfb57720554edd4947c3c9e3df9ad4f14c505fdeb4420ce3b77
3
+ size 1626402
models/crnn/5s.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72bf03a664a8c1418d3b16d0634df595428aed5e4bcda75e75115548ab085df
3
+ size 1626402
models/metrics_summary_table.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,Model,Split Size,Slice Accuracy,Slice Loss,Song Accuracy,Execution Time,Epoch
2
+ 0,MusicCNN,1s,0.7991666666666667,0.8332288451790809,0.9,0h 21m 1s,106
3
+ 1,MusicCNN,3s,0.8205263157894737,0.8359026940245378,0.86,0h 24m 59s,82
4
+ 2,MusicCNN,5s,0.8372727272727273,0.8710557313398881,0.85,0h 24m 19s,84
5
+ 3,MusicCNN,10s,0.836,0.987051441192627,0.88,0h 38m 37s,133
6
+ 4,MusicCRNN2D,1s,0.8333333333333334,0.746949385046959,0.94,0h 13m 12s,48
7
+ 5,MusicCRNN2D,3s,0.8078947368421052,0.8572936429475483,0.89,0h 11m 48s,43
8
+ 6,MusicCRNN2D,5s,0.8190909090909091,0.8851883194663308,0.89,0h 13m 50s,68
9
+ 7,MusicCRNN2D,10s,0.778,0.9759823226928712,0.85,0h 18m 24s,66
10
+ 8,MusicCRNN1D,1s,0.5648333333333333,1.7309636125564576,0.7,0h 11m 57s,97
11
+ 9,MusicCRNN1D,3s,0.5510526315789473,1.7096873275857225,0.62,0h 2m 33s,65
12
+ 10,MusicCRNN1D,5s,0.5972727272727273,1.698943519592285,0.69,0h 3m 30s,123
13
+ 11,MusicCRNN1D,10s,0.532,1.783525552749634,0.59,0h 2m 3s,124
14
+ 12,MusicRNN,1s,0.6378333333333334,1.2230633710006171,0.78,0h 5m 56s,45
15
+ 13,MusicRNN,3s,0.6210526315789474,1.2467606188984293,0.71,0h 2m 14s,42
16
+ 14,MusicRNN,5s,0.6018181818181818,1.1574554492668672,0.63,0h 0m 43s,46
17
+ 15,MusicRNN,10s,0.502,1.3741582012176514,0.53,0h 0m 49s,47
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ librosa
4
+ numpy
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
src/models.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class MusicCNN(nn.Module):
4
+ def __init__(self, num_classes, dropout_rate=0.3, device="cuda"):
5
+ super(MusicCNN, self).__init__()
6
+ self.device = device
7
+
8
+ # Convolutional blocks
9
+ self.conv_block1 = nn.Sequential(
10
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
11
+ nn.BatchNorm2d(32),
12
+ nn.ReLU(),
13
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(32),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(2, 2),
17
+ nn.Dropout2d(dropout_rate)
18
+ ).to(device)
19
+
20
+ self.conv_block2 = nn.Sequential(
21
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
22
+ nn.BatchNorm2d(64),
23
+ nn.ReLU(),
24
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(64),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(2, 2),
28
+ nn.Dropout2d(dropout_rate)
29
+ ).to(device)
30
+
31
+ self.conv_block3 = nn.Sequential(
32
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
33
+ nn.BatchNorm2d(128),
34
+ nn.ReLU(),
35
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(128),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2, 2),
39
+ nn.Dropout2d(dropout_rate)
40
+ ).to(device)
41
+
42
+ self.fc_layers = None # Fully connected layers will be initialized later
43
+ self.num_classes = num_classes
44
+ self.dropout_rate = dropout_rate
45
+
46
+ def forward(self, x):
47
+ x = self.conv_block1(x)
48
+ x = self.conv_block2(x)
49
+ x = self.conv_block3(x)
50
+
51
+ # Flatten dynamically
52
+ x = x.view(x.size(0), -1)
53
+
54
+ # Initialize FC layers dynamically
55
+ if self.fc_layers is None:
56
+ fc_input_size = x.size(1)
57
+ self.fc_layers = nn.Sequential(
58
+ nn.Linear(fc_input_size, 512),
59
+ nn.BatchNorm1d(512),
60
+ nn.ReLU(),
61
+ nn.Dropout(self.dropout_rate),
62
+ nn.Linear(512, 256),
63
+ nn.BatchNorm1d(256),
64
+ nn.ReLU(),
65
+ nn.Dropout(self.dropout_rate),
66
+ nn.Linear(256, self.num_classes)
67
+ ).to(self.device)
68
+
69
+ x = self.fc_layers(x)
70
+ return x
71
+
72
+
73
+ class MusicCRNN2D(nn.Module):
74
+ def __init__(self, num_classes, dropout_rate=0.1, gru_hidden_size=32, device="cuda"):
75
+ super(MusicCRNN2D, self).__init__()
76
+ self.device = device
77
+
78
+ # Input batch normalization
79
+ self.input_bn = nn.BatchNorm2d(1).to(device)
80
+
81
+ # Convolutional blocks
82
+ self.conv_block1 = nn.Sequential(
83
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
84
+ nn.BatchNorm2d(64),
85
+ nn.ELU(),
86
+ nn.MaxPool2d((2, 2)),
87
+ nn.Dropout2d(dropout_rate)
88
+ ).to(device)
89
+
90
+ self.conv_block2 = nn.Sequential(
91
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
92
+ nn.BatchNorm2d(128),
93
+ nn.ELU(),
94
+ nn.MaxPool2d((4, 2)),
95
+ nn.Dropout2d(dropout_rate)
96
+ ).to(device)
97
+
98
+ self.conv_block3 = nn.Sequential(
99
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
100
+ nn.BatchNorm2d(128),
101
+ nn.ELU(),
102
+ nn.MaxPool2d((4, 2)),
103
+ nn.Dropout2d(dropout_rate)
104
+ ).to(device)
105
+
106
+ self.conv_block4 = nn.Sequential(
107
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
108
+ nn.BatchNorm2d(128),
109
+ nn.ELU(),
110
+ nn.MaxPool2d((4, 2)),
111
+ nn.Dropout2d(dropout_rate)
112
+ ).to(device)
113
+
114
+ self.gru_stack = None # GRU layers will be initialized later
115
+ self.classifier = None
116
+ self.num_classes = num_classes
117
+ self.dropout_rate = dropout_rate
118
+ self.gru_hidden_size = gru_hidden_size
119
+
120
+ def forward(self, x):
121
+ x = self.input_bn(x)
122
+ x = self.conv_block1(x)
123
+ x = self.conv_block2(x)
124
+ x = self.conv_block3(x)
125
+ x = self.conv_block4(x)
126
+
127
+ # Reshape for GRU
128
+ batch_size, _, freq, time = x.shape
129
+ x = x.permute(0, 3, 1, 2) # (batch, time, channels, freq)
130
+ x = x.reshape(batch_size, time, -1)
131
+
132
+ # Initialize GRU dynamically
133
+ if self.gru_stack is None:
134
+ gru_input_size = x.size(2)
135
+ self.gru_stack = nn.GRU(
136
+ input_size=gru_input_size,
137
+ hidden_size=self.gru_hidden_size,
138
+ batch_first=True,
139
+ bidirectional=True,
140
+ ).to(self.device)
141
+ self.classifier = nn.Sequential(
142
+ nn.Dropout(self.dropout_rate * 3),
143
+ nn.Linear(self.gru_hidden_size * 2, self.num_classes) # * 2 for bidirectional
144
+ ).to(self.device)
145
+
146
+ x, _ = self.gru_stack(x)
147
+
148
+ # Take the last time step
149
+ x = x[:, -1, :]
150
+
151
+ # Classification
152
+ x = self.classifier(x)
153
+ return x
src/utility.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def slice_songs(X, Y, S,
5
+ sr=22050,
6
+ hop_length=512,
7
+ length_in_seconds=30,
8
+ overlap=0.5):
9
+ """
10
+ Slice spectrograms into smaller splits with overlap.
11
+
12
+ Parameters:
13
+ X: Array of spectrograms
14
+ Y: Array of labels
15
+ S: Array of song names
16
+ sr: Sample rate (default: 22050)
17
+ hop_length: Hop length used in spectrogram creation (default: 512)
18
+ length_in_seconds: Length of each slice in seconds (default: 30)
19
+ overlap: Overlap ratio between consecutive slices (default: 0.5 for 50% overlap)
20
+ """
21
+ # Compute the number of frames for the desired slice length
22
+ frames_per_second = sr / hop_length
23
+ slice_length_frames = int(length_in_seconds * frames_per_second)
24
+
25
+ # Calculate hop size for overlapping (stride)
26
+ stride = int(slice_length_frames * (1 - overlap))
27
+
28
+ # Initialize lists for sliced data
29
+ X_slices = []
30
+ Y_slices = []
31
+ S_slices = []
32
+
33
+ # Slice each spectrogram
34
+ for i, spectrogram in enumerate(X):
35
+ num_frames = spectrogram.shape[1]
36
+
37
+ # Calculate start positions for all slices
38
+ start_positions = range(0, num_frames - slice_length_frames + 1, stride)
39
+
40
+ for start_frame in start_positions:
41
+ end_frame = start_frame + slice_length_frames
42
+
43
+ # Extract the slice
44
+ slice_ = spectrogram[:, start_frame:end_frame]
45
+
46
+ # Only add if the slice is the expected length
47
+ if slice_.shape[1] == slice_length_frames:
48
+ X_slices.append(slice_)
49
+ Y_slices.append(Y[i])
50
+ S_slices.append(S[i])
51
+
52
+ # Convert lists to numpy arrays
53
+ X_slices = np.array(X_slices)
54
+ Y_slices = np.array(Y_slices)
55
+ S_slices = np.array(S_slices)
56
+
57
+ return X_slices, Y_slices, S_slices