YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

MIDI Generation Pipeline: Text-to-Music

Complete pipelines for training and inference of text-conditioned MIDI generation using both GPT2-style and Qwen3-based autoregressive models.

Two Architectures

1. GPT2-Style (train_midi_gpt.py)

  • From-scratch GPT2 model with custom vocabulary
  • ~50M parameters (configurable)
  • Fast training, good for experimentation

2. Qwen3-0.6B (train_midi_qwen3.py) ⭐ Recommended

  • Pretrained LLM with vocabulary expansion (inspired by MIDI-LLM)
  • 751M parameters with rich text understanding
  • Tied embeddings automatically handled
  • Apache-2.0 license

Files

File Purpose
prepare_dataset.py Preprocess for GPT2 pipeline
prepare_dataset_qwen3.py Preprocess for Qwen3 pipeline (rich prompts)
train_midi_gpt.py Train GPT2-style model
train_midi_qwen3.py Fine-tune Qwen3-0.6B with MIDI vocab expansion
inference_midi_gpt.py Generate MIDI with GPT2 model
inference_midi_qwen3.py Generate MIDI with Qwen3 model (for local checkpoints)
inference_trained_model.py Generate MIDI with trained model from HF Hub
create_synthetic_dataset.py Generate synthetic test data
test_end_to_end.py Validate GPT2 pipeline
test_qwen3_e2e.py Validate Qwen3 pipeline
run_qwen3_training.py One-command GPU training script

Trained Model Available

rahuldshetty/midi-qwen3-v1

A trained Qwen3-0.6B model with expanded MIDI vocabulary (152,188 tokens total).

Quick Inference with Trained Model

# Install dependencies
pip install transformers torch datasets miditok miditoolkit accelerate

# Generate MIDI from a text prompt (simplest way)
python inference_trained_model.py \
    --model_id rahuldshetty/midi-qwen3-v1 \
    --prompt "A dark electronic piece with synth strings in D minor, 140 BPM" \
    --output_path my_song.mid \
    --max_midi_tokens 1024 \
    --temperature 1.0 \
    --top_k 50 \
    --top_p 0.92

# Use a random prompt from the dataset
python inference_trained_model.py \
    --model_id rahuldshetty/midi-qwen3-v1 \
    --dataset_prompt \
    --num_samples 3 \
    --output_path generated.mid

# Generate multiple variations
python inference_trained_model.py \
    --model_id rahuldshetty/midi-qwen3-v1 \
    --prompt "A cheerful piano piece in C major, 120 BPM" \
    --num_samples 5 \
    --temperature 0.8 \
    --output_path song.mid

Python API for Inference

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
from miditok import REMI
from pathlib import Path
import json, tempfile, shutil

# 1. Download model
model_id = "rahuldshetty/midi-qwen3-v1"
temp_dir = Path(tempfile.mkdtemp())
snapshot_download(repo_id=model_id, local_dir=str(temp_dir))

# 2. Load and expand tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
with open(temp_dir / "config.json") as f:
    config = json.load(f)
expanded_vocab = config["vocab_size"]  # 152188
n_added = expanded_vocab - len(tokenizer)  # 519 tokens added
midi_vocab_size = n_added - 3  # 516 MIDI tokens

# Add tokens
midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)]
tokenizer.add_tokens(["<|midi_start|>", "<|midi_end|>", "<|midi_pad|>"] + midi_tokens)

