Spaces:
Sleeping
Sleeping
| from anticipation import ops | |
| from anticipation.sample import generate | |
| from anticipation.tokenize import extract_instruments | |
| from anticipation.convert import events_to_midi,midi_to_events, compound_to_midi | |
| from anticipation.config import * | |
| from anticipation.vocab import * | |
| from anticipation.convert import midi_to_compound | |
| import mido | |
| from agents.utils import load_midi_metadata | |
| SMALL_MODEL = 'stanford-crfm/music-small-800k' # faster inference, worse sample quality | |
| MEDIUM_MODEL = 'stanford-crfm/music-medium-800k' # slower inference, better sample quality | |
| LARGE_MODEL = 'stanford-crfm/music-large-800k' # slowest inference, best sample quality | |
| def harmonize_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p): | |
| # Turn full midi to events | |
| events = midi_to_events(midi) | |
| print("Midi converted to events") | |
| # Get clip from 0 to end of full midi | |
| segment = ops.clip(events, 0, ops.max_time(events, seconds=True)) | |
| segment = ops.translate(segment, -ops.min_time(segment, seconds=False)) | |
| # Extract melody and accompaniment | |
| events, melody = extract_instruments(segment, [0]) | |
| print("Melody extracted") | |
| print("Start time:", start_time) | |
| print("End time:", end_time) | |
| # Get initial prompt | |
| history = ops.clip(events, 0, start_time, clip_duration=False) | |
| anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)] | |
| # Generate accompaniment conditioning on melody | |
| accompaniment = generate(model, start_time, end_time, inputs=history, controls=melody, top_p=top_p, debug=False) | |
| # Append anticipated continuation to accompaniment | |
| accompaniment = ops.combine(accompaniment, anticipated) | |
| print("Accompaniment generated") | |
| # 1) render each voice separately | |
| mel_mid = events_to_midi(melody) | |
| acc_mid = events_to_midi(accompaniment) | |
| # 2) build a fresh MidiFile | |
| combined = mido.MidiFile() | |
| combined.ticks_per_beat = mel_mid.ticks_per_beat # or TIME_RESOLUTION//2 | |
| print("Midi built") | |
| # 3) meta‐track with tempo & time signature | |
| meta = mido.MidiTrack() | |
| meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo)) | |
| meta.append(mido.MetaMessage('time_signature', | |
| numerator=original_time_sig[0], | |
| denominator=original_time_sig[1])) | |
| combined.tracks.append(meta) | |
| # 4) append melody *then* accompaniment | |
| combined.tracks.extend(mel_mid.tracks[1:]) # Skip existing meta track | |
| combined.tracks.extend(acc_mid.tracks[1:]) | |
| # 5) save in exactly that order | |
| for track in combined.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Ensure valid MIDI values | |
| if hasattr(msg, 'velocity'): | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| if hasattr(msg, 'note'): | |
| msg.note = min(max(msg.note, 0), 127) | |
| print(f"Melody tracks: {len(mel_mid.tracks)}") | |
| print(f"Accompaniment tracks: {len(acc_mid.tracks)}") | |
| print(f"Combined tracks before cleanup: {len(combined.tracks)}") | |
| # Add track cleanup (keep only unique tracks): | |
| unique_tracks = [] | |
| seen = set() | |
| for track in combined.tracks: | |
| track_hash = str([msg.hex() for msg in track]) | |
| if track_hash not in seen: | |
| unique_tracks.append(track) | |
| seen.add(track_hash) | |
| combined.tracks = unique_tracks | |
| print(f"Final track count: {len(combined.tracks)}") | |
| print("Output Midi metadata added") | |
| return combined | |
| def harmonizer(ai_model,midi_file, start_time, end_time,top_p): | |
| """ | |
| this function harmonizes a melody in a MIDI file | |
| returns the harmonized MIDI | |
| Args: | |
| midi_file: path to the MIDI file | |
| start_time: start time of the selected measure (melody you want to harmonize) in milliseconds | |
| end_time: end time of the selected measure in milliseconds | |
| """ | |
| print(f"Original MIDI tracks: {len(midi_file.tracks)}") | |
| # Load metadata and model... | |
| # Log original note parameters | |
| for track in midi_file.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| if msg.velocity > 127 or msg.velocity < 0: | |
| print(f"Invalid velocity: {msg.velocity}") | |
| if msg.note > 127 or msg.note < 0: | |
| print(f"Invalid pitch: {msg.note}") | |
| # Load original MIDI and extract metadata | |
| midi, original_tempo, original_time_sig = load_midi_metadata(midi_file) | |
| print("Midi metadata loaded") | |
| # load an anticipatory music transformer | |
| model = ai_model # add .cuda() if you have a GPU | |
| print("Model loaded") | |
| harmonized_midi = harmonize_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p) | |
| print("Midi generated") | |
| print(f"Harmonized MIDI tracks: {len(harmonized_midi.tracks)}") | |
| # Add MIDI validation | |
| for track in harmonized_midi.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Clamp invalid values | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| msg.note = min(max(msg.note, 0), 127) | |
| print("Midi saved") | |
| return harmonized_midi | |
| def infill_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p): | |
| # Turn full midi to events | |
| events = midi_to_events(midi) | |
| print("Midi converted to events") | |
| # Get clip from 0 to end of full midi | |
| segment = ops.clip(events, 0, ops.max_time(events, seconds=True)) | |
| segment = ops.translate(segment, -ops.min_time(segment, seconds=False)) | |
| # Get initial prompt | |
| history = ops.clip(events, 0, start_time, clip_duration=False) | |
| anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)] | |
| # Generate accompaniment conditioning on melody | |
| infilling = generate(model, start_time, end_time, inputs=history, controls=anticipated, top_p=top_p, debug=False) | |
| # Append anticipated continuation to accompaniment | |
| full_events = ops.combine(infilling, anticipated) | |
| print("Accompaniment generated") | |
| # 1) render each voice separately | |
| full_mid = events_to_midi(full_events) | |
| # 2) build a fresh MidiFile | |
| combined = mido.MidiFile() | |
| combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2 | |
| print("Midi built") | |
| # 3) meta‐track with tempo & time signature | |
| meta = mido.MidiTrack() | |
| meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo)) | |
| meta.append(mido.MetaMessage('time_signature', | |
| numerator=original_time_sig[0], | |
| denominator=original_time_sig[1])) | |
| combined.tracks.append(meta) | |
| # 4) append melody *then* accompaniment | |
| combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track | |
| # 5) save in exactly that order | |
| for track in combined.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Ensure valid MIDI values | |
| if hasattr(msg, 'velocity'): | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| if hasattr(msg, 'note'): | |
| msg.note = min(max(msg.note, 0), 127) | |
| print(f"Melody tracks: {len(full_mid.tracks)}") | |
| print(f"Accompaniment tracks: {len(full_mid.tracks)}") | |
| print(f"Combined tracks before cleanup: {len(combined.tracks)}") | |
| # Add track cleanup (keep only unique tracks): | |
| unique_tracks = [] | |
| seen = set() | |
| for track in combined.tracks: | |
| track_hash = str([msg.hex() for msg in track]) | |
| if track_hash not in seen: | |
| unique_tracks.append(track) | |
| seen.add(track_hash) | |
| combined.tracks = unique_tracks | |
| print(f"Final track count: {len(combined.tracks)}") | |
| print("Output Midi metadata added") | |
| return combined | |
| def infiller(ai_model,midi_file, start_time, end_time,top_p): | |
| """ | |
| this function harmonizes a melody in a MIDI file | |
| returns the harmonized MIDI | |
| Args: | |
| midi_file: path to the MIDI file | |
| start_time: start time of the selected measure (melody you want to harmonize) in milliseconds | |
| end_time: end time of the selected measure in milliseconds | |
| """ | |
| print(f"Original MIDI tracks: {len(midi_file.tracks)}") | |
| # Load metadata and model... | |
| # Log original note parameters | |
| for track in midi_file.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| if msg.velocity > 127 or msg.velocity < 0: | |
| print(f"Invalid velocity: {msg.velocity}") | |
| if msg.note > 127 or msg.note < 0: | |
| print(f"Invalid pitch: {msg.note}") | |
| # Load original MIDI and extract metadata | |
| midi, original_tempo, original_time_sig = load_midi_metadata(midi_file) | |
| print("Midi metadata loaded") | |
| # load an anticipatory music transformer | |
| model = ai_model # add .cuda() if you have a GPU | |
| print("Model loaded") | |
| infilled_midi = infill_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p) | |
| print("Midi generated") | |
| print(f"Harmonized MIDI tracks: {len(infilled_midi.tracks)}") | |
| # Add MIDI validation | |
| for track in infilled_midi.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Clamp invalid values | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| msg.note = min(max(msg.note, 0), 127) | |
| print("Midi saved") | |
| return infilled_midi | |
| def change_melody_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p): | |
| events = midi_to_events(midi) | |
| segment = ops.clip(events, 0, ops.max_time(events, seconds=True)) | |
| segment = ops.translate(segment, -ops.min_time(segment, seconds=False)) | |
| # Extract melody (instrument 0) as events and accompaniment as controls | |
| instruments = list(ops.get_instruments(segment).keys()) | |
| accompaniment_instruments = [instr for instr in instruments if instr != 0] | |
| melody_events, accompaniment_controls = extract_instruments(segment, accompaniment_instruments) | |
| # Get initial prompt (melody before start_time) | |
| history = ops.clip(melody_events, 0, start_time, clip_duration=False) | |
| # Include accompaniment controls for the entire duration | |
| controls = accompaniment_controls # Full accompaniment as controls | |
| # Generate new melody conditioned on accompaniment | |
| infilling = generate(model, start_time, end_time, inputs=history, controls=controls, top_p=top_p, debug=False) | |
| # Append anticipated continuation | |
| anticipated_melody = [CONTROL_OFFSET + tok for tok in ops.clip(melody_events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)] | |
| full_events = ops.combine(infilling, anticipated_melody) | |
| acc_mid = events_to_midi(accompaniment_controls) | |
| # Render and combine MIDI tracks | |
| full_mid = events_to_midi(full_events) | |
| combined = mido.MidiFile() | |
| combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2 | |
| print("Midi built") | |
| # 3) meta‐track with tempo & time signature | |
| meta = mido.MidiTrack() | |
| meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo)) | |
| meta.append(mido.MetaMessage('time_signature', | |
| numerator=original_time_sig[0], | |
| denominator=original_time_sig[1])) | |
| combined.tracks.append(meta) | |
| # 4) append melody *then* accompaniment | |
| combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track | |
| combined.tracks.extend(acc_mid.tracks[:]) # Skip existing meta track | |
| # 5) save in exactly that order | |
| for track in combined.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Ensure valid MIDI values | |
| if hasattr(msg, 'velocity'): | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| if hasattr(msg, 'note'): | |
| msg.note = min(max(msg.note, 0), 127) | |
| print(f"Melody tracks: {len(full_mid.tracks)}") | |
| print(f"Accompaniment tracks: {len(full_mid.tracks)}") | |
| print(f"Combined tracks before cleanup: {len(combined.tracks)}") | |
| # Add track cleanup (keep only unique tracks): | |
| unique_tracks = [] | |
| seen = set() | |
| for track in combined.tracks: | |
| track_hash = str([msg.hex() for msg in track]) | |
| if track_hash not in seen: | |
| unique_tracks.append(track) | |
| seen.add(track_hash) | |
| combined.tracks = unique_tracks | |
| print(f"Final track count: {len(combined.tracks)}") | |
| print("Output Midi metadata added") | |
| return combined | |
| def change_melody(ai_model,midi_file, start_time, end_time,top_p): | |
| """ | |
| this function harmonizes a melody in a MIDI file | |
| returns the harmonized MIDI | |
| Args: | |
| midi_file: path to the MIDI file | |
| start_time: start time of the selected measure (melody you want to harmonize) in milliseconds | |
| end_time: end time of the selected measure in milliseconds | |
| """ | |
| print(f"Original MIDI tracks: {len(midi_file.tracks)}") | |
| # Load metadata and model... | |
| # Log original note parameters | |
| for track in midi_file.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| if msg.velocity > 127 or msg.velocity < 0: | |
| print(f"Invalid velocity: {msg.velocity}") | |
| if msg.note > 127 or msg.note < 0: | |
| print(f"Invalid pitch: {msg.note}") | |
| # Load original MIDI and extract metadata | |
| midi, original_tempo, original_time_sig = load_midi_metadata(midi_file) | |
| print("Midi metadata loaded") | |
| # load an anticipatory music transformer | |
| model = ai_model # add .cuda() if you have a GPU | |
| print("Model loaded") | |
| change_melody_gen_midi = change_melody_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p) | |
| print("Midi generated") | |
| print(f"Harmonized MIDI tracks: {len(change_melody_gen_midi.tracks)}") | |
| # Add MIDI validation | |
| for track in change_melody_gen_midi.tracks: | |
| for msg in track: | |
| if msg.type in ['note_on', 'note_off']: | |
| # Clamp invalid values | |
| msg.velocity = min(max(msg.velocity, 0), 127) | |
| msg.note = min(max(msg.note, 0), 127) | |
| print("Midi saved") | |
| return change_melody_gen_midi |