# -*- coding: utf-8 -*- """loop_ downloadOnly_MidiMusicGenApp.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1gFFJW56DAbeLYqKWqi6jZN_einOQBf4j """ import gradio as gr import torch import gc from transformers import GPT2LMHeadModel from miditokenizer import MIDITokenizer from genprocessor import GENProcessor, generated_tokens_to_midi from midi2audio import FluidSynth from pydub import AudioSegment import tempfile import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # Load model and tokenizer torch.serialization.add_safe_globals([set]) torch.serialization.add_safe_globals([GPT2LMHeadModel]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load('model_complete_18epochs.pkl',map_location=device, weights_only=False) tokenizer = MIDITokenizer() processor = GENProcessor() model.eval() class State: def __init__(self): self.generated_text = None # Store the text representation of the music state = State() #functions to adjust timing & combine generated song parts def adjust_midi_timing(midi_data, start_time=0): """Adjust MIDI timing with optional start time. Prevent large gaps based on ticks_per_beat.""" try: # Keep tempo track separate tempo_track = midi_data['tracks'][0] ticks_per_beat = midi_data['metadata']['ticks_per_beat'] # Calculate thresholds based on ticks_per_beat gap_threshold = ticks_per_beat * 2 small_increment = ticks_per_beat // 8 # Eighth note # Get all other events and sort by time all_events = [] for track in midi_data['tracks'][1:]: all_events.extend(track) all_events.sort(key=lambda x: x['time']) # Find sequential times, ignoring large gaps sequential_events = [] current_time = all_events[0]['time'] if all_events else 0 for event in all_events: if event['time'] - current_time > gap_threshold: event['time'] = current_time + small_increment current_time = event['time'] sequential_events.append(event) # Find first non-zero time first_time = min((event['time'] for event in sequential_events if event['time'] != 0), default=0) adjusted_data = {'metadata': midi_data['metadata'], 'tracks': [tempo_track]} # Adjust all events adjusted_track = [] for event in sequential_events: adjusted_event = event.copy() if event['time'] != 0: adjusted_event['time'] = (event['time'] - first_time) + start_time else: adjusted_event['time'] = start_time adjusted_track.append(adjusted_event) adjusted_data['tracks'].append(adjusted_track) return adjusted_data except Exception as e: print(f"Error adjusting MIDI timing: {str(e)}") return midi_data def combine_tracks(first, continued): """Combine two generated sequences into a single song""" processor = GENProcessor() # Check if first input is already decoded or needs decoding if isinstance(first, str): gendecoded = processor.decode_midi_file(first) else: gendecoded = first first_midi = adjust_midi_timing(gendecoded, start_time=0) # Get last time from the single note track last_time = max(event['time'] for event in first_midi['tracks'][1]) # adjust timing of second midi second_midi = adjust_midi_timing(processor.decode_midi_file(continued), start_time=last_time) # Combine into a single song full_song = { 'metadata': first_midi['metadata'], 'tracks': [ first_midi['tracks'][0], # Keep tempo track first_midi['tracks'][1] + second_midi['tracks'][1] # Combine note tracks ] } return full_song def extract_context(text_or_midi, num_events=3): """get metadata, composer, and last few events from generated text to create new prompt to continue sequence""" result_metadata = None if isinstance(text_or_midi, dict): # Extract metadata from MIDI structure ticks = text_or_midi['metadata']['ticks_per_beat'] numerator = text_or_midi['tracks'][0][1]['numerator'] composer = text_or_midi.get('composer', 'Bach') #default to Bach if missing # Get last events from track all_events = text_or_midi['tracks'][1] sorted_events = sorted(all_events, key=lambda x: x['time']) last_events = sorted_events[-num_events:] # Format metadata string result_metadata = (f"<|START_METADATA|> <|composer_{composer}|> " f"ticks_per_beat={ticks} <|START_TRACK|> " f"tempo=500000 time=0 numerator={numerator} denominator=4") # Format context string from last events context = " ".join(f"<{event['type']}> time={event['time']} channel={event['channel']} " + (f"note={event['note']} velocity={event['velocity']}" if 'note' in event else f"control={event['control']} value={event['value']}") for event in last_events) last_time = last_events[-1]['time'] if last_events else 0 else: # Input is string if "<|START_METADATA|>" in text_or_midi: composer_match = re.search(r"<\|composer_([^|]+)\|>", text_or_midi) ticks_match = re.search(r"ticks[_]?(?:per_)?beat=(\d+)", text_or_midi) time_sig_match = re.search(r"numerator=(\d+)", text_or_midi) if composer_match and ticks_match: composer = composer_match.group(1) numerator = 4 if time_sig_match: try: num = int(time_sig_match.group(1)) numerator = min(4, max(2, num)) except ValueError: pass result_metadata = (f"<|START_METADATA|> <|composer_{composer}|> " f"ticks_per_beat={max(75, int(ticks_match.group(1)))} <|START_TRACK|> " f"tempo=500000 time=0 numerator={numerator} denominator=4") # Get last N complete events and timestamp events = [] matches = re.finditer(r"<(note_on|control_change|note_off)>.*?(value=\d+|velocity=\d+)", text_or_midi) events = list(matches)[-num_events:] context = " ".join(event.group(0) for event in events) last_time = None time_match = re.search(r"time=(\d+)", context) if time_match: last_time = int(time_match.group(1)) return result_metadata, context, last_time def continue_sequence(generated_text, num_loops=1): """continue the sequence, extract info from previous sequence to create prompt, append to prevous sequence""" full_song = generated_text for i in range(num_loops): print(f"Generating loop {i+1}/{num_loops}") metadata, context, last_time = extract_context(full_song) print(metadata+context) continued = generate_music(metadata+context) full_song = combine_tracks(full_song, continued) return full_song #Functions to generate music def generate_music(prompt): """Generate music based on a given prompt.""" # Tokenize if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True ) # Generate output_sequences = model.generate( input_ids=inputs["input_ids"].to(model.device), attention_mask=inputs["attention_mask"].to(model.device), max_length=1024, do_sample=True, temperature=0.6, #adjust creativity top_k=30, top_p=0.90, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode the generated sequence generated_text = tokenizer.decode(output_sequences[0]) return generated_text def generate_wrapper(composer): try: # Clear memory before generation gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Format the prompt with the selected composer prompt = f"<|START_METADATA|> <|composer_{composer}|> ticks_per_beat=" generated_text = generate_music(prompt) midi_data = adjust_midi_timing(processor.decode_midi_file(generated_text)) state.generated_text = midi_data # Create temp file for MIDI with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp: generated_tokens_to_midi(midi_data, tmp.name) return tmp.name, gr.update(visible=True) finally: # Clear memory after generation gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def continue_wrapper(): """Wrapper for continue_sequence""" if state.generated_text is not None: # Continue the sequence using your existing function extended_text = continue_sequence(state.generated_text, num_loops=8) state.generated_text = extended_text # Update the stored text # Create temp file for MIDI with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as tmp: generated_tokens_to_midi(extended_text, tmp.name) return tmp.name return None with gr.Blocks() as iface: gr.Markdown(""" # MAI: MIDI AI Music Generation Model Select a composer whose musical style you'd like to emulate. Generate an original sequence inspired by that composer's unique sound. It should take a few minutes. Once it's ready, you can download the audio file. If you like the opening, you can continue the sequence and make your song longer, repeat, or try again. It may take a few minutes to continue the sequence. """) with gr.Column(): composer_input = gr.Dropdown( choices=["Bach", "Chopin"], label="Select Composer", value="Bach" ) generate_btn = gr.Button("Generate Music") output_file = gr.File(label="Generated MIDI File") continue_btn = gr.Button("Continue Sequence (add 10 seconds)", visible=False) #generate_btn.click( # fn=generate_wrapper, # inputs=composer_input, # outputs=[output_file, continue_btn] #) generate_btn.click( lambda: (None, gr.update(visible=False)), None, [output_file, continue_btn], queue=False ).success( generate_wrapper, inputs=composer_input, outputs=[output_file, continue_btn] ) continue_btn.click( fn=continue_wrapper, outputs=output_file ) iface.launch()