| --- |
| license: mit |
| tags: |
| - music-generation |
| - variational-autoencoder |
| - piano-roll |
| - midi |
| - pytorch |
| --- |
| |
| # Music VAE β CNN pretrained on the Lakh MIDI Dataset |
|
|
| ## Overview |
|
|
| This is a **Convolutional Variational Autoencoder (CNN VAE)** pretrained on the |
| [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (lmd_full, ~175k MIDI files). |
| |
| It was trained as part of a machine learning course assignment at Purdue University |
| to give students a meaningful starting point for music-generation tasks. |
| |
| ## Input / Output Format |
| |
| | Property | Value | |
| |-----------------|--------------------------------| |
| | Input shape | `[batch, 1, 88, 32]` β float32 | |
| | Output shape | `[batch, 1, 88, 32]` β float32 | |
| | Pitch range | MIDI 21β108 (A0 β C8, 88 keys) | |
| | Time resolution | 16th notes at 120 BPM | |
| | Segment length | 2 bars = 32 timesteps | |
| | Value range | [0, 1] (Sigmoid output) | |
| |
| A tensor value of `1` at position `[pitch_idx, time_step]` means that pitch |
| `21 + pitch_idx` is active at that 16th-note time step. |
|
|
| ## Architecture Summary |
|
|
| ``` |
| ENCODER |
| Conv2d(1 β 32, k=4, s=2, p=1) + ReLU + BN β [B, 32, 44, 16] |
| Conv2d(32 β 64, k=4, s=2, p=1) + ReLU + BN β [B, 64, 22, 8] |
| Conv2d(64 β 128,k=4, s=2, p=1) + ReLU + BN β [B, 128, 11, 4] |
| Conv2d(128β 256,k=4, s=2, p=1) + ReLU + BN β [B, 256, 5, 2] |
| Flatten β 2560 |
| Linear β mu [B, 256] |
| Linear β log_var [B, 256] |
| |
| REPARAMETERISATION |
| z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I) |
| |
| DECODER |
| Linear(256 β 2560) β Reshape [B, 256, 5, 2] |
| ConvTranspose2d(256β128, k=4, s=2, p=1, output_padding=(1,0)) β [B, 128, 11, 4] |
| ConvTranspose2d(128β 64, k=4, s=2, p=1) β [B, 64, 22, 8] |
| ConvTranspose2d( 64β 32, k=4, s=2, p=1) β [B, 32, 44, 16] |
| ConvTranspose2d( 32β 1, k=4, s=2, p=1) + Sigmoid β [B, 1, 88, 32] |
| ``` |
|
|
| - **Latent dimension**: 256 |
| - **Trainable parameters**: ~4.2M |
|
|
| ## Loading the Model (Course Assignment) |
|
|
| ```python |
| import torch |
| from model import MusicVAE # copy src/model.py into your project |
| |
| # Load checkpoint |
| ckpt = torch.load("best_model.pt", map_location="cpu") |
| config = ckpt["config"] |
| |
| model = MusicVAE(latent_dim=config["latent_dim"]) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| |
| # Generate new piano rolls |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| samples = model.sample(n=4, device=device) # [4, 1, 88, 32] |
| |
| # Encode a segment and reconstruct it |
| x = ... # your [1, 1, 88, 32] piano-roll tensor |
| x_recon, mu, log_var = model(x.to(device)) |
| |
| # Interpolate between two points in latent space |
| z1 = mu[0:1] |
| z2 = mu[1:2] # second example |
| interp = model.interpolate(z1, z2, steps=8) |
| ``` |
|
|
| Also see `src/utils.py` for `pianoroll_to_midi()` and `visualize_pianoroll()`. |
|
|
| ## Training Details |
|
|
| - **Dataset**: Lakh MIDI Dataset (lmd_full) |
| - **Piano roll**: 88-pitch binary, 16th-note resolution, 120 BPM normalised |
| - **Segments**: 2 bars (32 frames), stride 1 bar (16 frames) |
| - **Loss**: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β 1 over 50 epochs) + free bits (Ξ»=0.5) |
| - **Optimizer**: Adam, lr=1e-3, ReduceLROnPlateau (patience=10, factor=0.5, min_lr=1e-5) |
| - **Batch size**: 256 | **Epochs**: 100 | **Gradient clip**: 1.0 |
|
|
| ## Citation |
|
|
| If you use this model in your work, please cite the Lakh MIDI Dataset: |
|
|
| ```bibtex |
| @inproceedings{Raffel2016, |
| author = {Colin Raffel}, |
| title = {Learning-Based Methods for Comparing Sequences, with Applications |
| to Audio-to-{MIDI} Alignment and Matching}, |
| booktitle = {PhD Thesis, Columbia University}, |
| year = {2016} |
| } |
| ``` |
|
|