Ripefog commited on
Commit
4195b51
·
verified ·
1 Parent(s): 2c3fb80

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audios_samples/classical.00000.wav filter=lfs diff=lfs merge=lfs -text
37
+ audios_samples/country.00031.wav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+
6
+ from utils import *
7
+ from model.model import CVAE
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ AUDIO_SAMPLES_DIR = "audios_samples"
11
+
12
+ st.set_page_config(
13
+ page_title="Audio Reconstruction",
14
+ page_icon="./static/aivn_favicon.png",
15
+ )
16
+
17
+ st.image("./static/aivn_logo.png", width=300)
18
+
19
+ st.title('New Genres Audio Reconstruction')
20
+
21
+ @st.cache_data
22
+ def load_models():
23
+ st.spinner('Đang tải mô hình...')
24
+ # lưu mô hình để tránh tải lại
25
+ model = CVAE(64, 128, 256, 130, len(uni_genres_list)).to(device)
26
+ model.load_state_dict(torch.load('model/model_256.pth', map_location=torch.device('cpu')))
27
+ model.eval()
28
+ return model
29
+
30
+
31
+ def gen_audio(model, audio_source, genres_list, fixed_length_seconds=3):
32
+ with st.spinner('Đang xử lý âm thanh...'):
33
+ audio_data, sr = load_and_resample_audio(audio_source)
34
+ n_frames = len(audio_data)
35
+ segment_length_frame = int(fixed_length_seconds * sr)
36
+ n_segments = n_frames // segment_length_frame
37
+
38
+ split_audio_text_placeholder = st.empty()
39
+ split_audio_text_placeholder.text("Đang chia nhỏ audio... ✂")
40
+ progress_bar_placeholder = st.empty()
41
+ progress_bar = progress_bar_placeholder.progress(0)
42
+
43
+ audios = []
44
+ for i in range(n_segments):
45
+ start = i * segment_length_frame
46
+ end = (i + 1) * segment_length_frame
47
+ segment = audio_data[start:end]
48
+ mel_spec = audio_to_melspec(segment, sr, to_db=True)
49
+ mel_spec_norm = normalize_melspec(mel_spec)
50
+ mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
51
+ mel_spec_norm = torch.tensor(mel_spec_norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
52
+ audios.append((mel_spec_norm, mel_spec))
53
+ progress_bar.progress(int((i + 1) / n_segments * 100))
54
+
55
+ progress_bar_placeholder.empty()
56
+ split_audio_text_placeholder.empty()
57
+
58
+ audios_input = torch.cat([audio[0] for audio in audios], dim=0)
59
+
60
+ genres_input = onehot_encode(tokenize(genres_list), len(uni_genres_list))
61
+ genres_input = torch.tensor(genres_input, dtype=torch.long).unsqueeze(0).unsqueeze(0)
62
+ genres_input = genres_input.repeat(audios_input.shape[0], 1, 1)
63
+
64
+ with st.spinner('Mô hình đang nấu ăn... 🍳🍴'):
65
+ recons, _, _ = model(audios_input, genres_input)
66
+
67
+ recon_audio_text_placeholder = st.empty()
68
+ recon_audio_text_placeholder.text("Đang dựng lại audio video... 🎵")
69
+ progress_bar_placeholder = st.empty()
70
+ progress_bar = progress_bar_placeholder.progress(0)
71
+ recon_audios = []
72
+ for i in range(len(recons)):
73
+ spec_denorm = denormalize_melspec(recons[i].detach().numpy().squeeze(), audios[i][1])
74
+ audio_reconstructed = melspec_to_audio(spec_denorm)
75
+ recon_audios.append(audio_reconstructed)
76
+ progress_bar.progress(int((i + 1) / len(recons) * 100))
77
+ recon_audios = np.concatenate(recon_audios)
78
+ progress_bar_placeholder.empty()
79
+ recon_audio_text_placeholder.empty()
80
+
81
+ return recon_audios
82
+
83
+
84
+ def run():
85
+ model = load_models()
86
+ uploaded_audio = st.file_uploader("Tải lên 1 audio (chỉ xử lý 15s đầu tiên)", type=['wav', 'mp3'])
87
+
88
+ select_audio = st.selectbox(
89
+ "Hoặc chọn 1 audio mẫu dưới dây:",
90
+ options=[""] + [f"{file} - được lấy từ GTZAN Dataset" for file in os.listdir(AUDIO_SAMPLES_DIR) if file.endswith(('.wav', '.mp3'))],
91
+ index=0,
92
+ format_func=lambda x: "Không chọn audio mẫu" if x == "" else x
93
+ )
94
+
95
+ if uploaded_audio is not None or select_audio != "":
96
+ if uploaded_audio is not None:
97
+ st.audio(uploaded_audio, format='audio/wav')
98
+ else:
99
+ uploaded_audio = os.path.join(AUDIO_SAMPLES_DIR, select_audio.replace(" - được lấy từ GTZAN Dataset", ""))
100
+ st.audio(uploaded_audio, format='audio/wav')
101
+
102
+ genres_list = st.multiselect('Chọn thể loại', uni_genres_list)
103
+
104
+ if st.button('Xử lý Âm Thanh'):
105
+ result = gen_audio(model, uploaded_audio, genres_list)
106
+ st.write('Kết quả:')
107
+ st.audio(result, format='audio/wav', sample_rate=22050)
108
+
109
+ run()
110
+
111
+ st.markdown(
112
+ """
113
+ <style>
114
+ .footer {
115
+ position: fixed;
116
+ bottom: 0;
117
+ left: 0;
118
+ width: 100%;
119
+ background-color: #f1f1f1;
120
+ text-align: center;
121
+ padding: 10px 0;
122
+ font-size: 14px;
123
+ color: #555;
124
+ }
125
+ </style>
126
+
127
+ <div class="footer">
128
+ <div>
129
+ <a href="https://ieeexplore.ieee.org/document/1021072">*GTZAN Dataset</a>
130
+ </div>
131
+ <div>
132
+ 2024 AI VIETNAM | Made by <a href="https://github.com/Koii2k3/Music-Reconstruction" target="_blank">Koii2k3</a>
133
+ </div>
134
+ </div>
135
+ """,
136
+ unsafe_allow_html=True
137
+ )
audios_samples/classical.00000.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8add3a7d1add1e157ce5de91be72372773b1ad2779742532c4f0ad1c7316f2a4
3
+ size 1323632
audios_samples/country.00031.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb99205cc54237a0462c593bb472af9ad977806495671d5aa1e231aca9885ccc
3
+ size 1323632
model/__pycache__/model.cpython-312.pyc ADDED
Binary file (5.52 kB). View file
 
model/model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+
5
+ class CVAE(nn.Module):
6
+ def __init__(self, d_model, latent_dim, n_frames, n_mels, n_genres):
7
+ super(CVAE, self).__init__()
8
+ self.d_model = d_model
9
+ self.latent_dim = latent_dim
10
+ self.n_frames = int(np.ceil(n_frames / 2**3))
11
+ self.n_mels = int(np.ceil(n_mels / 2**3))
12
+ self.n_genres = n_genres
13
+ print(self.n_frames, self.n_mels)
14
+
15
+ # Encoder
16
+ self.encoder = nn.Sequential(
17
+ nn.Conv2d(1 + self.n_genres, d_model, kernel_size=3, stride=2, padding=1), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
18
+ nn.BatchNorm2d(d_model),
19
+ nn.SiLU(),
20
+ nn.Dropout2d(0.05),
21
+
22
+ nn.Conv2d(d_model, d_model * 2, kernel_size=3, stride=2, padding=1), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
23
+ nn.BatchNorm2d(d_model * 2),
24
+ nn.SiLU(),
25
+ nn.Dropout2d(0.1),
26
+
27
+ nn.Conv2d(d_model * 2, d_model * 4, kernel_size=3, stride=2, padding=1), # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
28
+ nn.BatchNorm2d(d_model * 4),
29
+ nn.SiLU(),
30
+ nn.Dropout2d(0.15),
31
+
32
+ nn.AdaptiveAvgPool2d((1, 1)), # [B, 4*d, 1, 1]
33
+ nn.Flatten()
34
+ )
35
+
36
+ # Latent space
37
+ self.fc_mu = nn.Linear(d_model * 4, latent_dim)
38
+ self.fc_logvar = nn.Linear(d_model * 4, latent_dim)
39
+
40
+ # Decoder
41
+ self.decoder_input = nn.Linear(latent_dim + self.n_genres, d_model * 4 * self.n_frames * self.n_mels) # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
42
+ self.decoder = nn.Sequential(
43
+ nn.ConvTranspose2d(d_model * 4, d_model * 2, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
44
+ nn.BatchNorm2d(d_model * 2),
45
+ nn.SiLU(),
46
+ nn.Dropout2d(0.1),
47
+
48
+ nn.ConvTranspose2d(d_model * 2, d_model, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
49
+ nn.BatchNorm2d(d_model),
50
+ nn.SiLU(),
51
+ nn.Dropout2d(0.05),
52
+
53
+ nn.ConvTranspose2d(d_model, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # [B, 1, n_mels, n_frame]
54
+ nn.Sigmoid()
55
+ )
56
+
57
+ def reparameterize(self, mu, logvar):
58
+ std = torch.exp(0.5 * logvar)
59
+ eps = torch.randn_like(std)
60
+ return mu + eps * std
61
+
62
+ def forward(self, x, genres_input):
63
+ ori_genres_embed = genres_input.view(genres_input.size(0), -1)
64
+ genres_embed = ori_genres_embed.unsqueeze(-1).unsqueeze(-1)
65
+ genres_embed = genres_embed.expand(-1, -1, x.size(2), x.size(3))
66
+ x_genres = torch.cat((x, genres_embed), dim=1)
67
+
68
+ h = x_genres
69
+ shortcuts = []
70
+ for block in self.encoder:
71
+ h = block(h)
72
+ if isinstance(block, nn.SiLU):
73
+ shortcuts.append(h) # skip-connection
74
+
75
+ mu = self.fc_mu(h)
76
+ logvar = self.fc_logvar(h)
77
+
78
+ z = self.reparameterize(mu, logvar)
79
+ z_genres = torch.cat((z, ori_genres_embed), dim=1)
80
+
81
+ h_dec = self.decoder_input(z_genres)
82
+ h_dec = h_dec.view(-1, self.d_model * 4, self.n_frames, self.n_mels)
83
+
84
+ for block in self.decoder:
85
+ if isinstance(block, nn.ConvTranspose2d) and shortcuts:
86
+ shortcut = shortcuts.pop() # skip-connection
87
+ h_dec = h_dec + shortcut
88
+ h_dec = block(h_dec)
89
+
90
+ recon = h_dec[:, :, :x.size(2), :x.size(3)]
91
+ return recon, mu, logvar
model/model_256.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aade85a5d0f5a11612c2e4fd17a72842936b3a18e8c9c28462f4ae64cd1f9755
3
+ size 89088777
requirements.txt ADDED
Binary file (172 Bytes). View file
 
static/aivn_favicon.png ADDED
static/aivn_logo.png ADDED
static/demo.png ADDED
utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ from sklearn.preprocessing import MinMaxScaler
4
+
5
+ uni_genres_list = ['House', 'Soundtrack', 'Composed Music', 'Drone', 'Instrumental', 'Ambient Electronic', 'Blues', 'Easy Listening', 'Classical', 'Jazz', 'Christmas', 'Electronic', 'Ambient', 'Lo-fi Instrumental', 'Lounge', 'Contemporary Classical', 'Indie-Rock', 'Dance', 'New Age', 'Halloween', 'Lo-fi Electronic', '20th Century Classical', 'Piano', 'Chill-out', 'Pop']
6
+ genres2idx = {genre: idx for idx, genre in enumerate(uni_genres_list)}
7
+ idx2genres = {idx: genre for genre, idx in genres2idx.items()}
8
+
9
+ def tokenize(genres):
10
+ return [genres2idx[genre] for genre in genres if genre in genres2idx]
11
+
12
+ def detokenize_tolist(tokens):
13
+ return [idx2genres[token] for token in tokens if token in idx2genres]
14
+
15
+ def onehot_encode(tokens, max_genres):
16
+ onehot = np.zeros(max_genres)
17
+ onehot[tokens] = 1
18
+ return onehot
19
+
20
+ def onehot_decode(onehot):
21
+ return [idx for idx, val in enumerate(onehot) if val == 1]
22
+
23
+ def load_and_resample_audio(file_path, target_sr=22050, max_duration=15):
24
+ audio, sr = librosa.load(file_path, sr=None)
25
+ if sr != target_sr:
26
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
27
+ if len(audio) > target_sr * max_duration:
28
+ audio = audio[:target_sr * max_duration]
29
+ return audio, target_sr
30
+
31
+ def audio_to_melspec(audio, sr, n_mels=256, n_fft=2048, hop_length=512, to_db=False):
32
+ spec = librosa.feature.melspectrogram(y=audio,
33
+ sr=sr,
34
+ n_fft=n_fft,
35
+ hop_length=hop_length,
36
+ win_length=None,
37
+ window='hann',
38
+ center=True,
39
+ pad_mode='reflect',
40
+ power=2.0,
41
+ n_mels=n_mels)
42
+
43
+ if to_db:
44
+ spec = librosa.power_to_db(spec, ref=np.max)
45
+
46
+ return spec
47
+
48
+ # Normalize the Mel spectrogram
49
+ def normalize_melspec(melspec, norm_range=(0, 1)):
50
+ scaler = MinMaxScaler(feature_range=norm_range)
51
+ melspec = melspec.T
52
+ melspec_normalized = scaler.fit_transform(melspec)
53
+ return melspec_normalized.T
54
+
55
+ # Denormalize the Mel spectrogram
56
+ def denormalize_melspec(melspec_normalized, original_melspec, norm_range=(0, 1)):
57
+ scaler = MinMaxScaler(feature_range=norm_range)
58
+ melspec = original_melspec.T
59
+ scaler.fit(melspec)
60
+ melspec_denormalized = scaler.inverse_transform(melspec_normalized.T)
61
+ return melspec_denormalized.T
62
+
63
+ # Function to convert Mel spectrogram back to audio
64
+ def melspec_to_audio(melspec, sr=22050, n_fft=2048, hop_length=512, n_iter=64):
65
+ if np.any(melspec < 0):
66
+ melspec = librosa.db_to_power(melspec)
67
+
68
+ audio_reconstructed = librosa.feature.inverse.mel_to_audio(melspec,
69
+ sr=sr,
70
+ n_fft=n_fft,
71
+ hop_length=hop_length,
72
+ win_length=None,
73
+ window='hann',
74
+ center=True,
75
+ pad_mode='reflect',
76
+ power=2.0, # Ensure the correct inverse transformation
77
+ n_iter=n_iter)
78
+ return audio_reconstructed