Nekochu's picture
Init
a554f96 verified
"""
Multitrack MIDI Composer - Combined MIDI generation tools
- Simple MIDI Composer: Demo mode chord progressions
- Multitrack Generator: AI multi-instrument composition with genre selection
CPU-only HuggingFace Space
"""
import os
import sys
import tempfile
import argparse
import struct
import wave
from typing import List, Tuple, Optional
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import gradio as gr
import numpy as np
# Import meltysynth for pure Python MIDI-to-audio synthesis
try:
import meltysynth as ms
MELTYSYNTH_AVAILABLE = True
except ImportError:
MELTYSYNTH_AVAILABLE = False
print("meltysynth not available - Audio playback disabled")
# Path to SoundFont file
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SOUNDFONT_PATH = os.path.join(SCRIPT_DIR, "TimGM6mb.sf2")
# Global synthesizer (loaded once)
_synthesizer = None
_synth_settings = None
def get_synthesizer():
"""Load the SoundFont synthesizer (cached)."""
global _synthesizer, _synth_settings
if _synthesizer is None and MELTYSYNTH_AVAILABLE:
if os.path.exists(SOUNDFONT_PATH):
print(f"Loading SoundFont: {SOUNDFONT_PATH}")
sound_font = ms.SoundFont.from_file(SOUNDFONT_PATH)
_synth_settings = ms.SynthesizerSettings(44100)
_synthesizer = ms.Synthesizer(sound_font, _synth_settings)
print("SoundFont loaded successfully!")
else:
print(f"SoundFont not found: {SOUNDFONT_PATH}")
return _synthesizer, _synth_settings
def render_midi_to_audio(midi_path: str) -> Optional[str]:
"""Render a MIDI file to WAV audio using meltysynth."""
if not MELTYSYNTH_AVAILABLE:
return None
synth, settings = get_synthesizer()
if synth is None:
return None
try:
# Load MIDI file
midi_file = ms.MidiFile.from_file(midi_path)
sequencer = ms.MidiFileSequencer(synth)
sequencer.play(midi_file, False)
# Calculate buffer size (duration + 1 second for tail)
duration = midi_file.length + 1.0
buffer_length = int(settings.sample_rate * duration)
# Create buffers and render
left = ms.create_buffer(buffer_length)
right = ms.create_buffer(buffer_length)
sequencer.render(left, right)
# Convert to interleaved stereo int16
samples = []
for i in range(buffer_length):
# Clamp and convert to int16
l_sample = max(-1.0, min(1.0, left[i]))
r_sample = max(-1.0, min(1.0, right[i]))
samples.append(int(l_sample * 32767))
samples.append(int(r_sample * 32767))
# Write WAV file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
wav_path = f.name
with wave.open(wav_path, 'w') as wav_file:
wav_file.setnchannels(2)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(settings.sample_rate)
wav_file.writeframes(struct.pack(f'<{len(samples)}h', *samples))
return wav_path
except Exception as e:
print(f"Audio render error: {e}")
return None
# =============================================================================
# Tab 1: Simple MIDI Composer (Demo Mode)
# =============================================================================
try:
from midiutil import MIDIFile
MIDIUTIL_AVAILABLE = True
except ImportError:
MIDIUTIL_AVAILABLE = False
print("midiutil not available - Demo Composer disabled")
def create_piano_roll(notes_data, total_time, title="Piano Roll"):
"""Create a piano roll visualization and save as PNG image."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
fig, ax = plt.subplots(figsize=(12, 6))
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8']
for i, (start, end, pitch, track) in enumerate(notes_data):
color = colors[track % len(colors)]
rect = Rectangle((start, pitch - 0.4), end - start, 0.8,
facecolor=color, edgecolor='black', linewidth=0.5, alpha=0.8)
ax.add_patch(rect)
ax.set_xlim(0, total_time)
ax.set_ylim(40, 90)
ax.set_xlabel('Time (beats)')
ax.set_ylabel('MIDI Pitch')
ax.set_title(title)
ax.grid(True, alpha=0.3)
# Save to PNG file
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
img_path = f.name
fig.savefig(img_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return img_path
def create_demo_midi(tempo: int = 120, length_bars: int = 4):
"""Create demo MIDI with chord progression and melody, with visualization and audio."""
if not MIDIUTIL_AVAILABLE:
return None, None, None, "midiutil not installed"
try:
midi = MIDIFile(2) # Two tracks: melody and chords
midi.addTempo(0, 0, tempo)
midi.addProgramChange(0, 0, 0, 0) # Piano for melody
midi.addProgramChange(1, 0, 0, 0) # Piano for chords
# Simple chord progression (C - Am - F - G)
chords = [
[60, 64, 67], # C major
[57, 60, 64], # A minor
[53, 57, 60], # F major
[55, 59, 62], # G major
]
# Melody notes
melody_notes = [
72, 74, 76, 77, 76, 74, 72, 71,
69, 71, 72, 74, 72, 71, 69, 67,
65, 67, 69, 71, 72, 71, 69, 67,
67, 69, 71, 72, 74, 76, 77, 79,
]
# Collect notes for visualization
notes_data = []
for bar in range(length_bars):
bar_time = bar * 4 # In beats
chord_idx = bar % len(chords)
# Add chord notes
for note in chords[chord_idx]:
midi.addNote(1, 0, note, bar_time, 4, 60)
notes_data.append((bar_time, bar_time + 4, note, 1))
# Add melody notes (8 per bar)
for i in range(8):
note_time = bar_time + (i * 0.5)
note_idx = (bar * 8 + i) % len(melody_notes)
midi.addNote(0, 0, melody_notes[note_idx], note_time, 0.4, 90)
notes_data.append((note_time, note_time + 0.4, melody_notes[note_idx], 0))
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f:
midi.writeFile(f)
midi_path = f.name
# Create visualization
total_time = length_bars * 4
fig = create_piano_roll(notes_data, total_time, f"Demo: {length_bars} bars at {tempo} BPM")
# Render audio
audio_path = render_midi_to_audio(midi_path)
status = f"Created: {length_bars} bars at {tempo} BPM"
if audio_path:
status += " - Audio rendered!"
else:
status += " - Download MIDI to play"
return midi_path, fig, audio_path, status
except Exception as e:
return None, None, None, f"Error: {str(e)}"
# =============================================================================
# Tab 2: Multitrack Generator (Transformer-based)
# =============================================================================
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from note_seq.protobuf.music_pb2 import NoteSequence
from note_seq.constants import STANDARD_PPQ
from matplotlib.figure import Figure
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
print("PyTorch/Transformers not available - Multitrack Generator disabled")
SAMPLE_RATE = 44100
GM_INSTRUMENTS = [
"Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano",
"Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord",
"Clavi", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba",
"Xylophone", "Tubular Bells", "Dulcimer", "Drawbar Organ", "Percussive Organ",
"Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica",
"Tango Accordion", "Acoustic Guitar (nylon)", "Acoustic Guitar (steel)",
"Electric Guitar (jazz)", "Electric Guitar (clean)", "Electric Guitar (muted)",
"Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics", "Acoustic Bass",
"Electric Bass (finger)", "Electric Bass (pick)", "Fretless Bass", "Slap Bass 1",
"Slap Bass 2", "Synth Bass 1", "Synth Bass 2", "Violin", "Viola", "Cello",
"Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp",
"Timpani", "String Ensemble 1", "String Ensemble 2", "Synth Strings 1",
"Synth Strings 2", "Choir Aahs", "Voice Oohs", "Synth Choir", "Orchestra Hit",
"Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section",
"Synth Brass 1", "Synth Brass 2", "Soprano Sax", "Alto Sax", "Tenor Sax",
"Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet", "Piccolo",
"Flute", "Recorder", "Pan Flute", "Blown Bottle", "Shakuhachi", "Whistle",
"Ocarina", "Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)",
"Lead 4 (chiff)", "Lead 5 (charang)", "Lead 6 (voice)", "Lead 7 (fifths)",
"Lead 8 (bass + lead)", "Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)",
"Pad 4 (choir)", "Pad 5 (bowed)", "Pad 6 (metallic)", "Pad 7 (halo)",
"Pad 8 (sweep)", "FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)",
"FX 4 (atmosphere)", "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)",
"FX 8 (sci-fi)", "Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe",
"Fiddle", "Shanai", "Tinkle Bell", "Agogo", "Steel Drums", "Woodblock",
"Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal", "Guitar Fret Noise",
"Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter",
"Applause", "Gunshot",
]
# Global model and tokenizer
device = None
tokenizer = None
model = None
def get_model_and_tokenizer():
"""Load model and tokenizer on CPU."""
global model, tokenizer, device
if model is None or tokenizer is None:
device = torch.device("cpu")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
print("Loading model on CPU...")
model = AutoModelForCausalLM.from_pretrained("juancopi81/lmd-8bars-2048-epochs40_v4")
model = model.to(device)
model.eval()
print("Model loaded successfully!")
return model, tokenizer
def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> "NoteSequence":
"""Create an empty note sequence."""
note_sequence = NoteSequence()
note_sequence.tempos.add().qpm = qpm
note_sequence.ticks_per_quarter = STANDARD_PPQ
note_sequence.total_time = total_time
return note_sequence
def token_sequence_to_note_sequence(
token_sequence: str,
qpm: float = 120.0,
use_program: bool = True,
use_drums: bool = True,
instrument_mapper: Optional[dict] = None,
only_piano: bool = False,
) -> "NoteSequence":
"""Convert a sequence of tokens into a sequence of notes."""
if isinstance(token_sequence, str):
token_sequence = token_sequence.split()
note_sequence = empty_note_sequence(qpm)
note_length_16th = 0.25 * 60 / qpm
bar_length = 4.0 * 60 / qpm
current_program = 1
current_is_drum = False
current_instrument = 0
track_count = 0
current_bar_index = 0
current_time = 0
current_notes = {}
for token in token_sequence:
if token == "PIECE_START":
pass
elif token == "PIECE_END":
break
elif token == "TRACK_START":
current_bar_index = 0
track_count += 1
elif token == "TRACK_END":
pass
elif token == "KEYS_START":
pass
elif token == "KEYS_END":
pass
elif token.startswith("KEY="):
pass
elif token.startswith("INST"):
instrument = token.split("=")[-1]
if instrument != "DRUMS" and use_program:
if instrument_mapper is not None:
if instrument in instrument_mapper:
instrument = instrument_mapper[instrument]
try:
current_program = int(instrument)
except ValueError:
current_program = 0
current_instrument = track_count
current_is_drum = False
if instrument == "DRUMS" and use_drums:
current_instrument = 0
current_program = 0
current_is_drum = True
elif token == "BAR_START":
current_time = current_bar_index * bar_length
current_notes = {}
elif token == "BAR_END":
current_bar_index += 1
elif token.startswith("NOTE_ON"):
try:
pitch = int(token.split("=")[-1])
note = note_sequence.notes.add()
note.start_time = current_time
note.end_time = current_time + 4 * note_length_16th
note.pitch = pitch
note.instrument = current_instrument
note.program = current_program
note.velocity = 80
note.is_drum = current_is_drum
current_notes[pitch] = note
except ValueError:
pass
elif token.startswith("NOTE_OFF"):
try:
pitch = int(token.split("=")[-1])
if pitch in current_notes:
note = current_notes[pitch]
note.end_time = current_time
except ValueError:
pass
elif token.startswith("TIME_DELTA"):
try:
delta = float(token.split("=")[-1]) * note_length_16th
current_time += delta
except ValueError:
pass
elif token.startswith("DENSITY="):
pass
elif token == "[PAD]":
pass
# Make the instruments right
instruments_drums = []
for note in note_sequence.notes:
pair = [note.program, note.is_drum]
if pair not in instruments_drums:
instruments_drums += [pair]
note.instrument = instruments_drums.index(pair)
if only_piano:
for note in note_sequence.notes:
if not note.is_drum:
note.instrument = 0
note.program = 0
return note_sequence
def create_seed_string(genre: str = "OTHER") -> str:
"""Create a seed string for generating a new piece."""
if genre == "RANDOM":
return "PIECE_START"
return f"PIECE_START GENRE={genre} TRACK_START"
def get_instruments(text_sequence: str) -> List[str]:
"""Extract the list of instruments from a text sequence."""
instruments = []
parts = text_sequence.split()
for part in parts:
if part.startswith("INST="):
if part[5:] == "DRUMS":
instruments.append("Drums")
else:
try:
index = int(part[5:])
if 0 <= index < len(GM_INSTRUMENTS):
instruments.append(GM_INSTRUMENTS[index])
else:
instruments.append(f"Program {index}")
except ValueError:
pass
return instruments
def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
"""Generate a new instrument sequence from a given seed and temperature."""
model, tok = get_model_and_tokenizer()
seed_length = len(tok.encode(seed))
# Retry until we get a valid generation with notes
max_attempts = 5
for attempt in range(max_attempts):
input_ids = tok.encode(seed, return_tensors="pt")
input_ids = input_ids.to(device)
eos_token_id = tok.encode("TRACK_END")[0]
with torch.no_grad():
generated_ids = model.generate(
input_ids,
max_new_tokens=2048,
do_sample=True,
temperature=temp,
eos_token_id=eos_token_id,
)
generated_sequence = tok.decode(generated_ids[0])
new_generated_sequence = tok.decode(generated_ids[0][seed_length:])
if "NOTE_ON" in new_generated_sequence:
return generated_sequence
# Return last attempt even if no NOTE_ON found
return generated_sequence
def create_noteseq_piano_roll(note_sequence, title="Generated Music"):
"""Create a piano roll visualization from a NoteSequence using matplotlib."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
fig, ax = plt.subplots(figsize=(14, 6))
# Color by instrument
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C']
if len(note_sequence.notes) == 0:
ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14)
ax.set_xlim(0, 10)
ax.set_ylim(0, 127)
else:
min_pitch = min(n.pitch for n in note_sequence.notes)
max_pitch = max(n.pitch for n in note_sequence.notes)
max_time = max(n.end_time for n in note_sequence.notes)
for note in note_sequence.notes:
color = colors[note.instrument % len(colors)]
rect = Rectangle(
(note.start_time, note.pitch - 0.4),
note.end_time - note.start_time,
0.8,
facecolor=color,
edgecolor='black',
linewidth=0.3,
alpha=0.8
)
ax.add_patch(rect)
ax.set_xlim(0, max_time + 0.5)
ax.set_ylim(min_pitch - 2, max_pitch + 2)
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('MIDI Pitch')
ax.set_title(title)
ax.grid(True, alpha=0.3)
# Save to PNG
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
img_path = f.name
fig.savefig(img_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return img_path
def get_outputs_from_string(generated_sequence: str, qpm: int = 120):
"""Convert a generated sequence into various output formats."""
instruments = get_instruments(generated_sequence)
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
# Create visualization using custom matplotlib function
img_path = create_noteseq_piano_roll(note_sequence, f"Generated at {qpm} BPM")
num_tokens = str(len(generated_sequence.split()))
# Save MIDI file
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f:
midi_path = f.name
note_seq.note_sequence_to_midi_file(note_sequence, midi_path)
# Render audio
audio_path = render_midi_to_audio(midi_path)
return midi_path, img_path, instruments_str, num_tokens, audio_path
def generate_song(genre: str = "OTHER", temp: float = 0.75, text_sequence: str = "", qpm: int = 120):
"""Generate a song given a genre, temperature, initial text sequence, and tempo."""
if not TORCH_AVAILABLE:
return None, None, "PyTorch not available", "", "0", None
if text_sequence == "":
seed_string = create_seed_string(genre)
else:
seed_string = text_sequence
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string(
generated_sequence, qpm
)
return midi_file, fig, instruments_str, generated_sequence, num_tokens, audio_path
def remove_last_instrument(text_sequence: str, qpm: int = 120):
"""Remove the last instrument from a song string."""
if not TORCH_AVAILABLE:
return None, None, "PyTorch not available", "", "0", None
tracks = text_sequence.split("TRACK_START")
modified_tracks = tracks[:-1]
new_song = "TRACK_START".join(modified_tracks)
if len(tracks) == 2:
midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song(
text_sequence=new_song
)
elif len(tracks) == 1:
midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song(
text_sequence=""
)
else:
midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string(
new_song, qpm
)
return midi_file, fig, instruments_str, new_song, num_tokens, audio_path
def regenerate_last_instrument(text_sequence: str, qpm: int = 120):
"""Regenerate the last instrument in a song string."""
if not TORCH_AVAILABLE:
return None, None, "PyTorch not available", "", "0", None
last_inst_index = text_sequence.rfind("INST=")
if last_inst_index == -1:
midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song(
text_sequence="", qpm=qpm
)
else:
next_space_index = text_sequence.find(" ", last_inst_index)
if next_space_index == -1:
# No space after INST=, use the whole remaining string
new_seed = text_sequence
else:
new_seed = text_sequence[:next_space_index]
midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song(
text_sequence=new_seed, qpm=qpm
)
return midi_file, fig, instruments_str, new_song, num_tokens, audio_path
def change_tempo(text_sequence: str, qpm: int):
"""Change the tempo of a song string."""
if not TORCH_AVAILABLE:
return None, None, "PyTorch not available", "", "0", None
if not text_sequence or text_sequence.strip() == "":
return None, None, "No sequence to process", "", "0", None
midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string(
text_sequence, qpm=qpm
)
return midi_file, fig, instruments_str, text_sequence, num_tokens, audio_path
# =============================================================================
# SkyTNT Model Integration
# =============================================================================
try:
from skytnt_generator import generate_midi as skytnt_generate, get_available_instruments, get_available_drum_kits, ONNX_AVAILABLE
SKYTNT_AVAILABLE = ONNX_AVAILABLE
except ImportError:
SKYTNT_AVAILABLE = False
print("SkyTNT generator not available")
def generate_skytnt(instruments, drum_kit, bpm, max_events, temp, top_p, top_k, seed_rand, seed):
"""Generate music using SkyTNT model."""
if not SKYTNT_AVAILABLE:
return None, None, "SkyTNT model not available", None
# Parse instruments
instr_list = instruments if instruments else []
actual_seed = None if seed_rand else int(seed)
try:
midi_path = skytnt_generate(
instruments=instr_list,
drum_kit=drum_kit,
bpm=int(bpm),
max_events=int(max_events),
temp=temp,
top_p=top_p,
top_k=int(top_k),
seed=actual_seed
)
if midi_path is None:
return None, None, "Generation failed", None
# Create visualization
img_path = create_skytnt_piano_roll(midi_path)
# Render audio
audio_path = render_midi_to_audio(midi_path)
status = f"Generated with {len(instr_list)} instruments at {bpm} BPM"
if audio_path:
status += " - Audio rendered!"
return midi_path, img_path, status, audio_path
except Exception as e:
return None, None, f"Error: {str(e)}", None
def create_skytnt_piano_roll(midi_path: str):
"""Create piano roll visualization from MIDI file."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
try:
import MIDI
midi_data = MIDI.midi2score(open(midi_path, 'rb').read())
fig, ax = plt.subplots(figsize=(14, 6))
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C']
ticks_per_beat = midi_data[0]
all_notes = []
for track_idx, track in enumerate(midi_data[1:]):
for event in track:
if event[0] == 'note':
start_time = event[1] / ticks_per_beat
duration = event[2] / ticks_per_beat
channel = event[3]
pitch = event[4]
all_notes.append((start_time, duration, pitch, channel))
if not all_notes:
ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14)
ax.set_xlim(0, 10)
ax.set_ylim(0, 127)
else:
min_pitch = min(n[2] for n in all_notes)
max_pitch = max(n[2] for n in all_notes)
max_time = max(n[0] + n[1] for n in all_notes)
for start, dur, pitch, channel in all_notes:
color = colors[channel % len(colors)]
rect = Rectangle((start, pitch - 0.4), dur, 0.8,
facecolor=color, edgecolor='black', linewidth=0.3, alpha=0.8)
ax.add_patch(rect)
ax.set_xlim(0, max_time + 0.5)
ax.set_ylim(min_pitch - 2, max_pitch + 2)
ax.set_xlabel('Time (beats)')
ax.set_ylabel('MIDI Pitch')
ax.set_title('Generated Music (SkyTNT)')
ax.grid(True, alpha=0.3)
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
img_path = f.name
fig.savefig(img_path, dpi=100, bbox_inches='tight')
plt.close(fig)
return img_path
except Exception as e:
print(f"Visualization error: {e}")
return None
# =============================================================================
# Gradio Interface
# =============================================================================
GENRES = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"]
def build_ui():
"""Build Gradio interface for AI Multitrack MIDI Composer."""
with gr.Blocks(title="Multitrack MIDI Composer") as demo:
gr.Markdown("""
# 🎹 Multitrack MIDI Composer
AI-powered multi-instrument music generation. Choose a model and generate!
""")
with gr.Tabs():
# Tab 1: Multitrack Generator (juancopi81) - Default
with gr.TabItem("Multitrack Generator (Genre-based)"):
if not TORCH_AVAILABLE:
gr.Markdown("⚠️ PyTorch/Transformers not installed.")
else:
with gr.Row():
with gr.Column():
mt_temp = gr.Slider(0, 1, step=0.05, value=0.85, label="Temperature")
mt_genre = gr.Dropdown(choices=GENRES, value="POP", label="Select Genre")
with gr.Row():
btn_from_scratch = gr.Button("Start from scratch", variant="primary")
btn_continue = gr.Button("Continue Generation")
with gr.Row():
btn_remove_last = gr.Button("Remove last instrument")
btn_regenerate_last = gr.Button("Regenerate last instrument")
with gr.Column():
mt_audio = gr.Audio(label="Listen")
with gr.Group():
mt_midi = gr.File(label="Download MIDI File")
with gr.Row():
mt_qpm = gr.Slider(60, 140, step=10, value=120, label="Tempo")
btn_qpm = gr.Button("Change Tempo")
with gr.Row():
with gr.Column():
mt_img = gr.Image(label="Music Visualization")
with gr.Column():
mt_instruments = gr.Markdown("### Instruments")
mt_sequence = gr.Textbox(label="Token Sequence", lines=3)
mt_empty = gr.Textbox(visible=False, value="")
mt_tokens = gr.Textbox(visible=False)
btn_from_scratch.click(
fn=generate_song,
inputs=[mt_genre, mt_temp, mt_empty, mt_qpm],
outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio]
)
btn_continue.click(
fn=generate_song,
inputs=[mt_genre, mt_temp, mt_sequence, mt_qpm],
outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio]
)
btn_remove_last.click(
fn=remove_last_instrument,
inputs=[mt_sequence, mt_qpm],
outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio]
)
btn_regenerate_last.click(
fn=regenerate_last_instrument,
inputs=[mt_sequence, mt_qpm],
outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio]
)
btn_qpm.click(
fn=change_tempo,
inputs=[mt_sequence, mt_qpm],
outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio]
)
gr.Markdown("**Model**: [juancopi81/lmd-8bars-2048-epochs40_v4](https://huggingface.co/juancopi81/lmd-8bars-2048-epochs40_v4)")
# Tab 2: SkyTNT MIDI Model
with gr.TabItem("SkyTNT MIDI Model"):
if not SKYTNT_AVAILABLE:
gr.Markdown("⚠️ SkyTNT model not available (onnxruntime required).")
else:
gr.Markdown("Select instruments and generate MIDI events. Processing: ~20 seconds for 200 events.")
with gr.Row():
with gr.Column():
sky_instruments = gr.Dropdown(
label="Instruments (optional, auto if empty)",
choices=get_available_instruments(),
multiselect=True,
max_choices=10
)
sky_drum_kit = gr.Dropdown(
label="Drum Kit",
choices=get_available_drum_kits(),
value="None"
)
sky_bpm = gr.Slider(60, 200, step=5, value=120, label="BPM")
sky_max_events = gr.Slider(50, 500, step=50, value=200, label="Max Events")
with gr.Column():
with gr.Accordion("Advanced Options", open=False):
sky_temp = gr.Slider(0.1, 1.5, step=0.05, value=1.0, label="Temperature")
sky_top_p = gr.Slider(0.5, 1.0, step=0.05, value=0.95, label="Top-p")
sky_top_k = gr.Slider(1, 50, step=1, value=20, label="Top-k")
sky_seed_rand = gr.Checkbox(label="Random Seed", value=True)
sky_seed = gr.Number(label="Seed", value=42)
btn_sky_generate = gr.Button("Generate", variant="primary")
with gr.Row():
sky_audio = gr.Audio(label="Listen")
sky_midi = gr.File(label="Download MIDI")
sky_img = gr.Image(label="Visualization")
sky_status = gr.Textbox(label="Status", interactive=False)
btn_sky_generate.click(
fn=generate_skytnt,
inputs=[sky_instruments, sky_drum_kit, sky_bpm, sky_max_events,
sky_temp, sky_top_p, sky_top_k, sky_seed_rand, sky_seed],
outputs=[sky_midi, sky_img, sky_status, sky_audio]
)
gr.Markdown("**Model**: [skytnt/midi-model](https://huggingface.co/skytnt/midi-model) (ONNX)")
gr.Markdown("""
---
**Credits**: Dr. Tristan Behrens (Multitrack) | SkyTNT (MIDI Model)
""")
return demo
# =============================================================================
# CLI Interface
# =============================================================================
def cli_main():
"""CLI entry point."""
parser = argparse.ArgumentParser(description="Multitrack MIDI Composer")
subparsers = parser.add_subparsers(dest="command", help="Commands")
# Demo command
demo_parser = subparsers.add_parser("demo", help="Generate demo MIDI")
demo_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM")
demo_parser.add_argument("--bars", type=int, default=4, help="Number of bars")
demo_parser.add_argument("--output", "-o", type=str, default="demo.mid", help="Output file")
# Generate command
gen_parser = subparsers.add_parser("generate", help="Generate multitrack music")
gen_parser.add_argument("--genre", type=str, default="POP", choices=GENRES)
gen_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM")
gen_parser.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature")
gen_parser.add_argument("--output", "-o", type=str, default="output.mid", help="Output file")
args = parser.parse_args()
if args.command == "demo":
midi_path, fig, audio_path, status = create_demo_midi(args.tempo, args.bars)
if midi_path:
import shutil
shutil.copy(midi_path, args.output)
print(f"Created: {args.output}")
if audio_path:
audio_out = args.output.replace('.mid', '.wav')
shutil.copy(audio_path, audio_out)
print(f"Audio: {audio_out}")
print(status)
else:
print(f"Error: {status}")
elif args.command == "generate":
if not TORCH_AVAILABLE:
print("Error: PyTorch not available. Install: pip install torch transformers note-seq")
return
print(f"Generating {args.genre} music at {args.tempo} BPM...")
midi_path, fig, instruments, sequence, tokens, audio_path = generate_song(
genre=args.genre, temp=args.temperature, qpm=args.tempo
)
if midi_path:
import shutil
shutil.copy(midi_path, args.output)
print(f"Created: {args.output}")
if audio_path:
audio_out = args.output.replace('.mid', '.wav')
shutil.copy(audio_path, audio_out)
print(f"Audio: {audio_out}")
print(f"Instruments:\n{instruments}")
print(f"Tokens: {tokens}")
else:
parser.print_help()
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] in ["demo", "generate", "-h", "--help"]:
cli_main()
else:
# Preload model if available
if TORCH_AVAILABLE:
print("Initializing model...")
get_model_and_tokenizer()
demo = build_ui()
demo.launch()