--- license: mit tags: - music-generation - variational-autoencoder - piano-roll - midi - pytorch --- # Music VAE — CNN pretrained on the Lakh MIDI Dataset ## Overview This is a **Convolutional Variational Autoencoder (CNN VAE)** pretrained on the [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (lmd_full, ~175k MIDI files). It was trained as part of a machine learning course assignment at Purdue University to give students a meaningful starting point for music-generation tasks. ## Input / Output Format | Property | Value | |-----------------|--------------------------------| | Input shape | `[batch, 1, 88, 32]` — float32 | | Output shape | `[batch, 1, 88, 32]` — float32 | | Pitch range | MIDI 21–108 (A0 – C8, 88 keys) | | Time resolution | 16th notes at 120 BPM | | Segment length | 2 bars = 32 timesteps | | Value range | [0, 1] (Sigmoid output) | A tensor value of `1` at position `[pitch_idx, time_step]` means that pitch `21 + pitch_idx` is active at that 16th-note time step. ## Architecture Summary ``` ENCODER Conv2d(1 → 32, k=4, s=2, p=1) + ReLU + BN → [B, 32, 44, 16] Conv2d(32 → 64, k=4, s=2, p=1) + ReLU + BN → [B, 64, 22, 8] Conv2d(64 → 128,k=4, s=2, p=1) + ReLU + BN → [B, 128, 11, 4] Conv2d(128→ 256,k=4, s=2, p=1) + ReLU + BN → [B, 256, 5, 2] Flatten → 2560 Linear → mu [B, 256] Linear → log_var [B, 256] REPARAMETERISATION z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I) DECODER Linear(256 → 2560) → Reshape [B, 256, 5, 2] ConvTranspose2d(256→128, k=4, s=2, p=1, output_padding=(1,0)) → [B, 128, 11, 4] ConvTranspose2d(128→ 64, k=4, s=2, p=1) → [B, 64, 22, 8] ConvTranspose2d( 64→ 32, k=4, s=2, p=1) → [B, 32, 44, 16] ConvTranspose2d( 32→ 1, k=4, s=2, p=1) + Sigmoid → [B, 1, 88, 32] ``` - **Latent dimension**: 256 - **Trainable parameters**: ~4.2M ## Loading the Model (Course Assignment) ```python import torch from model import MusicVAE # copy src/model.py into your project # Load checkpoint ckpt = torch.load("best_model.pt", map_location="cpu") config = ckpt["config"] model = MusicVAE(latent_dim=config["latent_dim"]) model.load_state_dict(ckpt["model_state"]) model.eval() # Generate new piano rolls device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) samples = model.sample(n=4, device=device) # [4, 1, 88, 32] # Encode a segment and reconstruct it x = ... # your [1, 1, 88, 32] piano-roll tensor x_recon, mu, log_var = model(x.to(device)) # Interpolate between two points in latent space z1 = mu[0:1] z2 = mu[1:2] # second example interp = model.interpolate(z1, z2, steps=8) ``` Also see `src/utils.py` for `pianoroll_to_midi()` and `visualize_pianoroll()`. ## Training Details - **Dataset**: Lakh MIDI Dataset (lmd_full) - **Piano roll**: 88-pitch binary, 16th-note resolution, 120 BPM normalised - **Segments**: 2 bars (32 frames), stride 1 bar (16 frames) - **Loss**: BCE reconstruction + β-annealed KL (β: 0 → 1 over 50 epochs) + free bits (λ=0.5) - **Optimizer**: Adam, lr=1e-3, ReduceLROnPlateau (patience=10, factor=0.5, min_lr=1e-5) - **Batch size**: 256 | **Epochs**: 100 | **Gradient clip**: 1.0 ## Citation If you use this model in your work, please cite the Lakh MIDI Dataset: ```bibtex @inproceedings{Raffel2016, author = {Colin Raffel}, title = {Learning-Based Methods for Comparing Sequences, with Applications to Audio-to-{MIDI} Alignment and Matching}, booktitle = {PhD Thesis, Columbia University}, year = {2016} } ```