""" Multitrack MIDI Composer - Combined MIDI generation tools - Simple MIDI Composer: Demo mode chord progressions - Multitrack Generator: AI multi-instrument composition with genre selection CPU-only HuggingFace Space """ import os import sys import tempfile import argparse import struct import wave from typing import List, Tuple, Optional os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" import gradio as gr import numpy as np # Import meltysynth for pure Python MIDI-to-audio synthesis try: import meltysynth as ms MELTYSYNTH_AVAILABLE = True except ImportError: MELTYSYNTH_AVAILABLE = False print("meltysynth not available - Audio playback disabled") # Path to SoundFont file SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) SOUNDFONT_PATH = os.path.join(SCRIPT_DIR, "TimGM6mb.sf2") # Global synthesizer (loaded once) _synthesizer = None _synth_settings = None def get_synthesizer(): """Load the SoundFont synthesizer (cached).""" global _synthesizer, _synth_settings if _synthesizer is None and MELTYSYNTH_AVAILABLE: if os.path.exists(SOUNDFONT_PATH): print(f"Loading SoundFont: {SOUNDFONT_PATH}") sound_font = ms.SoundFont.from_file(SOUNDFONT_PATH) _synth_settings = ms.SynthesizerSettings(44100) _synthesizer = ms.Synthesizer(sound_font, _synth_settings) print("SoundFont loaded successfully!") else: print(f"SoundFont not found: {SOUNDFONT_PATH}") return _synthesizer, _synth_settings def render_midi_to_audio(midi_path: str) -> Optional[str]: """Render a MIDI file to WAV audio using meltysynth.""" if not MELTYSYNTH_AVAILABLE: return None synth, settings = get_synthesizer() if synth is None: return None try: # Load MIDI file midi_file = ms.MidiFile.from_file(midi_path) sequencer = ms.MidiFileSequencer(synth) sequencer.play(midi_file, False) # Calculate buffer size (duration + 1 second for tail) duration = midi_file.length + 1.0 buffer_length = int(settings.sample_rate * duration) # Create buffers and render left = ms.create_buffer(buffer_length) right = ms.create_buffer(buffer_length) sequencer.render(left, right) # Convert to interleaved stereo int16 samples = [] for i in range(buffer_length): # Clamp and convert to int16 l_sample = max(-1.0, min(1.0, left[i])) r_sample = max(-1.0, min(1.0, right[i])) samples.append(int(l_sample * 32767)) samples.append(int(r_sample * 32767)) # Write WAV file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: wav_path = f.name with wave.open(wav_path, 'w') as wav_file: wav_file.setnchannels(2) wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(settings.sample_rate) wav_file.writeframes(struct.pack(f'<{len(samples)}h', *samples)) return wav_path except Exception as e: print(f"Audio render error: {e}") return None # ============================================================================= # Tab 1: Simple MIDI Composer (Demo Mode) # ============================================================================= try: from midiutil import MIDIFile MIDIUTIL_AVAILABLE = True except ImportError: MIDIUTIL_AVAILABLE = False print("midiutil not available - Demo Composer disabled") def create_piano_roll(notes_data, total_time, title="Piano Roll"): """Create a piano roll visualization and save as PNG image.""" import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Rectangle fig, ax = plt.subplots(figsize=(12, 6)) colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8'] for i, (start, end, pitch, track) in enumerate(notes_data): color = colors[track % len(colors)] rect = Rectangle((start, pitch - 0.4), end - start, 0.8, facecolor=color, edgecolor='black', linewidth=0.5, alpha=0.8) ax.add_patch(rect) ax.set_xlim(0, total_time) ax.set_ylim(40, 90) ax.set_xlabel('Time (beats)') ax.set_ylabel('MIDI Pitch') ax.set_title(title) ax.grid(True, alpha=0.3) # Save to PNG file with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: img_path = f.name fig.savefig(img_path, dpi=100, bbox_inches='tight') plt.close(fig) return img_path def create_demo_midi(tempo: int = 120, length_bars: int = 4): """Create demo MIDI with chord progression and melody, with visualization and audio.""" if not MIDIUTIL_AVAILABLE: return None, None, None, "midiutil not installed" try: midi = MIDIFile(2) # Two tracks: melody and chords midi.addTempo(0, 0, tempo) midi.addProgramChange(0, 0, 0, 0) # Piano for melody midi.addProgramChange(1, 0, 0, 0) # Piano for chords # Simple chord progression (C - Am - F - G) chords = [ [60, 64, 67], # C major [57, 60, 64], # A minor [53, 57, 60], # F major [55, 59, 62], # G major ] # Melody notes melody_notes = [ 72, 74, 76, 77, 76, 74, 72, 71, 69, 71, 72, 74, 72, 71, 69, 67, 65, 67, 69, 71, 72, 71, 69, 67, 67, 69, 71, 72, 74, 76, 77, 79, ] # Collect notes for visualization notes_data = [] for bar in range(length_bars): bar_time = bar * 4 # In beats chord_idx = bar % len(chords) # Add chord notes for note in chords[chord_idx]: midi.addNote(1, 0, note, bar_time, 4, 60) notes_data.append((bar_time, bar_time + 4, note, 1)) # Add melody notes (8 per bar) for i in range(8): note_time = bar_time + (i * 0.5) note_idx = (bar * 8 + i) % len(melody_notes) midi.addNote(0, 0, melody_notes[note_idx], note_time, 0.4, 90) notes_data.append((note_time, note_time + 0.4, melody_notes[note_idx], 0)) with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f: midi.writeFile(f) midi_path = f.name # Create visualization total_time = length_bars * 4 fig = create_piano_roll(notes_data, total_time, f"Demo: {length_bars} bars at {tempo} BPM") # Render audio audio_path = render_midi_to_audio(midi_path) status = f"Created: {length_bars} bars at {tempo} BPM" if audio_path: status += " - Audio rendered!" else: status += " - Download MIDI to play" return midi_path, fig, audio_path, status except Exception as e: return None, None, None, f"Error: {str(e)}" # ============================================================================= # Tab 2: Multitrack Generator (Transformer-based) # ============================================================================= try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM import note_seq from note_seq.protobuf.music_pb2 import NoteSequence from note_seq.constants import STANDARD_PPQ from matplotlib.figure import Figure TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False print("PyTorch/Transformers not available - Multitrack Generator disabled") SAMPLE_RATE = 44100 GM_INSTRUMENTS = [ "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", "Clavi", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", "Xylophone", "Tubular Bells", "Dulcimer", "Drawbar Organ", "Percussive Organ", "Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica", "Tango Accordion", "Acoustic Guitar (nylon)", "Acoustic Guitar (steel)", "Electric Guitar (jazz)", "Electric Guitar (clean)", "Electric Guitar (muted)", "Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics", "Acoustic Bass", "Electric Bass (finger)", "Electric Bass (pick)", "Fretless Bass", "Slap Bass 1", "Slap Bass 2", "Synth Bass 1", "Synth Bass 2", "Violin", "Viola", "Cello", "Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp", "Timpani", "String Ensemble 1", "String Ensemble 2", "Synth Strings 1", "Synth Strings 2", "Choir Aahs", "Voice Oohs", "Synth Choir", "Orchestra Hit", "Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section", "Synth Brass 1", "Synth Brass 2", "Soprano Sax", "Alto Sax", "Tenor Sax", "Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet", "Piccolo", "Flute", "Recorder", "Pan Flute", "Blown Bottle", "Shakuhachi", "Whistle", "Ocarina", "Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)", "Lead 4 (chiff)", "Lead 5 (charang)", "Lead 6 (voice)", "Lead 7 (fifths)", "Lead 8 (bass + lead)", "Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)", "Pad 4 (choir)", "Pad 5 (bowed)", "Pad 6 (metallic)", "Pad 7 (halo)", "Pad 8 (sweep)", "FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)", "FX 4 (atmosphere)", "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)", "FX 8 (sci-fi)", "Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe", "Fiddle", "Shanai", "Tinkle Bell", "Agogo", "Steel Drums", "Woodblock", "Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal", "Guitar Fret Noise", "Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter", "Applause", "Gunshot", ] # Global model and tokenizer device = None tokenizer = None model = None def get_model_and_tokenizer(): """Load model and tokenizer on CPU.""" global model, tokenizer, device if model is None or tokenizer is None: device = torch.device("cpu") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer") print("Loading model on CPU...") model = AutoModelForCausalLM.from_pretrained("juancopi81/lmd-8bars-2048-epochs40_v4") model = model.to(device) model.eval() print("Model loaded successfully!") return model, tokenizer def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> "NoteSequence": """Create an empty note sequence.""" note_sequence = NoteSequence() note_sequence.tempos.add().qpm = qpm note_sequence.ticks_per_quarter = STANDARD_PPQ note_sequence.total_time = total_time return note_sequence def token_sequence_to_note_sequence( token_sequence: str, qpm: float = 120.0, use_program: bool = True, use_drums: bool = True, instrument_mapper: Optional[dict] = None, only_piano: bool = False, ) -> "NoteSequence": """Convert a sequence of tokens into a sequence of notes.""" if isinstance(token_sequence, str): token_sequence = token_sequence.split() note_sequence = empty_note_sequence(qpm) note_length_16th = 0.25 * 60 / qpm bar_length = 4.0 * 60 / qpm current_program = 1 current_is_drum = False current_instrument = 0 track_count = 0 current_bar_index = 0 current_time = 0 current_notes = {} for token in token_sequence: if token == "PIECE_START": pass elif token == "PIECE_END": break elif token == "TRACK_START": current_bar_index = 0 track_count += 1 elif token == "TRACK_END": pass elif token == "KEYS_START": pass elif token == "KEYS_END": pass elif token.startswith("KEY="): pass elif token.startswith("INST"): instrument = token.split("=")[-1] if instrument != "DRUMS" and use_program: if instrument_mapper is not None: if instrument in instrument_mapper: instrument = instrument_mapper[instrument] try: current_program = int(instrument) except ValueError: current_program = 0 current_instrument = track_count current_is_drum = False if instrument == "DRUMS" and use_drums: current_instrument = 0 current_program = 0 current_is_drum = True elif token == "BAR_START": current_time = current_bar_index * bar_length current_notes = {} elif token == "BAR_END": current_bar_index += 1 elif token.startswith("NOTE_ON"): try: pitch = int(token.split("=")[-1]) note = note_sequence.notes.add() note.start_time = current_time note.end_time = current_time + 4 * note_length_16th note.pitch = pitch note.instrument = current_instrument note.program = current_program note.velocity = 80 note.is_drum = current_is_drum current_notes[pitch] = note except ValueError: pass elif token.startswith("NOTE_OFF"): try: pitch = int(token.split("=")[-1]) if pitch in current_notes: note = current_notes[pitch] note.end_time = current_time except ValueError: pass elif token.startswith("TIME_DELTA"): try: delta = float(token.split("=")[-1]) * note_length_16th current_time += delta except ValueError: pass elif token.startswith("DENSITY="): pass elif token == "[PAD]": pass # Make the instruments right instruments_drums = [] for note in note_sequence.notes: pair = [note.program, note.is_drum] if pair not in instruments_drums: instruments_drums += [pair] note.instrument = instruments_drums.index(pair) if only_piano: for note in note_sequence.notes: if not note.is_drum: note.instrument = 0 note.program = 0 return note_sequence def create_seed_string(genre: str = "OTHER") -> str: """Create a seed string for generating a new piece.""" if genre == "RANDOM": return "PIECE_START" return f"PIECE_START GENRE={genre} TRACK_START" def get_instruments(text_sequence: str) -> List[str]: """Extract the list of instruments from a text sequence.""" instruments = [] parts = text_sequence.split() for part in parts: if part.startswith("INST="): if part[5:] == "DRUMS": instruments.append("Drums") else: try: index = int(part[5:]) if 0 <= index < len(GM_INSTRUMENTS): instruments.append(GM_INSTRUMENTS[index]) else: instruments.append(f"Program {index}") except ValueError: pass return instruments def generate_new_instrument(seed: str, temp: float = 0.75) -> str: """Generate a new instrument sequence from a given seed and temperature.""" model, tok = get_model_and_tokenizer() seed_length = len(tok.encode(seed)) # Retry until we get a valid generation with notes max_attempts = 5 for attempt in range(max_attempts): input_ids = tok.encode(seed, return_tensors="pt") input_ids = input_ids.to(device) eos_token_id = tok.encode("TRACK_END")[0] with torch.no_grad(): generated_ids = model.generate( input_ids, max_new_tokens=2048, do_sample=True, temperature=temp, eos_token_id=eos_token_id, ) generated_sequence = tok.decode(generated_ids[0]) new_generated_sequence = tok.decode(generated_ids[0][seed_length:]) if "NOTE_ON" in new_generated_sequence: return generated_sequence # Return last attempt even if no NOTE_ON found return generated_sequence def create_noteseq_piano_roll(note_sequence, title="Generated Music"): """Create a piano roll visualization from a NoteSequence using matplotlib.""" import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Rectangle fig, ax = plt.subplots(figsize=(14, 6)) # Color by instrument colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C'] if len(note_sequence.notes) == 0: ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14) ax.set_xlim(0, 10) ax.set_ylim(0, 127) else: min_pitch = min(n.pitch for n in note_sequence.notes) max_pitch = max(n.pitch for n in note_sequence.notes) max_time = max(n.end_time for n in note_sequence.notes) for note in note_sequence.notes: color = colors[note.instrument % len(colors)] rect = Rectangle( (note.start_time, note.pitch - 0.4), note.end_time - note.start_time, 0.8, facecolor=color, edgecolor='black', linewidth=0.3, alpha=0.8 ) ax.add_patch(rect) ax.set_xlim(0, max_time + 0.5) ax.set_ylim(min_pitch - 2, max_pitch + 2) ax.set_xlabel('Time (seconds)') ax.set_ylabel('MIDI Pitch') ax.set_title(title) ax.grid(True, alpha=0.3) # Save to PNG with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: img_path = f.name fig.savefig(img_path, dpi=100, bbox_inches='tight') plt.close(fig) return img_path def get_outputs_from_string(generated_sequence: str, qpm: int = 120): """Convert a generated sequence into various output formats.""" instruments = get_instruments(generated_sequence) instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) # Create visualization using custom matplotlib function img_path = create_noteseq_piano_roll(note_sequence, f"Generated at {qpm} BPM") num_tokens = str(len(generated_sequence.split())) # Save MIDI file with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f: midi_path = f.name note_seq.note_sequence_to_midi_file(note_sequence, midi_path) # Render audio audio_path = render_midi_to_audio(midi_path) return midi_path, img_path, instruments_str, num_tokens, audio_path def generate_song(genre: str = "OTHER", temp: float = 0.75, text_sequence: str = "", qpm: int = 120): """Generate a song given a genre, temperature, initial text sequence, and tempo.""" if not TORCH_AVAILABLE: return None, None, "PyTorch not available", "", "0", None if text_sequence == "": seed_string = create_seed_string(genre) else: seed_string = text_sequence generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( generated_sequence, qpm ) return midi_file, fig, instruments_str, generated_sequence, num_tokens, audio_path def remove_last_instrument(text_sequence: str, qpm: int = 120): """Remove the last instrument from a song string.""" if not TORCH_AVAILABLE: return None, None, "PyTorch not available", "", "0", None tracks = text_sequence.split("TRACK_START") modified_tracks = tracks[:-1] new_song = "TRACK_START".join(modified_tracks) if len(tracks) == 2: midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( text_sequence=new_song ) elif len(tracks) == 1: midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( text_sequence="" ) else: midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( new_song, qpm ) return midi_file, fig, instruments_str, new_song, num_tokens, audio_path def regenerate_last_instrument(text_sequence: str, qpm: int = 120): """Regenerate the last instrument in a song string.""" if not TORCH_AVAILABLE: return None, None, "PyTorch not available", "", "0", None last_inst_index = text_sequence.rfind("INST=") if last_inst_index == -1: midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( text_sequence="", qpm=qpm ) else: next_space_index = text_sequence.find(" ", last_inst_index) if next_space_index == -1: # No space after INST=, use the whole remaining string new_seed = text_sequence else: new_seed = text_sequence[:next_space_index] midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( text_sequence=new_seed, qpm=qpm ) return midi_file, fig, instruments_str, new_song, num_tokens, audio_path def change_tempo(text_sequence: str, qpm: int): """Change the tempo of a song string.""" if not TORCH_AVAILABLE: return None, None, "PyTorch not available", "", "0", None if not text_sequence or text_sequence.strip() == "": return None, None, "No sequence to process", "", "0", None midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( text_sequence, qpm=qpm ) return midi_file, fig, instruments_str, text_sequence, num_tokens, audio_path # ============================================================================= # SkyTNT Model Integration # ============================================================================= try: from skytnt_generator import generate_midi as skytnt_generate, get_available_instruments, get_available_drum_kits, ONNX_AVAILABLE SKYTNT_AVAILABLE = ONNX_AVAILABLE except ImportError: SKYTNT_AVAILABLE = False print("SkyTNT generator not available") def generate_skytnt(instruments, drum_kit, bpm, max_events, temp, top_p, top_k, seed_rand, seed): """Generate music using SkyTNT model.""" if not SKYTNT_AVAILABLE: return None, None, "SkyTNT model not available", None # Parse instruments instr_list = instruments if instruments else [] actual_seed = None if seed_rand else int(seed) try: midi_path = skytnt_generate( instruments=instr_list, drum_kit=drum_kit, bpm=int(bpm), max_events=int(max_events), temp=temp, top_p=top_p, top_k=int(top_k), seed=actual_seed ) if midi_path is None: return None, None, "Generation failed", None # Create visualization img_path = create_skytnt_piano_roll(midi_path) # Render audio audio_path = render_midi_to_audio(midi_path) status = f"Generated with {len(instr_list)} instruments at {bpm} BPM" if audio_path: status += " - Audio rendered!" return midi_path, img_path, status, audio_path except Exception as e: return None, None, f"Error: {str(e)}", None def create_skytnt_piano_roll(midi_path: str): """Create piano roll visualization from MIDI file.""" import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Rectangle try: import MIDI midi_data = MIDI.midi2score(open(midi_path, 'rb').read()) fig, ax = plt.subplots(figsize=(14, 6)) colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C'] ticks_per_beat = midi_data[0] all_notes = [] for track_idx, track in enumerate(midi_data[1:]): for event in track: if event[0] == 'note': start_time = event[1] / ticks_per_beat duration = event[2] / ticks_per_beat channel = event[3] pitch = event[4] all_notes.append((start_time, duration, pitch, channel)) if not all_notes: ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14) ax.set_xlim(0, 10) ax.set_ylim(0, 127) else: min_pitch = min(n[2] for n in all_notes) max_pitch = max(n[2] for n in all_notes) max_time = max(n[0] + n[1] for n in all_notes) for start, dur, pitch, channel in all_notes: color = colors[channel % len(colors)] rect = Rectangle((start, pitch - 0.4), dur, 0.8, facecolor=color, edgecolor='black', linewidth=0.3, alpha=0.8) ax.add_patch(rect) ax.set_xlim(0, max_time + 0.5) ax.set_ylim(min_pitch - 2, max_pitch + 2) ax.set_xlabel('Time (beats)') ax.set_ylabel('MIDI Pitch') ax.set_title('Generated Music (SkyTNT)') ax.grid(True, alpha=0.3) with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: img_path = f.name fig.savefig(img_path, dpi=100, bbox_inches='tight') plt.close(fig) return img_path except Exception as e: print(f"Visualization error: {e}") return None # ============================================================================= # Gradio Interface # ============================================================================= GENRES = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"] def build_ui(): """Build Gradio interface for AI Multitrack MIDI Composer.""" with gr.Blocks(title="Multitrack MIDI Composer") as demo: gr.Markdown(""" # 🎹 Multitrack MIDI Composer AI-powered multi-instrument music generation. Choose a model and generate! """) with gr.Tabs(): # Tab 1: Multitrack Generator (juancopi81) - Default with gr.TabItem("Multitrack Generator (Genre-based)"): if not TORCH_AVAILABLE: gr.Markdown("⚠️ PyTorch/Transformers not installed.") else: with gr.Row(): with gr.Column(): mt_temp = gr.Slider(0, 1, step=0.05, value=0.85, label="Temperature") mt_genre = gr.Dropdown(choices=GENRES, value="POP", label="Select Genre") with gr.Row(): btn_from_scratch = gr.Button("Start from scratch", variant="primary") btn_continue = gr.Button("Continue Generation") with gr.Row(): btn_remove_last = gr.Button("Remove last instrument") btn_regenerate_last = gr.Button("Regenerate last instrument") with gr.Column(): mt_audio = gr.Audio(label="Listen") with gr.Group(): mt_midi = gr.File(label="Download MIDI File") with gr.Row(): mt_qpm = gr.Slider(60, 140, step=10, value=120, label="Tempo") btn_qpm = gr.Button("Change Tempo") with gr.Row(): with gr.Column(): mt_img = gr.Image(label="Music Visualization") with gr.Column(): mt_instruments = gr.Markdown("### Instruments") mt_sequence = gr.Textbox(label="Token Sequence", lines=3) mt_empty = gr.Textbox(visible=False, value="") mt_tokens = gr.Textbox(visible=False) btn_from_scratch.click( fn=generate_song, inputs=[mt_genre, mt_temp, mt_empty, mt_qpm], outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] ) btn_continue.click( fn=generate_song, inputs=[mt_genre, mt_temp, mt_sequence, mt_qpm], outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] ) btn_remove_last.click( fn=remove_last_instrument, inputs=[mt_sequence, mt_qpm], outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] ) btn_regenerate_last.click( fn=regenerate_last_instrument, inputs=[mt_sequence, mt_qpm], outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] ) btn_qpm.click( fn=change_tempo, inputs=[mt_sequence, mt_qpm], outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] ) gr.Markdown("**Model**: [juancopi81/lmd-8bars-2048-epochs40_v4](https://huggingface.co/juancopi81/lmd-8bars-2048-epochs40_v4)") # Tab 2: SkyTNT MIDI Model with gr.TabItem("SkyTNT MIDI Model"): if not SKYTNT_AVAILABLE: gr.Markdown("⚠️ SkyTNT model not available (onnxruntime required).") else: gr.Markdown("Select instruments and generate MIDI events. Processing: ~20 seconds for 200 events.") with gr.Row(): with gr.Column(): sky_instruments = gr.Dropdown( label="Instruments (optional, auto if empty)", choices=get_available_instruments(), multiselect=True, max_choices=10 ) sky_drum_kit = gr.Dropdown( label="Drum Kit", choices=get_available_drum_kits(), value="None" ) sky_bpm = gr.Slider(60, 200, step=5, value=120, label="BPM") sky_max_events = gr.Slider(50, 500, step=50, value=200, label="Max Events") with gr.Column(): with gr.Accordion("Advanced Options", open=False): sky_temp = gr.Slider(0.1, 1.5, step=0.05, value=1.0, label="Temperature") sky_top_p = gr.Slider(0.5, 1.0, step=0.05, value=0.95, label="Top-p") sky_top_k = gr.Slider(1, 50, step=1, value=20, label="Top-k") sky_seed_rand = gr.Checkbox(label="Random Seed", value=True) sky_seed = gr.Number(label="Seed", value=42) btn_sky_generate = gr.Button("Generate", variant="primary") with gr.Row(): sky_audio = gr.Audio(label="Listen") sky_midi = gr.File(label="Download MIDI") sky_img = gr.Image(label="Visualization") sky_status = gr.Textbox(label="Status", interactive=False) btn_sky_generate.click( fn=generate_skytnt, inputs=[sky_instruments, sky_drum_kit, sky_bpm, sky_max_events, sky_temp, sky_top_p, sky_top_k, sky_seed_rand, sky_seed], outputs=[sky_midi, sky_img, sky_status, sky_audio] ) gr.Markdown("**Model**: [skytnt/midi-model](https://huggingface.co/skytnt/midi-model) (ONNX)") gr.Markdown(""" --- **Credits**: Dr. Tristan Behrens (Multitrack) | SkyTNT (MIDI Model) """) return demo # ============================================================================= # CLI Interface # ============================================================================= def cli_main(): """CLI entry point.""" parser = argparse.ArgumentParser(description="Multitrack MIDI Composer") subparsers = parser.add_subparsers(dest="command", help="Commands") # Demo command demo_parser = subparsers.add_parser("demo", help="Generate demo MIDI") demo_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM") demo_parser.add_argument("--bars", type=int, default=4, help="Number of bars") demo_parser.add_argument("--output", "-o", type=str, default="demo.mid", help="Output file") # Generate command gen_parser = subparsers.add_parser("generate", help="Generate multitrack music") gen_parser.add_argument("--genre", type=str, default="POP", choices=GENRES) gen_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM") gen_parser.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature") gen_parser.add_argument("--output", "-o", type=str, default="output.mid", help="Output file") args = parser.parse_args() if args.command == "demo": midi_path, fig, audio_path, status = create_demo_midi(args.tempo, args.bars) if midi_path: import shutil shutil.copy(midi_path, args.output) print(f"Created: {args.output}") if audio_path: audio_out = args.output.replace('.mid', '.wav') shutil.copy(audio_path, audio_out) print(f"Audio: {audio_out}") print(status) else: print(f"Error: {status}") elif args.command == "generate": if not TORCH_AVAILABLE: print("Error: PyTorch not available. Install: pip install torch transformers note-seq") return print(f"Generating {args.genre} music at {args.tempo} BPM...") midi_path, fig, instruments, sequence, tokens, audio_path = generate_song( genre=args.genre, temp=args.temperature, qpm=args.tempo ) if midi_path: import shutil shutil.copy(midi_path, args.output) print(f"Created: {args.output}") if audio_path: audio_out = args.output.replace('.mid', '.wav') shutil.copy(audio_path, audio_out) print(f"Audio: {audio_out}") print(f"Instruments:\n{instruments}") print(f"Tokens: {tokens}") else: parser.print_help() if __name__ == "__main__": if len(sys.argv) > 1 and sys.argv[1] in ["demo", "generate", "-h", "--help"]: cli_main() else: # Preload model if available if TORCH_AVAILABLE: print("Initializing model...") get_model_and_tokenizer() demo = build_ui() demo.launch()