File size: 6,842 Bytes
3b38d19
081442b
 
 
b672e6e
081442b
b672e6e
081442b
b672e6e
081442b
b672e6e
 
081442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b672e6e
081442b
 
 
 
 
 
b672e6e
 
 
081442b
b672e6e
081442b
b672e6e
081442b
 
b672e6e
081442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b672e6e
081442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b38d19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import torch
import numpy as np
import matplotlib.pyplot as plt

import pretty_midi as pm

from VAE import VAE

import pretty_midi as pm
from scipy.io.wavfile import write



# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load VAE model
@st.cache_resource
def load_model():
    vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256)
    vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device))
    vae = vae.to(device)
    vae.eval()
    return vae

# Function to process the uploaded MIDI file
def process_midi(file):
    try:
        mid = pm.PrettyMIDI(file)
        fs = 10
        hand_dict = {"right": None, "left": None}
        pitch_list = []

        for inst in mid.instruments:
            if inst.program // 8 > 0:
                continue
            
            piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
            if np.sum(piano_roll) == 0:
                continue
            hand_pitch = np.where(piano_roll)
            pitch_list.append(np.average(hand_pitch[0]))

        if len(pitch_list) == 0:
            st.error("No valid piano data found.")
            return None, None
        elif len(pitch_list) == 1:
            hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
            hand_dict['left'] = np.zeros_like(hand_dict['right'])
        else:
            hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
            hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs))
        
        pitch_start, pitch_stop, duration = 24, 100, 150
        right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration]
        left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration]
        
        if np.sum(right_hand) == 0 or np.sum(left_hand) == 0:
            st.error("Invalid data detected in MIDI file.")
            return None, None

        return right_hand, left_hand
    except Exception as e:
        st.error(f"Error processing MIDI: {e}")
        return None, None

# Run the VAE model for reconstruction
def reconstruct(right, left, model):
    right_tensor = torch.tensor(right, dtype=torch.float32).to(device)  
    left_tensor = torch.tensor(left, dtype=torch.float32).to(device)    
    
    input_tensor = torch.cat([right_tensor, left_tensor], dim=0) 
    input_tensor = input_tensor.unsqueeze(0) 
    

    with torch.no_grad():
        recon_data, _, _, _ = model(input_tensor)

    return recon_data.squeeze(0).cpu().numpy()


def midi_to_wav(midi_file, wav_file="output.wav", volume_increase_db=17):
    midi_data = pm.PrettyMIDI(midi_file)
    audio_data = midi_data.synthesize(fs=44100)  

    audio_data = np.int16(audio_data / np.max(np.abs(audio_data)) * 32767 * 0.9) 

    write(wav_file, 44100, audio_data)
    return wav_file


# Create a MIDI stream from piano roll data
def create_midi_from_piano_roll(right_hand, left_hand, fs=8):
    pm_obj = pm.PrettyMIDI()
    
    for hand_data in [right_hand, left_hand]:
        instrument = pm.Instrument(program=0)  # Acoustic Grand Piano
        pm_obj.instruments.append(instrument)
        
        for pitch, series in enumerate(hand_data):
            start_time = None
            threshold = 0.92  # Threshold for detecting note onset
            
            for j in range(len(series) - 1):
                if series[j] < threshold and series[j + 1] >= threshold:  
                    start_time = j / fs
                elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None: 
                    end_time = (j + 1) / fs

                    if start_time is not None and end_time is not None:
                        note = pm.Note(
                            velocity=100, 
                            pitch=pitch + 24,  
                            start=start_time,
                            end=end_time
                        )
                        instrument.notes.append(note)
                    start_time = None
            
            if start_time is not None:
                end_time = len(series) / fs
                note = pm.Note(
                    velocity=100, 
                    pitch=pitch + 24, 
                    start=start_time,
                    end=end_time
                )
                instrument.notes.append(note)
    
    return pm_obj


# Function to convert reconstructed data to MIDI files
def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8):
    midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs)
    midi_data.write(file_name)
    
    print(f"MIDI file saved to {file_name}")
    return file_name


# Streamlit interface
st.title("GRU-VAE Reconstruction Demo")
model = load_model()
    

# File upload
uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"])

if uploaded_file is not None:
    st.write("Processing MIDI file...")
    right_hand, left_hand = process_midi(uploaded_file)

    if right_hand is not None and left_hand is not None:
        # Display original data
        st.write("Original MIDI Data:")
        fig, axs = plt.subplots(1, 2, figsize=(10, 4))
        axs[0].imshow(right_hand, aspect='auto', cmap='gray')
        axs[0].set_title("Right Hand")
        axs[1].imshow(left_hand, aspect='auto', cmap='gray')
        axs[1].set_title("Left Hand")
        st.pyplot(fig)

        # Reconstruction
        recon_data = reconstruct(right_hand.T, left_hand.T, model)
        recon_right = recon_data[:150].T
        recon_left = recon_data[150:].T

        # Display reconstructed data
        st.write("Reconstructed MIDI Data:")
        fig, axs = plt.subplots(1, 2, figsize=(10, 4))
        axs[0].imshow(recon_right, aspect='auto', cmap='gray')
        axs[0].set_title("Right Hand (Reconstructed)")
        axs[1].imshow(recon_left, aspect='auto', cmap='gray')
        axs[1].set_title("Left Hand (Reconstructed)")
        st.pyplot(fig)

        # Convert to MIDI and then to WAV for playback
        original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8)
        recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8)

        # Save and play audio
        original_wav_path = midi_to_wav(original_midi, "original_output.wav")
        recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav")

        st.write("Original MIDI Playback:")
        st.audio(original_wav_path, format='audio/wav')

        st.write("Reconstructed MIDI Playback:")
        st.audio(recon_wav_path, format='audio/wav')