|
|
import streamlit as st |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import pretty_midi as pm |
|
|
|
|
|
from VAE import VAE |
|
|
|
|
|
import pretty_midi as pm |
|
|
from scipy.io.wavfile import write |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256) |
|
|
vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device)) |
|
|
vae = vae.to(device) |
|
|
vae.eval() |
|
|
return vae |
|
|
|
|
|
|
|
|
def process_midi(file): |
|
|
try: |
|
|
mid = pm.PrettyMIDI(file) |
|
|
fs = 10 |
|
|
hand_dict = {"right": None, "left": None} |
|
|
pitch_list = [] |
|
|
|
|
|
for inst in mid.instruments: |
|
|
if inst.program // 8 > 0: |
|
|
continue |
|
|
|
|
|
piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
|
|
if np.sum(piano_roll) == 0: |
|
|
continue |
|
|
hand_pitch = np.where(piano_roll) |
|
|
pitch_list.append(np.average(hand_pitch[0])) |
|
|
|
|
|
if len(pitch_list) == 0: |
|
|
st.error("No valid piano data found.") |
|
|
return None, None |
|
|
elif len(pitch_list) == 1: |
|
|
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
|
|
hand_dict['left'] = np.zeros_like(hand_dict['right']) |
|
|
else: |
|
|
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
|
|
hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
|
|
|
|
|
pitch_start, pitch_stop, duration = 24, 100, 150 |
|
|
right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration] |
|
|
left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration] |
|
|
|
|
|
if np.sum(right_hand) == 0 or np.sum(left_hand) == 0: |
|
|
st.error("Invalid data detected in MIDI file.") |
|
|
return None, None |
|
|
|
|
|
return right_hand, left_hand |
|
|
except Exception as e: |
|
|
st.error(f"Error processing MIDI: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def reconstruct(right, left, model): |
|
|
right_tensor = torch.tensor(right, dtype=torch.float32).to(device) |
|
|
left_tensor = torch.tensor(left, dtype=torch.float32).to(device) |
|
|
|
|
|
input_tensor = torch.cat([right_tensor, left_tensor], dim=0) |
|
|
input_tensor = input_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
recon_data, _, _, _ = model(input_tensor) |
|
|
|
|
|
return recon_data.squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
def midi_to_wav(midi_file, wav_file="output.wav", volume_increase_db=17): |
|
|
midi_data = pm.PrettyMIDI(midi_file) |
|
|
audio_data = midi_data.synthesize(fs=44100) |
|
|
|
|
|
audio_data = np.int16(audio_data / np.max(np.abs(audio_data)) * 32767 * 0.9) |
|
|
|
|
|
write(wav_file, 44100, audio_data) |
|
|
return wav_file |
|
|
|
|
|
|
|
|
|
|
|
def create_midi_from_piano_roll(right_hand, left_hand, fs=8): |
|
|
pm_obj = pm.PrettyMIDI() |
|
|
|
|
|
for hand_data in [right_hand, left_hand]: |
|
|
instrument = pm.Instrument(program=0) |
|
|
pm_obj.instruments.append(instrument) |
|
|
|
|
|
for pitch, series in enumerate(hand_data): |
|
|
start_time = None |
|
|
threshold = 0.92 |
|
|
|
|
|
for j in range(len(series) - 1): |
|
|
if series[j] < threshold and series[j + 1] >= threshold: |
|
|
start_time = j / fs |
|
|
elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None: |
|
|
end_time = (j + 1) / fs |
|
|
|
|
|
if start_time is not None and end_time is not None: |
|
|
note = pm.Note( |
|
|
velocity=100, |
|
|
pitch=pitch + 24, |
|
|
start=start_time, |
|
|
end=end_time |
|
|
) |
|
|
instrument.notes.append(note) |
|
|
start_time = None |
|
|
|
|
|
if start_time is not None: |
|
|
end_time = len(series) / fs |
|
|
note = pm.Note( |
|
|
velocity=100, |
|
|
pitch=pitch + 24, |
|
|
start=start_time, |
|
|
end=end_time |
|
|
) |
|
|
instrument.notes.append(note) |
|
|
|
|
|
return pm_obj |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8): |
|
|
midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs) |
|
|
midi_data.write(file_name) |
|
|
|
|
|
print(f"MIDI file saved to {file_name}") |
|
|
return file_name |
|
|
|
|
|
|
|
|
|
|
|
st.title("GRU-VAE Reconstruction Demo") |
|
|
model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
st.write("Processing MIDI file...") |
|
|
right_hand, left_hand = process_midi(uploaded_file) |
|
|
|
|
|
if right_hand is not None and left_hand is not None: |
|
|
|
|
|
st.write("Original MIDI Data:") |
|
|
fig, axs = plt.subplots(1, 2, figsize=(10, 4)) |
|
|
axs[0].imshow(right_hand, aspect='auto', cmap='gray') |
|
|
axs[0].set_title("Right Hand") |
|
|
axs[1].imshow(left_hand, aspect='auto', cmap='gray') |
|
|
axs[1].set_title("Left Hand") |
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
recon_data = reconstruct(right_hand.T, left_hand.T, model) |
|
|
recon_right = recon_data[:150].T |
|
|
recon_left = recon_data[150:].T |
|
|
|
|
|
|
|
|
st.write("Reconstructed MIDI Data:") |
|
|
fig, axs = plt.subplots(1, 2, figsize=(10, 4)) |
|
|
axs[0].imshow(recon_right, aspect='auto', cmap='gray') |
|
|
axs[0].set_title("Right Hand (Reconstructed)") |
|
|
axs[1].imshow(recon_left, aspect='auto', cmap='gray') |
|
|
axs[1].set_title("Left Hand (Reconstructed)") |
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8) |
|
|
recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8) |
|
|
|
|
|
|
|
|
original_wav_path = midi_to_wav(original_midi, "original_output.wav") |
|
|
recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav") |
|
|
|
|
|
st.write("Original MIDI Playback:") |
|
|
st.audio(original_wav_path, format='audio/wav') |
|
|
|
|
|
st.write("Reconstructed MIDI Playback:") |
|
|
st.audio(recon_wav_path, format='audio/wav') |
|
|
|
|
|
|