Spaces:
Sleeping
Sleeping
| """ | |
| Multitrack MIDI Composer - Combined MIDI generation tools | |
| - Simple MIDI Composer: Demo mode chord progressions | |
| - Multitrack Generator: AI multi-instrument composition with genre selection | |
| CPU-only HuggingFace Space | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| import argparse | |
| import struct | |
| import wave | |
| from typing import List, Tuple, Optional | |
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
| import gradio as gr | |
| import numpy as np | |
| # Import meltysynth for pure Python MIDI-to-audio synthesis | |
| try: | |
| import meltysynth as ms | |
| MELTYSYNTH_AVAILABLE = True | |
| except ImportError: | |
| MELTYSYNTH_AVAILABLE = False | |
| print("meltysynth not available - Audio playback disabled") | |
| # Path to SoundFont file | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| SOUNDFONT_PATH = os.path.join(SCRIPT_DIR, "TimGM6mb.sf2") | |
| # Global synthesizer (loaded once) | |
| _synthesizer = None | |
| _synth_settings = None | |
| def get_synthesizer(): | |
| """Load the SoundFont synthesizer (cached).""" | |
| global _synthesizer, _synth_settings | |
| if _synthesizer is None and MELTYSYNTH_AVAILABLE: | |
| if os.path.exists(SOUNDFONT_PATH): | |
| print(f"Loading SoundFont: {SOUNDFONT_PATH}") | |
| sound_font = ms.SoundFont.from_file(SOUNDFONT_PATH) | |
| _synth_settings = ms.SynthesizerSettings(44100) | |
| _synthesizer = ms.Synthesizer(sound_font, _synth_settings) | |
| print("SoundFont loaded successfully!") | |
| else: | |
| print(f"SoundFont not found: {SOUNDFONT_PATH}") | |
| return _synthesizer, _synth_settings | |
| def render_midi_to_audio(midi_path: str) -> Optional[str]: | |
| """Render a MIDI file to WAV audio using meltysynth.""" | |
| if not MELTYSYNTH_AVAILABLE: | |
| return None | |
| synth, settings = get_synthesizer() | |
| if synth is None: | |
| return None | |
| try: | |
| # Load MIDI file | |
| midi_file = ms.MidiFile.from_file(midi_path) | |
| sequencer = ms.MidiFileSequencer(synth) | |
| sequencer.play(midi_file, False) | |
| # Calculate buffer size (duration + 1 second for tail) | |
| duration = midi_file.length + 1.0 | |
| buffer_length = int(settings.sample_rate * duration) | |
| # Create buffers and render | |
| left = ms.create_buffer(buffer_length) | |
| right = ms.create_buffer(buffer_length) | |
| sequencer.render(left, right) | |
| # Convert to interleaved stereo int16 | |
| samples = [] | |
| for i in range(buffer_length): | |
| # Clamp and convert to int16 | |
| l_sample = max(-1.0, min(1.0, left[i])) | |
| r_sample = max(-1.0, min(1.0, right[i])) | |
| samples.append(int(l_sample * 32767)) | |
| samples.append(int(r_sample * 32767)) | |
| # Write WAV file | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: | |
| wav_path = f.name | |
| with wave.open(wav_path, 'w') as wav_file: | |
| wav_file.setnchannels(2) | |
| wav_file.setsampwidth(2) # 16-bit | |
| wav_file.setframerate(settings.sample_rate) | |
| wav_file.writeframes(struct.pack(f'<{len(samples)}h', *samples)) | |
| return wav_path | |
| except Exception as e: | |
| print(f"Audio render error: {e}") | |
| return None | |
| # ============================================================================= | |
| # Tab 1: Simple MIDI Composer (Demo Mode) | |
| # ============================================================================= | |
| try: | |
| from midiutil import MIDIFile | |
| MIDIUTIL_AVAILABLE = True | |
| except ImportError: | |
| MIDIUTIL_AVAILABLE = False | |
| print("midiutil not available - Demo Composer disabled") | |
| def create_piano_roll(notes_data, total_time, title="Piano Roll"): | |
| """Create a piano roll visualization and save as PNG image.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8'] | |
| for i, (start, end, pitch, track) in enumerate(notes_data): | |
| color = colors[track % len(colors)] | |
| rect = Rectangle((start, pitch - 0.4), end - start, 0.8, | |
| facecolor=color, edgecolor='black', linewidth=0.5, alpha=0.8) | |
| ax.add_patch(rect) | |
| ax.set_xlim(0, total_time) | |
| ax.set_ylim(40, 90) | |
| ax.set_xlabel('Time (beats)') | |
| ax.set_ylabel('MIDI Pitch') | |
| ax.set_title(title) | |
| ax.grid(True, alpha=0.3) | |
| # Save to PNG file | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: | |
| img_path = f.name | |
| fig.savefig(img_path, dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| return img_path | |
| def create_demo_midi(tempo: int = 120, length_bars: int = 4): | |
| """Create demo MIDI with chord progression and melody, with visualization and audio.""" | |
| if not MIDIUTIL_AVAILABLE: | |
| return None, None, None, "midiutil not installed" | |
| try: | |
| midi = MIDIFile(2) # Two tracks: melody and chords | |
| midi.addTempo(0, 0, tempo) | |
| midi.addProgramChange(0, 0, 0, 0) # Piano for melody | |
| midi.addProgramChange(1, 0, 0, 0) # Piano for chords | |
| # Simple chord progression (C - Am - F - G) | |
| chords = [ | |
| [60, 64, 67], # C major | |
| [57, 60, 64], # A minor | |
| [53, 57, 60], # F major | |
| [55, 59, 62], # G major | |
| ] | |
| # Melody notes | |
| melody_notes = [ | |
| 72, 74, 76, 77, 76, 74, 72, 71, | |
| 69, 71, 72, 74, 72, 71, 69, 67, | |
| 65, 67, 69, 71, 72, 71, 69, 67, | |
| 67, 69, 71, 72, 74, 76, 77, 79, | |
| ] | |
| # Collect notes for visualization | |
| notes_data = [] | |
| for bar in range(length_bars): | |
| bar_time = bar * 4 # In beats | |
| chord_idx = bar % len(chords) | |
| # Add chord notes | |
| for note in chords[chord_idx]: | |
| midi.addNote(1, 0, note, bar_time, 4, 60) | |
| notes_data.append((bar_time, bar_time + 4, note, 1)) | |
| # Add melody notes (8 per bar) | |
| for i in range(8): | |
| note_time = bar_time + (i * 0.5) | |
| note_idx = (bar * 8 + i) % len(melody_notes) | |
| midi.addNote(0, 0, melody_notes[note_idx], note_time, 0.4, 90) | |
| notes_data.append((note_time, note_time + 0.4, melody_notes[note_idx], 0)) | |
| with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f: | |
| midi.writeFile(f) | |
| midi_path = f.name | |
| # Create visualization | |
| total_time = length_bars * 4 | |
| fig = create_piano_roll(notes_data, total_time, f"Demo: {length_bars} bars at {tempo} BPM") | |
| # Render audio | |
| audio_path = render_midi_to_audio(midi_path) | |
| status = f"Created: {length_bars} bars at {tempo} BPM" | |
| if audio_path: | |
| status += " - Audio rendered!" | |
| else: | |
| status += " - Download MIDI to play" | |
| return midi_path, fig, audio_path, status | |
| except Exception as e: | |
| return None, None, None, f"Error: {str(e)}" | |
| # ============================================================================= | |
| # Tab 2: Multitrack Generator (Transformer-based) | |
| # ============================================================================= | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import note_seq | |
| from note_seq.protobuf.music_pb2 import NoteSequence | |
| from note_seq.constants import STANDARD_PPQ | |
| from matplotlib.figure import Figure | |
| TORCH_AVAILABLE = True | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| print("PyTorch/Transformers not available - Multitrack Generator disabled") | |
| SAMPLE_RATE = 44100 | |
| GM_INSTRUMENTS = [ | |
| "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", | |
| "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", | |
| "Clavi", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", | |
| "Xylophone", "Tubular Bells", "Dulcimer", "Drawbar Organ", "Percussive Organ", | |
| "Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica", | |
| "Tango Accordion", "Acoustic Guitar (nylon)", "Acoustic Guitar (steel)", | |
| "Electric Guitar (jazz)", "Electric Guitar (clean)", "Electric Guitar (muted)", | |
| "Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics", "Acoustic Bass", | |
| "Electric Bass (finger)", "Electric Bass (pick)", "Fretless Bass", "Slap Bass 1", | |
| "Slap Bass 2", "Synth Bass 1", "Synth Bass 2", "Violin", "Viola", "Cello", | |
| "Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp", | |
| "Timpani", "String Ensemble 1", "String Ensemble 2", "Synth Strings 1", | |
| "Synth Strings 2", "Choir Aahs", "Voice Oohs", "Synth Choir", "Orchestra Hit", | |
| "Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section", | |
| "Synth Brass 1", "Synth Brass 2", "Soprano Sax", "Alto Sax", "Tenor Sax", | |
| "Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet", "Piccolo", | |
| "Flute", "Recorder", "Pan Flute", "Blown Bottle", "Shakuhachi", "Whistle", | |
| "Ocarina", "Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)", | |
| "Lead 4 (chiff)", "Lead 5 (charang)", "Lead 6 (voice)", "Lead 7 (fifths)", | |
| "Lead 8 (bass + lead)", "Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)", | |
| "Pad 4 (choir)", "Pad 5 (bowed)", "Pad 6 (metallic)", "Pad 7 (halo)", | |
| "Pad 8 (sweep)", "FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)", | |
| "FX 4 (atmosphere)", "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)", | |
| "FX 8 (sci-fi)", "Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe", | |
| "Fiddle", "Shanai", "Tinkle Bell", "Agogo", "Steel Drums", "Woodblock", | |
| "Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal", "Guitar Fret Noise", | |
| "Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter", | |
| "Applause", "Gunshot", | |
| ] | |
| # Global model and tokenizer | |
| device = None | |
| tokenizer = None | |
| model = None | |
| def get_model_and_tokenizer(): | |
| """Load model and tokenizer on CPU.""" | |
| global model, tokenizer, device | |
| if model is None or tokenizer is None: | |
| device = torch.device("cpu") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer") | |
| print("Loading model on CPU...") | |
| model = AutoModelForCausalLM.from_pretrained("juancopi81/lmd-8bars-2048-epochs40_v4") | |
| model = model.to(device) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| return model, tokenizer | |
| def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> "NoteSequence": | |
| """Create an empty note sequence.""" | |
| note_sequence = NoteSequence() | |
| note_sequence.tempos.add().qpm = qpm | |
| note_sequence.ticks_per_quarter = STANDARD_PPQ | |
| note_sequence.total_time = total_time | |
| return note_sequence | |
| def token_sequence_to_note_sequence( | |
| token_sequence: str, | |
| qpm: float = 120.0, | |
| use_program: bool = True, | |
| use_drums: bool = True, | |
| instrument_mapper: Optional[dict] = None, | |
| only_piano: bool = False, | |
| ) -> "NoteSequence": | |
| """Convert a sequence of tokens into a sequence of notes.""" | |
| if isinstance(token_sequence, str): | |
| token_sequence = token_sequence.split() | |
| note_sequence = empty_note_sequence(qpm) | |
| note_length_16th = 0.25 * 60 / qpm | |
| bar_length = 4.0 * 60 / qpm | |
| current_program = 1 | |
| current_is_drum = False | |
| current_instrument = 0 | |
| track_count = 0 | |
| current_bar_index = 0 | |
| current_time = 0 | |
| current_notes = {} | |
| for token in token_sequence: | |
| if token == "PIECE_START": | |
| pass | |
| elif token == "PIECE_END": | |
| break | |
| elif token == "TRACK_START": | |
| current_bar_index = 0 | |
| track_count += 1 | |
| elif token == "TRACK_END": | |
| pass | |
| elif token == "KEYS_START": | |
| pass | |
| elif token == "KEYS_END": | |
| pass | |
| elif token.startswith("KEY="): | |
| pass | |
| elif token.startswith("INST"): | |
| instrument = token.split("=")[-1] | |
| if instrument != "DRUMS" and use_program: | |
| if instrument_mapper is not None: | |
| if instrument in instrument_mapper: | |
| instrument = instrument_mapper[instrument] | |
| try: | |
| current_program = int(instrument) | |
| except ValueError: | |
| current_program = 0 | |
| current_instrument = track_count | |
| current_is_drum = False | |
| if instrument == "DRUMS" and use_drums: | |
| current_instrument = 0 | |
| current_program = 0 | |
| current_is_drum = True | |
| elif token == "BAR_START": | |
| current_time = current_bar_index * bar_length | |
| current_notes = {} | |
| elif token == "BAR_END": | |
| current_bar_index += 1 | |
| elif token.startswith("NOTE_ON"): | |
| try: | |
| pitch = int(token.split("=")[-1]) | |
| note = note_sequence.notes.add() | |
| note.start_time = current_time | |
| note.end_time = current_time + 4 * note_length_16th | |
| note.pitch = pitch | |
| note.instrument = current_instrument | |
| note.program = current_program | |
| note.velocity = 80 | |
| note.is_drum = current_is_drum | |
| current_notes[pitch] = note | |
| except ValueError: | |
| pass | |
| elif token.startswith("NOTE_OFF"): | |
| try: | |
| pitch = int(token.split("=")[-1]) | |
| if pitch in current_notes: | |
| note = current_notes[pitch] | |
| note.end_time = current_time | |
| except ValueError: | |
| pass | |
| elif token.startswith("TIME_DELTA"): | |
| try: | |
| delta = float(token.split("=")[-1]) * note_length_16th | |
| current_time += delta | |
| except ValueError: | |
| pass | |
| elif token.startswith("DENSITY="): | |
| pass | |
| elif token == "[PAD]": | |
| pass | |
| # Make the instruments right | |
| instruments_drums = [] | |
| for note in note_sequence.notes: | |
| pair = [note.program, note.is_drum] | |
| if pair not in instruments_drums: | |
| instruments_drums += [pair] | |
| note.instrument = instruments_drums.index(pair) | |
| if only_piano: | |
| for note in note_sequence.notes: | |
| if not note.is_drum: | |
| note.instrument = 0 | |
| note.program = 0 | |
| return note_sequence | |
| def create_seed_string(genre: str = "OTHER") -> str: | |
| """Create a seed string for generating a new piece.""" | |
| if genre == "RANDOM": | |
| return "PIECE_START" | |
| return f"PIECE_START GENRE={genre} TRACK_START" | |
| def get_instruments(text_sequence: str) -> List[str]: | |
| """Extract the list of instruments from a text sequence.""" | |
| instruments = [] | |
| parts = text_sequence.split() | |
| for part in parts: | |
| if part.startswith("INST="): | |
| if part[5:] == "DRUMS": | |
| instruments.append("Drums") | |
| else: | |
| try: | |
| index = int(part[5:]) | |
| if 0 <= index < len(GM_INSTRUMENTS): | |
| instruments.append(GM_INSTRUMENTS[index]) | |
| else: | |
| instruments.append(f"Program {index}") | |
| except ValueError: | |
| pass | |
| return instruments | |
| def generate_new_instrument(seed: str, temp: float = 0.75) -> str: | |
| """Generate a new instrument sequence from a given seed and temperature.""" | |
| model, tok = get_model_and_tokenizer() | |
| seed_length = len(tok.encode(seed)) | |
| # Retry until we get a valid generation with notes | |
| max_attempts = 5 | |
| for attempt in range(max_attempts): | |
| input_ids = tok.encode(seed, return_tensors="pt") | |
| input_ids = input_ids.to(device) | |
| eos_token_id = tok.encode("TRACK_END")[0] | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=2048, | |
| do_sample=True, | |
| temperature=temp, | |
| eos_token_id=eos_token_id, | |
| ) | |
| generated_sequence = tok.decode(generated_ids[0]) | |
| new_generated_sequence = tok.decode(generated_ids[0][seed_length:]) | |
| if "NOTE_ON" in new_generated_sequence: | |
| return generated_sequence | |
| # Return last attempt even if no NOTE_ON found | |
| return generated_sequence | |
| def create_noteseq_piano_roll(note_sequence, title="Generated Music"): | |
| """Create a piano roll visualization from a NoteSequence using matplotlib.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| # Color by instrument | |
| colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C'] | |
| if len(note_sequence.notes) == 0: | |
| ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14) | |
| ax.set_xlim(0, 10) | |
| ax.set_ylim(0, 127) | |
| else: | |
| min_pitch = min(n.pitch for n in note_sequence.notes) | |
| max_pitch = max(n.pitch for n in note_sequence.notes) | |
| max_time = max(n.end_time for n in note_sequence.notes) | |
| for note in note_sequence.notes: | |
| color = colors[note.instrument % len(colors)] | |
| rect = Rectangle( | |
| (note.start_time, note.pitch - 0.4), | |
| note.end_time - note.start_time, | |
| 0.8, | |
| facecolor=color, | |
| edgecolor='black', | |
| linewidth=0.3, | |
| alpha=0.8 | |
| ) | |
| ax.add_patch(rect) | |
| ax.set_xlim(0, max_time + 0.5) | |
| ax.set_ylim(min_pitch - 2, max_pitch + 2) | |
| ax.set_xlabel('Time (seconds)') | |
| ax.set_ylabel('MIDI Pitch') | |
| ax.set_title(title) | |
| ax.grid(True, alpha=0.3) | |
| # Save to PNG | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: | |
| img_path = f.name | |
| fig.savefig(img_path, dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| return img_path | |
| def get_outputs_from_string(generated_sequence: str, qpm: int = 120): | |
| """Convert a generated sequence into various output formats.""" | |
| instruments = get_instruments(generated_sequence) | |
| instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) | |
| note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) | |
| # Create visualization using custom matplotlib function | |
| img_path = create_noteseq_piano_roll(note_sequence, f"Generated at {qpm} BPM") | |
| num_tokens = str(len(generated_sequence.split())) | |
| # Save MIDI file | |
| with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as f: | |
| midi_path = f.name | |
| note_seq.note_sequence_to_midi_file(note_sequence, midi_path) | |
| # Render audio | |
| audio_path = render_midi_to_audio(midi_path) | |
| return midi_path, img_path, instruments_str, num_tokens, audio_path | |
| def generate_song(genre: str = "OTHER", temp: float = 0.75, text_sequence: str = "", qpm: int = 120): | |
| """Generate a song given a genre, temperature, initial text sequence, and tempo.""" | |
| if not TORCH_AVAILABLE: | |
| return None, None, "PyTorch not available", "", "0", None | |
| if text_sequence == "": | |
| seed_string = create_seed_string(genre) | |
| else: | |
| seed_string = text_sequence | |
| generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) | |
| midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( | |
| generated_sequence, qpm | |
| ) | |
| return midi_file, fig, instruments_str, generated_sequence, num_tokens, audio_path | |
| def remove_last_instrument(text_sequence: str, qpm: int = 120): | |
| """Remove the last instrument from a song string.""" | |
| if not TORCH_AVAILABLE: | |
| return None, None, "PyTorch not available", "", "0", None | |
| tracks = text_sequence.split("TRACK_START") | |
| modified_tracks = tracks[:-1] | |
| new_song = "TRACK_START".join(modified_tracks) | |
| if len(tracks) == 2: | |
| midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( | |
| text_sequence=new_song | |
| ) | |
| elif len(tracks) == 1: | |
| midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( | |
| text_sequence="" | |
| ) | |
| else: | |
| midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( | |
| new_song, qpm | |
| ) | |
| return midi_file, fig, instruments_str, new_song, num_tokens, audio_path | |
| def regenerate_last_instrument(text_sequence: str, qpm: int = 120): | |
| """Regenerate the last instrument in a song string.""" | |
| if not TORCH_AVAILABLE: | |
| return None, None, "PyTorch not available", "", "0", None | |
| last_inst_index = text_sequence.rfind("INST=") | |
| if last_inst_index == -1: | |
| midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( | |
| text_sequence="", qpm=qpm | |
| ) | |
| else: | |
| next_space_index = text_sequence.find(" ", last_inst_index) | |
| if next_space_index == -1: | |
| # No space after INST=, use the whole remaining string | |
| new_seed = text_sequence | |
| else: | |
| new_seed = text_sequence[:next_space_index] | |
| midi_file, fig, instruments_str, new_song, num_tokens, audio_path = generate_song( | |
| text_sequence=new_seed, qpm=qpm | |
| ) | |
| return midi_file, fig, instruments_str, new_song, num_tokens, audio_path | |
| def change_tempo(text_sequence: str, qpm: int): | |
| """Change the tempo of a song string.""" | |
| if not TORCH_AVAILABLE: | |
| return None, None, "PyTorch not available", "", "0", None | |
| if not text_sequence or text_sequence.strip() == "": | |
| return None, None, "No sequence to process", "", "0", None | |
| midi_file, fig, instruments_str, num_tokens, audio_path = get_outputs_from_string( | |
| text_sequence, qpm=qpm | |
| ) | |
| return midi_file, fig, instruments_str, text_sequence, num_tokens, audio_path | |
| # ============================================================================= | |
| # SkyTNT Model Integration | |
| # ============================================================================= | |
| try: | |
| from skytnt_generator import generate_midi as skytnt_generate, get_available_instruments, get_available_drum_kits, ONNX_AVAILABLE | |
| SKYTNT_AVAILABLE = ONNX_AVAILABLE | |
| except ImportError: | |
| SKYTNT_AVAILABLE = False | |
| print("SkyTNT generator not available") | |
| def generate_skytnt(instruments, drum_kit, bpm, max_events, temp, top_p, top_k, seed_rand, seed): | |
| """Generate music using SkyTNT model.""" | |
| if not SKYTNT_AVAILABLE: | |
| return None, None, "SkyTNT model not available", None | |
| # Parse instruments | |
| instr_list = instruments if instruments else [] | |
| actual_seed = None if seed_rand else int(seed) | |
| try: | |
| midi_path = skytnt_generate( | |
| instruments=instr_list, | |
| drum_kit=drum_kit, | |
| bpm=int(bpm), | |
| max_events=int(max_events), | |
| temp=temp, | |
| top_p=top_p, | |
| top_k=int(top_k), | |
| seed=actual_seed | |
| ) | |
| if midi_path is None: | |
| return None, None, "Generation failed", None | |
| # Create visualization | |
| img_path = create_skytnt_piano_roll(midi_path) | |
| # Render audio | |
| audio_path = render_midi_to_audio(midi_path) | |
| status = f"Generated with {len(instr_list)} instruments at {bpm} BPM" | |
| if audio_path: | |
| status += " - Audio rendered!" | |
| return midi_path, img_path, status, audio_path | |
| except Exception as e: | |
| return None, None, f"Error: {str(e)}", None | |
| def create_skytnt_piano_roll(midi_path: str): | |
| """Create piano roll visualization from MIDI file.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| try: | |
| import MIDI | |
| midi_data = MIDI.midi2score(open(midi_path, 'rb').read()) | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F39C12', '#9B59B6', '#1ABC9C'] | |
| ticks_per_beat = midi_data[0] | |
| all_notes = [] | |
| for track_idx, track in enumerate(midi_data[1:]): | |
| for event in track: | |
| if event[0] == 'note': | |
| start_time = event[1] / ticks_per_beat | |
| duration = event[2] / ticks_per_beat | |
| channel = event[3] | |
| pitch = event[4] | |
| all_notes.append((start_time, duration, pitch, channel)) | |
| if not all_notes: | |
| ax.text(0.5, 0.5, 'No notes generated', ha='center', va='center', fontsize=14) | |
| ax.set_xlim(0, 10) | |
| ax.set_ylim(0, 127) | |
| else: | |
| min_pitch = min(n[2] for n in all_notes) | |
| max_pitch = max(n[2] for n in all_notes) | |
| max_time = max(n[0] + n[1] for n in all_notes) | |
| for start, dur, pitch, channel in all_notes: | |
| color = colors[channel % len(colors)] | |
| rect = Rectangle((start, pitch - 0.4), dur, 0.8, | |
| facecolor=color, edgecolor='black', linewidth=0.3, alpha=0.8) | |
| ax.add_patch(rect) | |
| ax.set_xlim(0, max_time + 0.5) | |
| ax.set_ylim(min_pitch - 2, max_pitch + 2) | |
| ax.set_xlabel('Time (beats)') | |
| ax.set_ylabel('MIDI Pitch') | |
| ax.set_title('Generated Music (SkyTNT)') | |
| ax.grid(True, alpha=0.3) | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: | |
| img_path = f.name | |
| fig.savefig(img_path, dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| return img_path | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| return None | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| GENRES = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"] | |
| def build_ui(): | |
| """Build Gradio interface for AI Multitrack MIDI Composer.""" | |
| with gr.Blocks(title="Multitrack MIDI Composer") as demo: | |
| gr.Markdown(""" | |
| # 🎹 Multitrack MIDI Composer | |
| AI-powered multi-instrument music generation. Choose a model and generate! | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Multitrack Generator (juancopi81) - Default | |
| with gr.TabItem("Multitrack Generator (Genre-based)"): | |
| if not TORCH_AVAILABLE: | |
| gr.Markdown("⚠️ PyTorch/Transformers not installed.") | |
| else: | |
| with gr.Row(): | |
| with gr.Column(): | |
| mt_temp = gr.Slider(0, 1, step=0.05, value=0.85, label="Temperature") | |
| mt_genre = gr.Dropdown(choices=GENRES, value="POP", label="Select Genre") | |
| with gr.Row(): | |
| btn_from_scratch = gr.Button("Start from scratch", variant="primary") | |
| btn_continue = gr.Button("Continue Generation") | |
| with gr.Row(): | |
| btn_remove_last = gr.Button("Remove last instrument") | |
| btn_regenerate_last = gr.Button("Regenerate last instrument") | |
| with gr.Column(): | |
| mt_audio = gr.Audio(label="Listen") | |
| with gr.Group(): | |
| mt_midi = gr.File(label="Download MIDI File") | |
| with gr.Row(): | |
| mt_qpm = gr.Slider(60, 140, step=10, value=120, label="Tempo") | |
| btn_qpm = gr.Button("Change Tempo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| mt_img = gr.Image(label="Music Visualization") | |
| with gr.Column(): | |
| mt_instruments = gr.Markdown("### Instruments") | |
| mt_sequence = gr.Textbox(label="Token Sequence", lines=3) | |
| mt_empty = gr.Textbox(visible=False, value="") | |
| mt_tokens = gr.Textbox(visible=False) | |
| btn_from_scratch.click( | |
| fn=generate_song, | |
| inputs=[mt_genre, mt_temp, mt_empty, mt_qpm], | |
| outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] | |
| ) | |
| btn_continue.click( | |
| fn=generate_song, | |
| inputs=[mt_genre, mt_temp, mt_sequence, mt_qpm], | |
| outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] | |
| ) | |
| btn_remove_last.click( | |
| fn=remove_last_instrument, | |
| inputs=[mt_sequence, mt_qpm], | |
| outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] | |
| ) | |
| btn_regenerate_last.click( | |
| fn=regenerate_last_instrument, | |
| inputs=[mt_sequence, mt_qpm], | |
| outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] | |
| ) | |
| btn_qpm.click( | |
| fn=change_tempo, | |
| inputs=[mt_sequence, mt_qpm], | |
| outputs=[mt_midi, mt_img, mt_instruments, mt_sequence, mt_tokens, mt_audio] | |
| ) | |
| gr.Markdown("**Model**: [juancopi81/lmd-8bars-2048-epochs40_v4](https://huggingface.co/juancopi81/lmd-8bars-2048-epochs40_v4)") | |
| # Tab 2: SkyTNT MIDI Model | |
| with gr.TabItem("SkyTNT MIDI Model"): | |
| if not SKYTNT_AVAILABLE: | |
| gr.Markdown("⚠️ SkyTNT model not available (onnxruntime required).") | |
| else: | |
| gr.Markdown("Select instruments and generate MIDI events. Processing: ~20 seconds for 200 events.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sky_instruments = gr.Dropdown( | |
| label="Instruments (optional, auto if empty)", | |
| choices=get_available_instruments(), | |
| multiselect=True, | |
| max_choices=10 | |
| ) | |
| sky_drum_kit = gr.Dropdown( | |
| label="Drum Kit", | |
| choices=get_available_drum_kits(), | |
| value="None" | |
| ) | |
| sky_bpm = gr.Slider(60, 200, step=5, value=120, label="BPM") | |
| sky_max_events = gr.Slider(50, 500, step=50, value=200, label="Max Events") | |
| with gr.Column(): | |
| with gr.Accordion("Advanced Options", open=False): | |
| sky_temp = gr.Slider(0.1, 1.5, step=0.05, value=1.0, label="Temperature") | |
| sky_top_p = gr.Slider(0.5, 1.0, step=0.05, value=0.95, label="Top-p") | |
| sky_top_k = gr.Slider(1, 50, step=1, value=20, label="Top-k") | |
| sky_seed_rand = gr.Checkbox(label="Random Seed", value=True) | |
| sky_seed = gr.Number(label="Seed", value=42) | |
| btn_sky_generate = gr.Button("Generate", variant="primary") | |
| with gr.Row(): | |
| sky_audio = gr.Audio(label="Listen") | |
| sky_midi = gr.File(label="Download MIDI") | |
| sky_img = gr.Image(label="Visualization") | |
| sky_status = gr.Textbox(label="Status", interactive=False) | |
| btn_sky_generate.click( | |
| fn=generate_skytnt, | |
| inputs=[sky_instruments, sky_drum_kit, sky_bpm, sky_max_events, | |
| sky_temp, sky_top_p, sky_top_k, sky_seed_rand, sky_seed], | |
| outputs=[sky_midi, sky_img, sky_status, sky_audio] | |
| ) | |
| gr.Markdown("**Model**: [skytnt/midi-model](https://huggingface.co/skytnt/midi-model) (ONNX)") | |
| gr.Markdown(""" | |
| --- | |
| **Credits**: Dr. Tristan Behrens (Multitrack) | SkyTNT (MIDI Model) | |
| """) | |
| return demo | |
| # ============================================================================= | |
| # CLI Interface | |
| # ============================================================================= | |
| def cli_main(): | |
| """CLI entry point.""" | |
| parser = argparse.ArgumentParser(description="Multitrack MIDI Composer") | |
| subparsers = parser.add_subparsers(dest="command", help="Commands") | |
| # Demo command | |
| demo_parser = subparsers.add_parser("demo", help="Generate demo MIDI") | |
| demo_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM") | |
| demo_parser.add_argument("--bars", type=int, default=4, help="Number of bars") | |
| demo_parser.add_argument("--output", "-o", type=str, default="demo.mid", help="Output file") | |
| # Generate command | |
| gen_parser = subparsers.add_parser("generate", help="Generate multitrack music") | |
| gen_parser.add_argument("--genre", type=str, default="POP", choices=GENRES) | |
| gen_parser.add_argument("--tempo", type=int, default=120, help="Tempo in BPM") | |
| gen_parser.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature") | |
| gen_parser.add_argument("--output", "-o", type=str, default="output.mid", help="Output file") | |
| args = parser.parse_args() | |
| if args.command == "demo": | |
| midi_path, fig, audio_path, status = create_demo_midi(args.tempo, args.bars) | |
| if midi_path: | |
| import shutil | |
| shutil.copy(midi_path, args.output) | |
| print(f"Created: {args.output}") | |
| if audio_path: | |
| audio_out = args.output.replace('.mid', '.wav') | |
| shutil.copy(audio_path, audio_out) | |
| print(f"Audio: {audio_out}") | |
| print(status) | |
| else: | |
| print(f"Error: {status}") | |
| elif args.command == "generate": | |
| if not TORCH_AVAILABLE: | |
| print("Error: PyTorch not available. Install: pip install torch transformers note-seq") | |
| return | |
| print(f"Generating {args.genre} music at {args.tempo} BPM...") | |
| midi_path, fig, instruments, sequence, tokens, audio_path = generate_song( | |
| genre=args.genre, temp=args.temperature, qpm=args.tempo | |
| ) | |
| if midi_path: | |
| import shutil | |
| shutil.copy(midi_path, args.output) | |
| print(f"Created: {args.output}") | |
| if audio_path: | |
| audio_out = args.output.replace('.mid', '.wav') | |
| shutil.copy(audio_path, audio_out) | |
| print(f"Audio: {audio_out}") | |
| print(f"Instruments:\n{instruments}") | |
| print(f"Tokens: {tokens}") | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1 and sys.argv[1] in ["demo", "generate", "-h", "--help"]: | |
| cli_main() | |
| else: | |
| # Preload model if available | |
| if TORCH_AVAILABLE: | |
| print("Initializing model...") | |
| get_model_and_tokenizer() | |
| demo = build_ui() | |
| demo.launch() | |