YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
MIDI Generation Pipeline: Text-to-Music
Complete pipelines for training and inference of text-conditioned MIDI generation using both GPT2-style and Qwen3-based autoregressive models.
Two Architectures
1. GPT2-Style (train_midi_gpt.py)
- From-scratch GPT2 model with custom vocabulary
- ~50M parameters (configurable)
- Fast training, good for experimentation
2. Qwen3-0.6B (train_midi_qwen3.py) ⭐ Recommended
- Pretrained LLM with vocabulary expansion (inspired by MIDI-LLM)
- 751M parameters with rich text understanding
- Tied embeddings automatically handled
- Apache-2.0 license
Files
| File | Purpose |
|---|---|
prepare_dataset.py |
Preprocess for GPT2 pipeline |
prepare_dataset_qwen3.py |
Preprocess for Qwen3 pipeline (rich prompts) |
train_midi_gpt.py |
Train GPT2-style model |
train_midi_qwen3.py |
Fine-tune Qwen3-0.6B with MIDI vocab expansion |
inference_midi_gpt.py |
Generate MIDI with GPT2 model |
inference_midi_qwen3.py |
Generate MIDI with Qwen3 model (for local checkpoints) |
inference_trained_model.py |
Generate MIDI with trained model from HF Hub |
create_synthetic_dataset.py |
Generate synthetic test data |
test_end_to_end.py |
Validate GPT2 pipeline |
test_qwen3_e2e.py |
Validate Qwen3 pipeline |
run_qwen3_training.py |
One-command GPU training script |
Trained Model Available
A trained Qwen3-0.6B model with expanded MIDI vocabulary (152,188 tokens total).
Quick Inference with Trained Model
# Install dependencies
pip install transformers torch datasets miditok miditoolkit accelerate
# Generate MIDI from a text prompt (simplest way)
python inference_trained_model.py \
--model_id rahuldshetty/midi-qwen3-v1 \
--prompt "A dark electronic piece with synth strings in D minor, 140 BPM" \
--output_path my_song.mid \
--max_midi_tokens 1024 \
--temperature 1.0 \
--top_k 50 \
--top_p 0.92
# Use a random prompt from the dataset
python inference_trained_model.py \
--model_id rahuldshetty/midi-qwen3-v1 \
--dataset_prompt \
--num_samples 3 \
--output_path generated.mid
# Generate multiple variations
python inference_trained_model.py \
--model_id rahuldshetty/midi-qwen3-v1 \
--prompt "A cheerful piano piece in C major, 120 BPM" \
--num_samples 5 \
--temperature 0.8 \
--output_path song.mid
Python API for Inference
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
from miditok import REMI
from pathlib import Path
import json, tempfile, shutil
# 1. Download model
model_id = "rahuldshetty/midi-qwen3-v1"
temp_dir = Path(tempfile.mkdtemp())
snapshot_download(repo_id=model_id, local_dir=str(temp_dir))
# 2. Load and expand tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
with open(temp_dir / "config.json") as f:
config = json.load(f)
expanded_vocab = config["vocab_size"] # 152188
n_added = expanded_vocab - len(tokenizer) # 519 tokens added
midi_vocab_size = n_added - 3 # 516 MIDI tokens
# Add tokens
midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)]
tokenizer.add_tokens(["<|midi_start|>", "<|midi_end|>", "<|midi_pad|>"] + midi_tokens)
# 3. Load model
model = AutoModelForCausalLM.from_pretrained(
str(temp_dir),
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()
# 4. Build prompt and tokenize
prompt = "A cheerful jazz piece with piano and saxophone in C major, 120 BPM"
text_ids = tokenizer.encode(prompt, add_special_tokens=False)
# Special tokens
bos_midi = tokenizer.convert_tokens_to_ids("<|midi_start|>")
eos_midi = tokenizer.convert_tokens_to_ids("<|midi_end|>")
pad_midi = tokenizer.convert_tokens_to_ids("<|midi_pad|>")
midi_offset = 151936 + 3 # original_vocab + num_special
# 5. Generate MIDI tokens
input_tensor = torch.tensor([text_ids + [bos_midi]], dtype=torch.long, device=model.device)
with torch.no_grad():
output = model.generate(
input_tensor,
max_new_tokens=512,
do_sample=True,
temperature=1.0,
top_k=50,
top_p=0.92,
pad_token_id=pad_midi,
eos_token_id=eos_midi,
)
# 6. Extract and decode MIDI tokens
generated = output[0].tolist()
bos_idx = generated.index(bos_midi)
midi_raw = generated[bos_idx + 1:]
midi_raw = [t for t in midi_raw if t not in (eos_midi, pad_midi)]
midi_ids = [t - midi_offset for t in midi_raw if t >= midi_offset]
midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size]
# 7. Decode to MIDI file
midi_tokenizer = REMI(params=str(temp_dir / "midi_tokenizer_init"))
midi = midi_tokenizer.decode(midi_ids)
midi.dump("output.mid")
print(f"Generated {len(midi_ids)} MIDI tokens → output.mid")
# Cleanup
shutil.rmtree(temp_dir, ignore_errors=True)
Dataset
Processed dataset: rahuldshetty/midi-generation-dataset
Source: B-K/midi-dataset-2 (MidiCaps with MIDI bytes)
Quick Start: Train Your Own
1. Install Dependencies
pip install transformers torch datasets miditok miditoolkit accelerate
2. Prepare Dataset
python prepare_dataset_qwen3.py \
--dataset B-K/midi-dataset-2 \
--output_dir ./midi_data_qwen3 \
--max_seq_len 2048
3. Train
python train_midi_qwen3.py \
--dataset rahuldshetty/midi-generation-dataset \
--output_dir ./midi_qwen3_model \
--num_epochs 20 \
--batch_size 2 \
--gradient_accumulation_steps 8 \
--bf16 \
--gradient_checkpointing \
--push_to_hub \
--hub_model_id yourname/midi-qwen3-v1
Or use the one-command script:
python run_qwen3_training.py
4. Generate MIDI
python inference_midi_qwen3.py \
--model_dir ./midi_qwen3_model/final \
--prompt "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" \
--output_path output.mid \
--max_midi_tokens 1024
Qwen3 Architecture Details
Vocabulary Expansion
- Qwen3 base vocab: 151,936 tokens
- MIDI special tokens:
<|midi_start|>,<|midi_end|>,<|midi_pad|> - MIDI vocab tokens:
<|midi_0|>...<|midi_515|>(REMI tokenization) - Total vocab: ~152,455
Training Labels
- Text prefix →
-100(not trained on) - MIDI tokens + special tokens → actual IDs
- Model learns only music generation
Rich Prompt Format
You are a world-class composer. Please compose some music according to the following description:
Description: [caption]
Genre: [genre]
Mood: [mood]
Key: [key]
Time Signature: [time_signature]
Tempo: [tempo] BPM
Duration: [duration] seconds
Instruments: [instruments]
Chords: [chords]
Recommended Datasets
| Dataset | Link | Description |
|---|---|---|
| B-K/midi-dataset-2 | HF | Best - rich metadata + MIDI bytes |
| amaai-lab/MidiCaps | HF | 168K captions (no MIDI bytes) |
| foldl/midi | HF | Name + genre + MIDI bytes |
Hardware Recommendations
| Model | GPU | Batch | Notes |
|---|---|---|---|
| GPT2 (50M) | t4-small | 4 | Fast experimentation |
| Qwen3-0.6B | a10g-large | 2 | Enable gradient_checkpointing |
| Qwen3-0.6B | a100-large | 4 | Full training |
SOTA References
- MIDI-LLM (Wu et al., 2025): LLM vocab expansion for MIDI
- MIDI-GPT (Pasquier et al., 2025): GPT2 for MIDI
- text2midi (Bhandari et al., AAAI 2025): T5 encoder + decoder
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support