File size: 5,067 Bytes
4195b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
import numpy as np
import torch
import os
from utils import *
from model.model import CVAE
device = "cuda" if torch.cuda.is_available() else "cpu"
AUDIO_SAMPLES_DIR = "audios_samples"
st.set_page_config(
page_title="Audio Reconstruction",
page_icon="./static/aivn_favicon.png",
)
st.image("./static/aivn_logo.png", width=300)
st.title('New Genres Audio Reconstruction')
@st.cache_data
def load_models():
st.spinner('Đang tải mô hình...')
# lưu mô hình để tránh tải lại
model = CVAE(64, 128, 256, 130, len(uni_genres_list)).to(device)
model.load_state_dict(torch.load('model/model_256.pth', map_location=torch.device('cpu')))
model.eval()
return model
def gen_audio(model, audio_source, genres_list, fixed_length_seconds=3):
with st.spinner('Đang xử lý âm thanh...'):
audio_data, sr = load_and_resample_audio(audio_source)
n_frames = len(audio_data)
segment_length_frame = int(fixed_length_seconds * sr)
n_segments = n_frames // segment_length_frame
split_audio_text_placeholder = st.empty()
split_audio_text_placeholder.text("Đang chia nhỏ audio... ✂")
progress_bar_placeholder = st.empty()
progress_bar = progress_bar_placeholder.progress(0)
audios = []
for i in range(n_segments):
start = i * segment_length_frame
end = (i + 1) * segment_length_frame
segment = audio_data[start:end]
mel_spec = audio_to_melspec(segment, sr, to_db=True)
mel_spec_norm = normalize_melspec(mel_spec)
mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
mel_spec_norm = torch.tensor(mel_spec_norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
audios.append((mel_spec_norm, mel_spec))
progress_bar.progress(int((i + 1) / n_segments * 100))
progress_bar_placeholder.empty()
split_audio_text_placeholder.empty()
audios_input = torch.cat([audio[0] for audio in audios], dim=0)
genres_input = onehot_encode(tokenize(genres_list), len(uni_genres_list))
genres_input = torch.tensor(genres_input, dtype=torch.long).unsqueeze(0).unsqueeze(0)
genres_input = genres_input.repeat(audios_input.shape[0], 1, 1)
with st.spinner('Mô hình đang nấu ăn... 🍳🍴'):
recons, _, _ = model(audios_input, genres_input)
recon_audio_text_placeholder = st.empty()
recon_audio_text_placeholder.text("Đang dựng lại audio video... 🎵")
progress_bar_placeholder = st.empty()
progress_bar = progress_bar_placeholder.progress(0)
recon_audios = []
for i in range(len(recons)):
spec_denorm = denormalize_melspec(recons[i].detach().numpy().squeeze(), audios[i][1])
audio_reconstructed = melspec_to_audio(spec_denorm)
recon_audios.append(audio_reconstructed)
progress_bar.progress(int((i + 1) / len(recons) * 100))
recon_audios = np.concatenate(recon_audios)
progress_bar_placeholder.empty()
recon_audio_text_placeholder.empty()
return recon_audios
def run():
model = load_models()
uploaded_audio = st.file_uploader("Tải lên 1 audio (chỉ xử lý 15s đầu tiên)", type=['wav', 'mp3'])
select_audio = st.selectbox(
"Hoặc chọn 1 audio mẫu dưới dây:",
options=[""] + [f"{file} - được lấy từ GTZAN Dataset" for file in os.listdir(AUDIO_SAMPLES_DIR) if file.endswith(('.wav', '.mp3'))],
index=0,
format_func=lambda x: "Không chọn audio mẫu" if x == "" else x
)
if uploaded_audio is not None or select_audio != "":
if uploaded_audio is not None:
st.audio(uploaded_audio, format='audio/wav')
else:
uploaded_audio = os.path.join(AUDIO_SAMPLES_DIR, select_audio.replace(" - được lấy từ GTZAN Dataset", ""))
st.audio(uploaded_audio, format='audio/wav')
genres_list = st.multiselect('Chọn thể loại', uni_genres_list)
if st.button('Xử lý Âm Thanh'):
result = gen_audio(model, uploaded_audio, genres_list)
st.write('Kết quả:')
st.audio(result, format='audio/wav', sample_rate=22050)
run()
st.markdown(
"""
<style>
.footer {
position: fixed;
bottom: 0;
left: 0;
width: 100%;
background-color: #f1f1f1;
text-align: center;
padding: 10px 0;
font-size: 14px;
color: #555;
}
</style>
<div class="footer">
<div>
<a href="https://ieeexplore.ieee.org/document/1021072">*GTZAN Dataset</a>
</div>
<div>
2024 AI VIETNAM | Made by <a href="https://github.com/Koii2k3/Music-Reconstruction" target="_blank">Koii2k3</a>
</div>
</div>
""",
unsafe_allow_html=True
) |