# 3. Load model
model = AutoModelForCausalLM.from_pretrained(
    str(temp_dir),
    trust_remote_code=True,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()

# 4. Build prompt and tokenize
prompt = "A cheerful jazz piece with piano and saxophone in C major, 120 BPM"
text_ids = tokenizer.encode(prompt, add_special_tokens=False)

# Special tokens
bos_midi = tokenizer.convert_tokens_to_ids("<|midi_start|>")
eos_midi = tokenizer.convert_tokens_to_ids("<|midi_end|>")
pad_midi = tokenizer.convert_tokens_to_ids("<|midi_pad|>")
midi_offset = 151936 + 3  # original_vocab + num_special

# 5. Generate MIDI tokens
input_tensor = torch.tensor([text_ids + [bos_midi]], dtype=torch.long, device=model.device)
with torch.no_grad():
    output = model.generate(
        input_tensor,
        max_new_tokens=512,
        do_sample=True,
        temperature=1.0,
        top_k=50,
        top_p=0.92,
        pad_token_id=pad_midi,
        eos_token_id=eos_midi,
    )

# 6. Extract and decode MIDI tokens
generated = output[0].tolist()
bos_idx = generated.index(bos_midi)
midi_raw = generated[bos_idx + 1:]
midi_raw = [t for t in midi_raw if t not in (eos_midi, pad_midi)]
midi_ids = [t - midi_offset for t in midi_raw if t >= midi_offset]
midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size]

# 7. Decode to MIDI file
midi_tokenizer = REMI(params=str(temp_dir / "midi_tokenizer_init"))
midi = midi_tokenizer.decode(midi_ids)
midi.dump("output.mid")
print(f"Generated {len(midi_ids)} MIDI tokens → output.mid")

# Cleanup
shutil.rmtree(temp_dir, ignore_errors=True)

Dataset

Processed dataset: rahuldshetty/midi-generation-dataset

Source: B-K/midi-dataset-2 (MidiCaps with MIDI bytes)

Quick Start: Train Your Own

1. Install Dependencies

pip install transformers torch datasets miditok miditoolkit accelerate

2. Prepare Dataset

python prepare_dataset_qwen3.py \
    --dataset B-K/midi-dataset-2 \
    --output_dir ./midi_data_qwen3 \
    --max_seq_len 2048

3. Train

python train_midi_qwen3.py \
    --dataset rahuldshetty/midi-generation-dataset \
    --output_dir ./midi_qwen3_model \
    --num_epochs 20 \
    --batch_size 2 \
    --gradient_accumulation_steps 8 \
    --bf16 \
    --gradient_checkpointing \
    --push_to_hub \
    --hub_model_id yourname/midi-qwen3-v1

Or use the one-command script:

python run_qwen3_training.py

4. Generate MIDI

python inference_midi_qwen3.py \
    --model_dir ./midi_qwen3_model/final \
    --prompt "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" \
    --output_path output.mid \
    --max_midi_tokens 1024

Qwen3 Architecture Details

Vocabulary Expansion

  • Qwen3 base vocab: 151,936 tokens
  • MIDI special tokens: <|midi_start|>, <|midi_end|>, <|midi_pad|>
  • MIDI vocab tokens: <|midi_0|> ... <|midi_515|> (REMI tokenization)
  • Total vocab: ~152,455

Training Labels

  • Text prefix → -100 (not trained on)
  • MIDI tokens + special tokens → actual IDs
  • Model learns only music generation

Rich Prompt Format

You are a world-class composer. Please compose some music according to the following description:
Description: [caption]
Genre: [genre]
Mood: [mood]
Key: [key]
Time Signature: [time_signature]
Tempo: [tempo] BPM
Duration: [duration] seconds
Instruments: [instruments]
Chords: [chords]

Recommended Datasets

Dataset Link Description
B-K/midi-dataset-2 HF Best - rich metadata + MIDI bytes
amaai-lab/MidiCaps HF 168K captions (no MIDI bytes)
foldl/midi HF Name + genre + MIDI bytes

Hardware Recommendations

Model GPU Batch Notes
GPT2 (50M) t4-small 4 Fast experimentation
Qwen3-0.6B a10g-large 2 Enable gradient_checkpointing
Qwen3-0.6B a100-large 4 Full training

SOTA References

  • MIDI-LLM (Wu et al., 2025): LLM vocab expansion for MIDI
  • MIDI-GPT (Pasquier et al., 2025): GPT2 for MIDI
  • text2midi (Bhandari et al., AAAI 2025): T5 encoder + decoder
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support