Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """loop_ downloadOnly_MidiMusicGenApp.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1gFFJW56DAbeLYqKWqi6jZN_einOQBf4j | |
| """ | |
| import gradio as gr | |
| import torch | |
| import gc | |
| from transformers import GPT2LMHeadModel | |
| from miditokenizer import MIDITokenizer | |
| from genprocessor import GENProcessor, generated_tokens_to_midi | |
| from midi2audio import FluidSynth | |
| from pydub import AudioSegment | |
| import tempfile | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Load model and tokenizer | |
| torch.serialization.add_safe_globals([set]) | |
| torch.serialization.add_safe_globals([GPT2LMHeadModel]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False) | |
| tokenizer = MIDITokenizer() | |
| processor = GENProcessor() | |
| model.eval() | |
| class State: | |
| def __init__(self): | |
| self.generated_text = None # Store the text representation of the music | |
| state = State() | |
| #functions to adjust timing & combine generated song parts | |
| def adjust_midi_timing(midi_data, start_time=0): | |
| """Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat.""" | |
| try: | |
| # Keep tempo track separate | |
| tempo_track = midi_data['tracks'][0] | |
| ticks_per_beat = midi_data['metadata']['ticks_per_beat'] | |
| # Calculate thresholds based on ticks_per_beat | |
| gap_threshold = ticks_per_beat * 2 | |
| small_increment = ticks_per_beat // 8 # Eighth note | |
| # Get all other events and sort by time | |
| all_events = [] | |
| for track in midi_data['tracks'][1:]: | |
| all_events.extend(track) | |
| all_events.sort(key=lambda x: x['time']) | |
| # Find sequential times, ignoring large gaps | |
| sequential_events = [] | |
| current_time = all_events[0]['time'] if all_events else 0 | |
| for event in all_events: | |
| if event['time'] - current_time > gap_threshold: | |
| event['time'] = current_time + small_increment | |
| current_time = event['time'] | |
| sequential_events.append(event) | |
| # Find first non-zero time | |
| first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0) | |
| adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]} | |
| # Adjust all events | |
| adjusted_track = [] | |
| for event in sequential_events: | |
| adjusted_event = event.copy() | |
| if event['time'] != 0: | |
| adjusted_event['time'] = (event['time'] - first_time) + start_time | |
| else: | |
| adjusted_event['time'] = start_time | |
| adjusted_track.append(adjusted_event) | |
| adjusted_data['tracks'].append(adjusted_track) | |
| return adjusted_data | |
| except Exception as e: | |
| print(f"Error adjusting MIDI timing: {str(e)}") | |
| return midi_data | |
| def combine_tracks(first, continued): | |
| """Combine two generated sequences into a single song""" | |
| processor = GENProcessor() | |
| # Check if first input is already decoded or needs decoding | |
| if isinstance(first, str): | |
| gendecoded = processor.decode_midi_file(first) | |
| else: | |
| gendecoded = first | |
| first_midi = adjust_midi_timing(gendecoded, start_time=0) | |
| # Get last time from the single note track | |
| last_time = max(event['time'] for event in first_midi['tracks'][1]) | |
| # adjust timing of second midi | |
| second_midi = adjust_midi_timing(processor.decode_midi_file(continued), start_time=last_time) | |
| # Combine into a single song | |
| full_song = { | |
| 'metadata': first_midi['metadata'], | |
| 'tracks': [ | |
| first_midi['tracks'][0], # Keep tempo track | |
| first_midi['tracks'][1] + second_midi['tracks'][1] # Combine note tracks | |
| ] | |
| } | |
| return full_song | |
| def extract_context(text_or_midi, num_events=3): | |
| """get metadata, composer, and last few events from generated text to create new prompt to continue sequence""" | |
| result_metadata = None | |
| if isinstance(text_or_midi, dict): | |
| # Extract metadata from MIDI structure | |
| ticks = text_or_midi['metadata']['ticks_per_beat'] | |
| numerator = text_or_midi['tracks'][0][1]['numerator'] | |
| composer = text_or_midi.get('composer', 'Bach') #default to Bach if missing | |
| # Get last events from track | |
| all_events = text_or_midi['tracks'][1] | |
| sorted_events = sorted(all_events, key=lambda x: x['time']) | |
| last_events = sorted_events[-num_events:] | |
| # Format metadata string | |
| result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> " | |
| f"ticks_per_beat={ticks} <|START_TRACK|> " | |
| f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4") | |
| # Format context string from last events | |
| context = " ".join(f"<{event['type']}> time={event['time']} channel={event['channel']} " + | |
| (f"note={event['note']} velocity={event['velocity']}" if 'note' in event else | |
| f"control={event['control']} value={event['value']}") | |
| for event in last_events) | |
| last_time = last_events[-1]['time'] if last_events else 0 | |
| else: # Input is string | |
| if "<|START_METADATA|>" in text_or_midi: | |
| composer_match = re.search(r"<\|composer_([^|]+)\|>", text_or_midi) | |
| ticks_match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", text_or_midi) | |
| time_sig_match = re.search(r"numerator=(\d+)", text_or_midi) | |
| if composer_match and ticks_match: | |
| composer = composer_match.group(1) | |
| numerator = 4 | |
| if time_sig_match: | |
| try: | |
| num = int(time_sig_match.group(1)) | |
| numerator = min(4, max(2, num)) | |
| except ValueError: | |
| pass | |
| result_metadata = (f"<|START_METADATA|> <|composer_{composer}|><metadata> " | |
| f"ticks_per_beat={max(75, int(ticks_match.group(1)))} <|START_TRACK|> " | |
| f"tempo=500000 <time_signature> time=0 numerator={numerator} denominator=4") | |
| # Get last N complete events and timestamp | |
| events = [] | |
| matches = re.finditer(r"<(note_on|control_change|note_off)>.*?(value=\d+|velocity=\d+)", text_or_midi) | |
| events = list(matches)[-num_events:] | |
| context = " ".join(event.group(0) for event in events) | |
| last_time = None | |
| time_match = re.search(r"time=(\d+)", context) | |
| if time_match: | |
| last_time = int(time_match.group(1)) | |
| return result_metadata, context, last_time | |
| def continue_sequence(generated_text, num_loops=1): | |
| """continue the sequence, extract info from previous sequence to create prompt, append to prevous sequence""" | |
| full_song = generated_text | |
| for i in range(num_loops): | |
| print(f"Generating loop {i+1}/{num_loops}") | |
| metadata, context, last_time = extract_context(full_song) | |
| print(metadata+context) | |
| continued = generate_music(metadata+context) | |
| full_song = combine_tracks(full_song, continued) | |
| return full_song | |
| #Functions to generate music | |
| def generate_music(prompt): | |
| """Generate music based on a given prompt.""" | |
| # Tokenize | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| add_special_tokens=True | |
| ) | |
| # Generate | |
| output_sequences = model.generate( | |
| input_ids=inputs["input_ids"].to(model.device), | |
| attention_mask=inputs["attention_mask"].to(model.device), | |
| max_length=1024, | |
| do_sample=True, | |
| temperature=0.6, #adjust creativity | |
| top_k=30, | |
| top_p=0.90, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode the generated sequence | |
| generated_text = tokenizer.decode(output_sequences[0]) | |
| return generated_text | |
| def generate_wrapper(composer): | |
| try: | |
| # Clear memory before generation | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Format the prompt with the selected composer | |
| prompt = f"<|START_METADATA|> <|composer_{composer}|><metadata> ticks_per_beat=" | |
| generated_text = generate_music(prompt) | |
| midi_data = adjust_midi_timing(processor.decode_midi_file(generated_text)) | |
| state.generated_text = midi_data | |
| # Create temp file for MIDI | |
| with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp: | |
| generated_tokens_to_midi(midi_data, tmp.name) | |
| return tmp.name, gr.update(visible=True) | |
| finally: | |
| # Clear memory after generation | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def continue_wrapper(): | |
| """Wrapper for continue_sequence""" | |
| if state.generated_text is not None: | |
| # Continue the sequence using your existing function | |
| extended_text = continue_sequence(state.generated_text, num_loops=8) | |
| state.generated_text = extended_text # Update the stored text | |
| # Create temp file for MIDI | |
| with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp: | |
| generated_tokens_to_midi(extended_text, tmp.name) | |
| return tmp.name | |
| return None | |
| with gr.Blocks() as iface: | |
| gr.Markdown(""" | |
| # MAI: MIDI AI Music Generation Model | |
| Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound. | |
| It should take a few minutes. Once it's ready, you can download the audio file. | |
| If you like the opening, you can continue the sequence and make your song longer, repeat, or try again. It may take a few minutes to continue the sequence. | |
| """) | |
| with gr.Column(): | |
| composer_input = gr.Dropdown( | |
| choices=["Bach", "Chopin"], | |
| label="Select Composer", | |
| value="Bach" | |
| ) | |
| generate_btn = gr.Button("Generate Music") | |
| output_file = gr.File(label="Generated MIDI File") | |
| continue_btn = gr.Button("Continue Sequence (add 10 seconds)", visible=False) | |
| #generate_btn.click( | |
| # fn=generate_wrapper, | |
| # inputs=composer_input, | |
| # outputs=[output_file, continue_btn] | |
| #) | |
| generate_btn.click( | |
| lambda: (None, gr.update(visible=False)), | |
| None, | |
| [output_file, continue_btn], | |
| queue=False | |
| ).success( | |
| generate_wrapper, | |
| inputs=composer_input, | |
| outputs=[output_file, continue_btn] | |
| ) | |
| continue_btn.click( | |
| fn=continue_wrapper, | |
| outputs=output_file | |
| ) | |
| iface.launch() | |