JukeBox / app.py
hjimjim
upload model
b672e6e
import streamlit as st
import torch
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi as pm
from VAE import VAE
import pretty_midi as pm
from scipy.io.wavfile import write
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load VAE model
@st.cache_resource
def load_model():
vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256)
vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device))
vae = vae.to(device)
vae.eval()
return vae
# Function to process the uploaded MIDI file
def process_midi(file):
try:
mid = pm.PrettyMIDI(file)
fs = 10
hand_dict = {"right": None, "left": None}
pitch_list = []
for inst in mid.instruments:
if inst.program // 8 > 0:
continue
piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
if np.sum(piano_roll) == 0:
continue
hand_pitch = np.where(piano_roll)
pitch_list.append(np.average(hand_pitch[0]))
if len(pitch_list) == 0:
st.error("No valid piano data found.")
return None, None
elif len(pitch_list) == 1:
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
hand_dict['left'] = np.zeros_like(hand_dict['right'])
else:
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
pitch_start, pitch_stop, duration = 24, 100, 150
right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration]
left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration]
if np.sum(right_hand) == 0 or np.sum(left_hand) == 0:
st.error("Invalid data detected in MIDI file.")
return None, None
return right_hand, left_hand
except Exception as e:
st.error(f"Error processing MIDI: {e}")
return None, None
# Run the VAE model for reconstruction
def reconstruct(right, left, model):
right_tensor = torch.tensor(right, dtype=torch.float32).to(device)
left_tensor = torch.tensor(left, dtype=torch.float32).to(device)
input_tensor = torch.cat([right_tensor, left_tensor], dim=0)
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
recon_data, _, _, _ = model(input_tensor)
return recon_data.squeeze(0).cpu().numpy()
def midi_to_wav(midi_file, wav_file="output.wav", volume_increase_db=17):
midi_data = pm.PrettyMIDI(midi_file)
audio_data = midi_data.synthesize(fs=44100)
audio_data = np.int16(audio_data / np.max(np.abs(audio_data)) * 32767 * 0.9)
write(wav_file, 44100, audio_data)
return wav_file
# Create a MIDI stream from piano roll data
def create_midi_from_piano_roll(right_hand, left_hand, fs=8):
pm_obj = pm.PrettyMIDI()
for hand_data in [right_hand, left_hand]:
instrument = pm.Instrument(program=0) # Acoustic Grand Piano
pm_obj.instruments.append(instrument)
for pitch, series in enumerate(hand_data):
start_time = None
threshold = 0.92 # Threshold for detecting note onset
for j in range(len(series) - 1):
if series[j] < threshold and series[j + 1] >= threshold:
start_time = j / fs
elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None:
end_time = (j + 1) / fs
if start_time is not None and end_time is not None:
note = pm.Note(
velocity=100,
pitch=pitch + 24,
start=start_time,
end=end_time
)
instrument.notes.append(note)
start_time = None
if start_time is not None:
end_time = len(series) / fs
note = pm.Note(
velocity=100,
pitch=pitch + 24,
start=start_time,
end=end_time
)
instrument.notes.append(note)
return pm_obj
# Function to convert reconstructed data to MIDI files
def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8):
midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs)
midi_data.write(file_name)
print(f"MIDI file saved to {file_name}")
return file_name
# Streamlit interface
st.title("GRU-VAE Reconstruction Demo")
model = load_model()
# File upload
uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"])
if uploaded_file is not None:
st.write("Processing MIDI file...")
right_hand, left_hand = process_midi(uploaded_file)
if right_hand is not None and left_hand is not None:
# Display original data
st.write("Original MIDI Data:")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(right_hand, aspect='auto', cmap='gray')
axs[0].set_title("Right Hand")
axs[1].imshow(left_hand, aspect='auto', cmap='gray')
axs[1].set_title("Left Hand")
st.pyplot(fig)
# Reconstruction
recon_data = reconstruct(right_hand.T, left_hand.T, model)
recon_right = recon_data[:150].T
recon_left = recon_data[150:].T
# Display reconstructed data
st.write("Reconstructed MIDI Data:")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(recon_right, aspect='auto', cmap='gray')
axs[0].set_title("Right Hand (Reconstructed)")
axs[1].imshow(recon_left, aspect='auto', cmap='gray')
axs[1].set_title("Left Hand (Reconstructed)")
st.pyplot(fig)
# Convert to MIDI and then to WAV for playback
original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8)
recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8)
# Save and play audio
original_wav_path = midi_to_wav(original_midi, "original_output.wav")
recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav")
st.write("Original MIDI Playback:")
st.audio(original_wav_path, format='audio/wav')
st.write("Reconstructed MIDI Playback:")
st.audio(recon_wav_path, format='audio/wav')