File size: 6,842 Bytes
3b38d19 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b b672e6e 081442b 3b38d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import torch
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi as pm
from VAE import VAE
import pretty_midi as pm
from scipy.io.wavfile import write
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load VAE model
@st.cache_resource
def load_model():
vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256)
vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device))
vae = vae.to(device)
vae.eval()
return vae
# Function to process the uploaded MIDI file
def process_midi(file):
try:
mid = pm.PrettyMIDI(file)
fs = 10
hand_dict = {"right": None, "left": None}
pitch_list = []
for inst in mid.instruments:
if inst.program // 8 > 0:
continue
piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
if np.sum(piano_roll) == 0:
continue
hand_pitch = np.where(piano_roll)
pitch_list.append(np.average(hand_pitch[0]))
if len(pitch_list) == 0:
st.error("No valid piano data found.")
return None, None
elif len(pitch_list) == 1:
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
hand_dict['left'] = np.zeros_like(hand_dict['right'])
else:
hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
pitch_start, pitch_stop, duration = 24, 100, 150
right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration]
left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration]
if np.sum(right_hand) == 0 or np.sum(left_hand) == 0:
st.error("Invalid data detected in MIDI file.")
return None, None
return right_hand, left_hand
except Exception as e:
st.error(f"Error processing MIDI: {e}")
return None, None
# Run the VAE model for reconstruction
def reconstruct(right, left, model):
right_tensor = torch.tensor(right, dtype=torch.float32).to(device)
left_tensor = torch.tensor(left, dtype=torch.float32).to(device)
input_tensor = torch.cat([right_tensor, left_tensor], dim=0)
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
recon_data, _, _, _ = model(input_tensor)
return recon_data.squeeze(0).cpu().numpy()
def midi_to_wav(midi_file, wav_file="output.wav", volume_increase_db=17):
midi_data = pm.PrettyMIDI(midi_file)
audio_data = midi_data.synthesize(fs=44100)
audio_data = np.int16(audio_data / np.max(np.abs(audio_data)) * 32767 * 0.9)
write(wav_file, 44100, audio_data)
return wav_file
# Create a MIDI stream from piano roll data
def create_midi_from_piano_roll(right_hand, left_hand, fs=8):
pm_obj = pm.PrettyMIDI()
for hand_data in [right_hand, left_hand]:
instrument = pm.Instrument(program=0) # Acoustic Grand Piano
pm_obj.instruments.append(instrument)
for pitch, series in enumerate(hand_data):
start_time = None
threshold = 0.92 # Threshold for detecting note onset
for j in range(len(series) - 1):
if series[j] < threshold and series[j + 1] >= threshold:
start_time = j / fs
elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None:
end_time = (j + 1) / fs
if start_time is not None and end_time is not None:
note = pm.Note(
velocity=100,
pitch=pitch + 24,
start=start_time,
end=end_time
)
instrument.notes.append(note)
start_time = None
if start_time is not None:
end_time = len(series) / fs
note = pm.Note(
velocity=100,
pitch=pitch + 24,
start=start_time,
end=end_time
)
instrument.notes.append(note)
return pm_obj
# Function to convert reconstructed data to MIDI files
def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8):
midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs)
midi_data.write(file_name)
print(f"MIDI file saved to {file_name}")
return file_name
# Streamlit interface
st.title("GRU-VAE Reconstruction Demo")
model = load_model()
# File upload
uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"])
if uploaded_file is not None:
st.write("Processing MIDI file...")
right_hand, left_hand = process_midi(uploaded_file)
if right_hand is not None and left_hand is not None:
# Display original data
st.write("Original MIDI Data:")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(right_hand, aspect='auto', cmap='gray')
axs[0].set_title("Right Hand")
axs[1].imshow(left_hand, aspect='auto', cmap='gray')
axs[1].set_title("Left Hand")
st.pyplot(fig)
# Reconstruction
recon_data = reconstruct(right_hand.T, left_hand.T, model)
recon_right = recon_data[:150].T
recon_left = recon_data[150:].T
# Display reconstructed data
st.write("Reconstructed MIDI Data:")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(recon_right, aspect='auto', cmap='gray')
axs[0].set_title("Right Hand (Reconstructed)")
axs[1].imshow(recon_left, aspect='auto', cmap='gray')
axs[1].set_title("Left Hand (Reconstructed)")
st.pyplot(fig)
# Convert to MIDI and then to WAV for playback
original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8)
recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8)
# Save and play audio
original_wav_path = midi_to_wav(original_midi, "original_output.wav")
recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav")
st.write("Original MIDI Playback:")
st.audio(original_wav_path, format='audio/wav')
st.write("Reconstructed MIDI Playback:")
st.audio(recon_wav_path, format='audio/wav')
|