music-vae-lmd / README.md
XR-Lin's picture
Upload Music VAE checkpoint (epoch 99)
02d8234 verified
---
license: mit
tags:
- music-generation
- variational-autoencoder
- piano-roll
- midi
- pytorch
---
# Music VAE β€” CNN pretrained on the Lakh MIDI Dataset
## Overview
This is a **Convolutional Variational Autoencoder (CNN VAE)** pretrained on the
[Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (lmd_full, ~175k MIDI files).
It was trained as part of a machine learning course assignment at Purdue University
to give students a meaningful starting point for music-generation tasks.
## Input / Output Format
| Property | Value |
|-----------------|--------------------------------|
| Input shape | `[batch, 1, 88, 32]` β€” float32 |
| Output shape | `[batch, 1, 88, 32]` β€” float32 |
| Pitch range | MIDI 21–108 (A0 – C8, 88 keys) |
| Time resolution | 16th notes at 120 BPM |
| Segment length | 2 bars = 32 timesteps |
| Value range | [0, 1] (Sigmoid output) |
A tensor value of `1` at position `[pitch_idx, time_step]` means that pitch
`21 + pitch_idx` is active at that 16th-note time step.
## Architecture Summary
```
ENCODER
Conv2d(1 β†’ 32, k=4, s=2, p=1) + ReLU + BN β†’ [B, 32, 44, 16]
Conv2d(32 β†’ 64, k=4, s=2, p=1) + ReLU + BN β†’ [B, 64, 22, 8]
Conv2d(64 β†’ 128,k=4, s=2, p=1) + ReLU + BN β†’ [B, 128, 11, 4]
Conv2d(128β†’ 256,k=4, s=2, p=1) + ReLU + BN β†’ [B, 256, 5, 2]
Flatten β†’ 2560
Linear β†’ mu [B, 256]
Linear β†’ log_var [B, 256]
REPARAMETERISATION
z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I)
DECODER
Linear(256 β†’ 2560) β†’ Reshape [B, 256, 5, 2]
ConvTranspose2d(256β†’128, k=4, s=2, p=1, output_padding=(1,0)) β†’ [B, 128, 11, 4]
ConvTranspose2d(128β†’ 64, k=4, s=2, p=1) β†’ [B, 64, 22, 8]
ConvTranspose2d( 64β†’ 32, k=4, s=2, p=1) β†’ [B, 32, 44, 16]
ConvTranspose2d( 32β†’ 1, k=4, s=2, p=1) + Sigmoid β†’ [B, 1, 88, 32]
```
- **Latent dimension**: 256
- **Trainable parameters**: ~4.2M
## Loading the Model (Course Assignment)
```python
import torch
from model import MusicVAE # copy src/model.py into your project
# Load checkpoint
ckpt = torch.load("best_model.pt", map_location="cpu")
config = ckpt["config"]
model = MusicVAE(latent_dim=config["latent_dim"])
model.load_state_dict(ckpt["model_state"])
model.eval()
# Generate new piano rolls
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
samples = model.sample(n=4, device=device) # [4, 1, 88, 32]
# Encode a segment and reconstruct it
x = ... # your [1, 1, 88, 32] piano-roll tensor
x_recon, mu, log_var = model(x.to(device))
# Interpolate between two points in latent space
z1 = mu[0:1]
z2 = mu[1:2] # second example
interp = model.interpolate(z1, z2, steps=8)
```
Also see `src/utils.py` for `pianoroll_to_midi()` and `visualize_pianoroll()`.
## Training Details
- **Dataset**: Lakh MIDI Dataset (lmd_full)
- **Piano roll**: 88-pitch binary, 16th-note resolution, 120 BPM normalised
- **Segments**: 2 bars (32 frames), stride 1 bar (16 frames)
- **Loss**: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β†’ 1 over 50 epochs) + free bits (Ξ»=0.5)
- **Optimizer**: Adam, lr=1e-3, ReduceLROnPlateau (patience=10, factor=0.5, min_lr=1e-5)
- **Batch size**: 256 | **Epochs**: 100 | **Gradient clip**: 1.0
## Citation
If you use this model in your work, please cite the Lakh MIDI Dataset:
```bibtex
@inproceedings{Raffel2016,
author = {Colin Raffel},
title = {Learning-Based Methods for Comparing Sequences, with Applications
to Audio-to-{MIDI} Alignment and Matching},
booktitle = {PhD Thesis, Columbia University},
year = {2016}
}
```