Upload 11 files
Browse files- .gitattributes +2 -0
- app.py +137 -0
- audios_samples/classical.00000.wav +3 -0
- audios_samples/country.00031.wav +3 -0
- model/__pycache__/model.cpython-312.pyc +0 -0
- model/model.py +91 -0
- model/model_256.pth +3 -0
- requirements.txt +0 -0
- static/aivn_favicon.png +0 -0
- static/aivn_logo.png +0 -0
- static/demo.png +0 -0
- utils.py +78 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
audios_samples/classical.00000.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
audios_samples/country.00031.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from utils import *
|
| 7 |
+
from model.model import CVAE
|
| 8 |
+
|
| 9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
AUDIO_SAMPLES_DIR = "audios_samples"
|
| 11 |
+
|
| 12 |
+
st.set_page_config(
|
| 13 |
+
page_title="Audio Reconstruction",
|
| 14 |
+
page_icon="./static/aivn_favicon.png",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
st.image("./static/aivn_logo.png", width=300)
|
| 18 |
+
|
| 19 |
+
st.title('New Genres Audio Reconstruction')
|
| 20 |
+
|
| 21 |
+
@st.cache_data
|
| 22 |
+
def load_models():
|
| 23 |
+
st.spinner('Đang tải mô hình...')
|
| 24 |
+
# lưu mô hình để tránh tải lại
|
| 25 |
+
model = CVAE(64, 128, 256, 130, len(uni_genres_list)).to(device)
|
| 26 |
+
model.load_state_dict(torch.load('model/model_256.pth', map_location=torch.device('cpu')))
|
| 27 |
+
model.eval()
|
| 28 |
+
return model
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def gen_audio(model, audio_source, genres_list, fixed_length_seconds=3):
|
| 32 |
+
with st.spinner('Đang xử lý âm thanh...'):
|
| 33 |
+
audio_data, sr = load_and_resample_audio(audio_source)
|
| 34 |
+
n_frames = len(audio_data)
|
| 35 |
+
segment_length_frame = int(fixed_length_seconds * sr)
|
| 36 |
+
n_segments = n_frames // segment_length_frame
|
| 37 |
+
|
| 38 |
+
split_audio_text_placeholder = st.empty()
|
| 39 |
+
split_audio_text_placeholder.text("Đang chia nhỏ audio... ✂")
|
| 40 |
+
progress_bar_placeholder = st.empty()
|
| 41 |
+
progress_bar = progress_bar_placeholder.progress(0)
|
| 42 |
+
|
| 43 |
+
audios = []
|
| 44 |
+
for i in range(n_segments):
|
| 45 |
+
start = i * segment_length_frame
|
| 46 |
+
end = (i + 1) * segment_length_frame
|
| 47 |
+
segment = audio_data[start:end]
|
| 48 |
+
mel_spec = audio_to_melspec(segment, sr, to_db=True)
|
| 49 |
+
mel_spec_norm = normalize_melspec(mel_spec)
|
| 50 |
+
mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
|
| 51 |
+
mel_spec_norm = torch.tensor(mel_spec_norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 52 |
+
audios.append((mel_spec_norm, mel_spec))
|
| 53 |
+
progress_bar.progress(int((i + 1) / n_segments * 100))
|
| 54 |
+
|
| 55 |
+
progress_bar_placeholder.empty()
|
| 56 |
+
split_audio_text_placeholder.empty()
|
| 57 |
+
|
| 58 |
+
audios_input = torch.cat([audio[0] for audio in audios], dim=0)
|
| 59 |
+
|
| 60 |
+
genres_input = onehot_encode(tokenize(genres_list), len(uni_genres_list))
|
| 61 |
+
genres_input = torch.tensor(genres_input, dtype=torch.long).unsqueeze(0).unsqueeze(0)
|
| 62 |
+
genres_input = genres_input.repeat(audios_input.shape[0], 1, 1)
|
| 63 |
+
|
| 64 |
+
with st.spinner('Mô hình đang nấu ăn... 🍳🍴'):
|
| 65 |
+
recons, _, _ = model(audios_input, genres_input)
|
| 66 |
+
|
| 67 |
+
recon_audio_text_placeholder = st.empty()
|
| 68 |
+
recon_audio_text_placeholder.text("Đang dựng lại audio video... 🎵")
|
| 69 |
+
progress_bar_placeholder = st.empty()
|
| 70 |
+
progress_bar = progress_bar_placeholder.progress(0)
|
| 71 |
+
recon_audios = []
|
| 72 |
+
for i in range(len(recons)):
|
| 73 |
+
spec_denorm = denormalize_melspec(recons[i].detach().numpy().squeeze(), audios[i][1])
|
| 74 |
+
audio_reconstructed = melspec_to_audio(spec_denorm)
|
| 75 |
+
recon_audios.append(audio_reconstructed)
|
| 76 |
+
progress_bar.progress(int((i + 1) / len(recons) * 100))
|
| 77 |
+
recon_audios = np.concatenate(recon_audios)
|
| 78 |
+
progress_bar_placeholder.empty()
|
| 79 |
+
recon_audio_text_placeholder.empty()
|
| 80 |
+
|
| 81 |
+
return recon_audios
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def run():
|
| 85 |
+
model = load_models()
|
| 86 |
+
uploaded_audio = st.file_uploader("Tải lên 1 audio (chỉ xử lý 15s đầu tiên)", type=['wav', 'mp3'])
|
| 87 |
+
|
| 88 |
+
select_audio = st.selectbox(
|
| 89 |
+
"Hoặc chọn 1 audio mẫu dưới dây:",
|
| 90 |
+
options=[""] + [f"{file} - được lấy từ GTZAN Dataset" for file in os.listdir(AUDIO_SAMPLES_DIR) if file.endswith(('.wav', '.mp3'))],
|
| 91 |
+
index=0,
|
| 92 |
+
format_func=lambda x: "Không chọn audio mẫu" if x == "" else x
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if uploaded_audio is not None or select_audio != "":
|
| 96 |
+
if uploaded_audio is not None:
|
| 97 |
+
st.audio(uploaded_audio, format='audio/wav')
|
| 98 |
+
else:
|
| 99 |
+
uploaded_audio = os.path.join(AUDIO_SAMPLES_DIR, select_audio.replace(" - được lấy từ GTZAN Dataset", ""))
|
| 100 |
+
st.audio(uploaded_audio, format='audio/wav')
|
| 101 |
+
|
| 102 |
+
genres_list = st.multiselect('Chọn thể loại', uni_genres_list)
|
| 103 |
+
|
| 104 |
+
if st.button('Xử lý Âm Thanh'):
|
| 105 |
+
result = gen_audio(model, uploaded_audio, genres_list)
|
| 106 |
+
st.write('Kết quả:')
|
| 107 |
+
st.audio(result, format='audio/wav', sample_rate=22050)
|
| 108 |
+
|
| 109 |
+
run()
|
| 110 |
+
|
| 111 |
+
st.markdown(
|
| 112 |
+
"""
|
| 113 |
+
<style>
|
| 114 |
+
.footer {
|
| 115 |
+
position: fixed;
|
| 116 |
+
bottom: 0;
|
| 117 |
+
left: 0;
|
| 118 |
+
width: 100%;
|
| 119 |
+
background-color: #f1f1f1;
|
| 120 |
+
text-align: center;
|
| 121 |
+
padding: 10px 0;
|
| 122 |
+
font-size: 14px;
|
| 123 |
+
color: #555;
|
| 124 |
+
}
|
| 125 |
+
</style>
|
| 126 |
+
|
| 127 |
+
<div class="footer">
|
| 128 |
+
<div>
|
| 129 |
+
<a href="https://ieeexplore.ieee.org/document/1021072">*GTZAN Dataset</a>
|
| 130 |
+
</div>
|
| 131 |
+
<div>
|
| 132 |
+
2024 AI VIETNAM | Made by <a href="https://github.com/Koii2k3/Music-Reconstruction" target="_blank">Koii2k3</a>
|
| 133 |
+
</div>
|
| 134 |
+
</div>
|
| 135 |
+
""",
|
| 136 |
+
unsafe_allow_html=True
|
| 137 |
+
)
|
audios_samples/classical.00000.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8add3a7d1add1e157ce5de91be72372773b1ad2779742532c4f0ad1c7316f2a4
|
| 3 |
+
size 1323632
|
audios_samples/country.00031.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb99205cc54237a0462c593bb472af9ad977806495671d5aa1e231aca9885ccc
|
| 3 |
+
size 1323632
|
model/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (5.52 kB). View file
|
|
|
model/model.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class CVAE(nn.Module):
|
| 6 |
+
def __init__(self, d_model, latent_dim, n_frames, n_mels, n_genres):
|
| 7 |
+
super(CVAE, self).__init__()
|
| 8 |
+
self.d_model = d_model
|
| 9 |
+
self.latent_dim = latent_dim
|
| 10 |
+
self.n_frames = int(np.ceil(n_frames / 2**3))
|
| 11 |
+
self.n_mels = int(np.ceil(n_mels / 2**3))
|
| 12 |
+
self.n_genres = n_genres
|
| 13 |
+
print(self.n_frames, self.n_mels)
|
| 14 |
+
|
| 15 |
+
# Encoder
|
| 16 |
+
self.encoder = nn.Sequential(
|
| 17 |
+
nn.Conv2d(1 + self.n_genres, d_model, kernel_size=3, stride=2, padding=1), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
|
| 18 |
+
nn.BatchNorm2d(d_model),
|
| 19 |
+
nn.SiLU(),
|
| 20 |
+
nn.Dropout2d(0.05),
|
| 21 |
+
|
| 22 |
+
nn.Conv2d(d_model, d_model * 2, kernel_size=3, stride=2, padding=1), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
|
| 23 |
+
nn.BatchNorm2d(d_model * 2),
|
| 24 |
+
nn.SiLU(),
|
| 25 |
+
nn.Dropout2d(0.1),
|
| 26 |
+
|
| 27 |
+
nn.Conv2d(d_model * 2, d_model * 4, kernel_size=3, stride=2, padding=1), # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
|
| 28 |
+
nn.BatchNorm2d(d_model * 4),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Dropout2d(0.15),
|
| 31 |
+
|
| 32 |
+
nn.AdaptiveAvgPool2d((1, 1)), # [B, 4*d, 1, 1]
|
| 33 |
+
nn.Flatten()
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Latent space
|
| 37 |
+
self.fc_mu = nn.Linear(d_model * 4, latent_dim)
|
| 38 |
+
self.fc_logvar = nn.Linear(d_model * 4, latent_dim)
|
| 39 |
+
|
| 40 |
+
# Decoder
|
| 41 |
+
self.decoder_input = nn.Linear(latent_dim + self.n_genres, d_model * 4 * self.n_frames * self.n_mels) # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
|
| 42 |
+
self.decoder = nn.Sequential(
|
| 43 |
+
nn.ConvTranspose2d(d_model * 4, d_model * 2, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
|
| 44 |
+
nn.BatchNorm2d(d_model * 2),
|
| 45 |
+
nn.SiLU(),
|
| 46 |
+
nn.Dropout2d(0.1),
|
| 47 |
+
|
| 48 |
+
nn.ConvTranspose2d(d_model * 2, d_model, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
|
| 49 |
+
nn.BatchNorm2d(d_model),
|
| 50 |
+
nn.SiLU(),
|
| 51 |
+
nn.Dropout2d(0.05),
|
| 52 |
+
|
| 53 |
+
nn.ConvTranspose2d(d_model, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # [B, 1, n_mels, n_frame]
|
| 54 |
+
nn.Sigmoid()
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def reparameterize(self, mu, logvar):
|
| 58 |
+
std = torch.exp(0.5 * logvar)
|
| 59 |
+
eps = torch.randn_like(std)
|
| 60 |
+
return mu + eps * std
|
| 61 |
+
|
| 62 |
+
def forward(self, x, genres_input):
|
| 63 |
+
ori_genres_embed = genres_input.view(genres_input.size(0), -1)
|
| 64 |
+
genres_embed = ori_genres_embed.unsqueeze(-1).unsqueeze(-1)
|
| 65 |
+
genres_embed = genres_embed.expand(-1, -1, x.size(2), x.size(3))
|
| 66 |
+
x_genres = torch.cat((x, genres_embed), dim=1)
|
| 67 |
+
|
| 68 |
+
h = x_genres
|
| 69 |
+
shortcuts = []
|
| 70 |
+
for block in self.encoder:
|
| 71 |
+
h = block(h)
|
| 72 |
+
if isinstance(block, nn.SiLU):
|
| 73 |
+
shortcuts.append(h) # skip-connection
|
| 74 |
+
|
| 75 |
+
mu = self.fc_mu(h)
|
| 76 |
+
logvar = self.fc_logvar(h)
|
| 77 |
+
|
| 78 |
+
z = self.reparameterize(mu, logvar)
|
| 79 |
+
z_genres = torch.cat((z, ori_genres_embed), dim=1)
|
| 80 |
+
|
| 81 |
+
h_dec = self.decoder_input(z_genres)
|
| 82 |
+
h_dec = h_dec.view(-1, self.d_model * 4, self.n_frames, self.n_mels)
|
| 83 |
+
|
| 84 |
+
for block in self.decoder:
|
| 85 |
+
if isinstance(block, nn.ConvTranspose2d) and shortcuts:
|
| 86 |
+
shortcut = shortcuts.pop() # skip-connection
|
| 87 |
+
h_dec = h_dec + shortcut
|
| 88 |
+
h_dec = block(h_dec)
|
| 89 |
+
|
| 90 |
+
recon = h_dec[:, :, :x.size(2), :x.size(3)]
|
| 91 |
+
return recon, mu, logvar
|
model/model_256.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aade85a5d0f5a11612c2e4fd17a72842936b3a18e8c9c28462f4ae64cd1f9755
|
| 3 |
+
size 89088777
|
requirements.txt
ADDED
|
Binary file (172 Bytes). View file
|
|
|
static/aivn_favicon.png
ADDED
|
|
static/aivn_logo.png
ADDED
|
static/demo.png
ADDED
|
utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 4 |
+
|
| 5 |
+
uni_genres_list = ['House', 'Soundtrack', 'Composed Music', 'Drone', 'Instrumental', 'Ambient Electronic', 'Blues', 'Easy Listening', 'Classical', 'Jazz', 'Christmas', 'Electronic', 'Ambient', 'Lo-fi Instrumental', 'Lounge', 'Contemporary Classical', 'Indie-Rock', 'Dance', 'New Age', 'Halloween', 'Lo-fi Electronic', '20th Century Classical', 'Piano', 'Chill-out', 'Pop']
|
| 6 |
+
genres2idx = {genre: idx for idx, genre in enumerate(uni_genres_list)}
|
| 7 |
+
idx2genres = {idx: genre for genre, idx in genres2idx.items()}
|
| 8 |
+
|
| 9 |
+
def tokenize(genres):
|
| 10 |
+
return [genres2idx[genre] for genre in genres if genre in genres2idx]
|
| 11 |
+
|
| 12 |
+
def detokenize_tolist(tokens):
|
| 13 |
+
return [idx2genres[token] for token in tokens if token in idx2genres]
|
| 14 |
+
|
| 15 |
+
def onehot_encode(tokens, max_genres):
|
| 16 |
+
onehot = np.zeros(max_genres)
|
| 17 |
+
onehot[tokens] = 1
|
| 18 |
+
return onehot
|
| 19 |
+
|
| 20 |
+
def onehot_decode(onehot):
|
| 21 |
+
return [idx for idx, val in enumerate(onehot) if val == 1]
|
| 22 |
+
|
| 23 |
+
def load_and_resample_audio(file_path, target_sr=22050, max_duration=15):
|
| 24 |
+
audio, sr = librosa.load(file_path, sr=None)
|
| 25 |
+
if sr != target_sr:
|
| 26 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
| 27 |
+
if len(audio) > target_sr * max_duration:
|
| 28 |
+
audio = audio[:target_sr * max_duration]
|
| 29 |
+
return audio, target_sr
|
| 30 |
+
|
| 31 |
+
def audio_to_melspec(audio, sr, n_mels=256, n_fft=2048, hop_length=512, to_db=False):
|
| 32 |
+
spec = librosa.feature.melspectrogram(y=audio,
|
| 33 |
+
sr=sr,
|
| 34 |
+
n_fft=n_fft,
|
| 35 |
+
hop_length=hop_length,
|
| 36 |
+
win_length=None,
|
| 37 |
+
window='hann',
|
| 38 |
+
center=True,
|
| 39 |
+
pad_mode='reflect',
|
| 40 |
+
power=2.0,
|
| 41 |
+
n_mels=n_mels)
|
| 42 |
+
|
| 43 |
+
if to_db:
|
| 44 |
+
spec = librosa.power_to_db(spec, ref=np.max)
|
| 45 |
+
|
| 46 |
+
return spec
|
| 47 |
+
|
| 48 |
+
# Normalize the Mel spectrogram
|
| 49 |
+
def normalize_melspec(melspec, norm_range=(0, 1)):
|
| 50 |
+
scaler = MinMaxScaler(feature_range=norm_range)
|
| 51 |
+
melspec = melspec.T
|
| 52 |
+
melspec_normalized = scaler.fit_transform(melspec)
|
| 53 |
+
return melspec_normalized.T
|
| 54 |
+
|
| 55 |
+
# Denormalize the Mel spectrogram
|
| 56 |
+
def denormalize_melspec(melspec_normalized, original_melspec, norm_range=(0, 1)):
|
| 57 |
+
scaler = MinMaxScaler(feature_range=norm_range)
|
| 58 |
+
melspec = original_melspec.T
|
| 59 |
+
scaler.fit(melspec)
|
| 60 |
+
melspec_denormalized = scaler.inverse_transform(melspec_normalized.T)
|
| 61 |
+
return melspec_denormalized.T
|
| 62 |
+
|
| 63 |
+
# Function to convert Mel spectrogram back to audio
|
| 64 |
+
def melspec_to_audio(melspec, sr=22050, n_fft=2048, hop_length=512, n_iter=64):
|
| 65 |
+
if np.any(melspec < 0):
|
| 66 |
+
melspec = librosa.db_to_power(melspec)
|
| 67 |
+
|
| 68 |
+
audio_reconstructed = librosa.feature.inverse.mel_to_audio(melspec,
|
| 69 |
+
sr=sr,
|
| 70 |
+
n_fft=n_fft,
|
| 71 |
+
hop_length=hop_length,
|
| 72 |
+
win_length=None,
|
| 73 |
+
window='hann',
|
| 74 |
+
center=True,
|
| 75 |
+
pad_mode='reflect',
|
| 76 |
+
power=2.0, # Ensure the correct inverse transformation
|
| 77 |
+
n_iter=n_iter)
|
| 78 |
+
return audio_reconstructed
|