Spaces:
Sleeping
Sleeping
Commit
·
572abf8
1
Parent(s):
753bd5a
Initial commit
Browse files- Dockerfile +31 -0
- README.md +4 -4
- __pycache__/utils.cpython-311.pyc +0 -0
- agents/__innit__.py +0 -0
- agents/__pycache__/agents.cpython-311.pyc +0 -0
- agents/__pycache__/harmonize.cpython-311.pyc +0 -0
- agents/__pycache__/harmonize.cpython-312.pyc +0 -0
- agents/__pycache__/utils.cpython-311.pyc +0 -0
- agents/agents.py +422 -0
- agents/utils.py +15 -0
- anticipation/__init__.py +9 -0
- anticipation/__pycache__/__init__.cpython-311.pyc +0 -0
- anticipation/__pycache__/config.cpython-311.pyc +0 -0
- anticipation/__pycache__/convert.cpython-311.pyc +0 -0
- anticipation/__pycache__/ops.cpython-311.pyc +0 -0
- anticipation/__pycache__/sample.cpython-311.pyc +0 -0
- anticipation/__pycache__/tokenize.cpython-311.pyc +0 -0
- anticipation/__pycache__/visuals.cpython-311.pyc +0 -0
- anticipation/__pycache__/vocab.cpython-311.pyc +0 -0
- anticipation/config-original.py +60 -0
- anticipation/config.py +60 -0
- anticipation/convert-original.py +342 -0
- anticipation/convert.py +365 -0
- anticipation/ops.py +285 -0
- anticipation/sample.py +280 -0
- anticipation/tokenize.py +219 -0
- anticipation/visuals.py +65 -0
- anticipation/vocab.py +58 -0
- api.py +240 -0
- examples/full-score3.mid +0 -0
- examples/strawberry.mid +0 -0
- requirements.txt +11 -0
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
# Create dedicated user with home directory
|
| 4 |
+
RUN useradd -m -u 1000 user
|
| 5 |
+
|
| 6 |
+
# Set Hugging Face cache to user's writable directory
|
| 7 |
+
ENV HF_HOME=/home/user/.cache/huggingface
|
| 8 |
+
ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
|
| 9 |
+
|
| 10 |
+
# Create cache directory with proper permissions
|
| 11 |
+
RUN mkdir -p ${HF_HOME} && chown -R user:user /home/user
|
| 12 |
+
|
| 13 |
+
# Set working directory (app will live here)
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# Install dependencies as root
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt gunicorn
|
| 19 |
+
|
| 20 |
+
# Copy app files (maintain ownership)
|
| 21 |
+
COPY --chown=user:user . .
|
| 22 |
+
|
| 23 |
+
RUN rm -rf /root/.cache/pip
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Switch to non-root user
|
| 27 |
+
USER user
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
CMD ["gunicorn", "--workers", "1", "--timeout", "120", "--bind", "0.0.0.0:7860", "api:app"]
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
|
|
|
| 1 |
---
|
| 2 |
+
title: InScoreAI
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (229 Bytes). View file
|
|
|
agents/__innit__.py
ADDED
|
File without changes
|
agents/__pycache__/agents.cpython-311.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
agents/__pycache__/harmonize.cpython-311.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
agents/__pycache__/harmonize.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
agents/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (691 Bytes). View file
|
|
|
agents/agents.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from anticipation import ops
|
| 2 |
+
from anticipation.sample import generate
|
| 3 |
+
from anticipation.tokenize import extract_instruments
|
| 4 |
+
from anticipation.convert import events_to_midi,midi_to_events, compound_to_midi
|
| 5 |
+
from anticipation.config import *
|
| 6 |
+
from anticipation.vocab import *
|
| 7 |
+
from anticipation.convert import midi_to_compound
|
| 8 |
+
import mido
|
| 9 |
+
from agents.utils import load_midi_metadata
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
SMALL_MODEL = 'stanford-crfm/music-small-800k' # faster inference, worse sample quality
|
| 13 |
+
MEDIUM_MODEL = 'stanford-crfm/music-medium-800k' # slower inference, better sample quality
|
| 14 |
+
LARGE_MODEL = 'stanford-crfm/music-large-800k' # slowest inference, best sample quality
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def harmonize_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
| 20 |
+
|
| 21 |
+
# Turn full midi to events
|
| 22 |
+
events = midi_to_events(midi)
|
| 23 |
+
|
| 24 |
+
print("Midi converted to events")
|
| 25 |
+
|
| 26 |
+
# Get clip from 0 to end of full midi
|
| 27 |
+
|
| 28 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
| 29 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
| 30 |
+
|
| 31 |
+
# Extract melody and accompaniment
|
| 32 |
+
events, melody = extract_instruments(segment, [0])
|
| 33 |
+
|
| 34 |
+
print("Melody extracted")
|
| 35 |
+
|
| 36 |
+
print("Start time:", start_time)
|
| 37 |
+
print("End time:", end_time)
|
| 38 |
+
|
| 39 |
+
# Get initial prompt
|
| 40 |
+
history = ops.clip(events, 0, start_time, clip_duration=False)
|
| 41 |
+
|
| 42 |
+
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
| 43 |
+
|
| 44 |
+
# Generate accompaniment conditioning on melody
|
| 45 |
+
accompaniment = generate(model, start_time, end_time, inputs=history, controls=melody, top_p=top_p, debug=False)
|
| 46 |
+
|
| 47 |
+
# Append anticipated continuation to accompaniment
|
| 48 |
+
accompaniment = ops.combine(accompaniment, anticipated)
|
| 49 |
+
|
| 50 |
+
print("Accompaniment generated")
|
| 51 |
+
|
| 52 |
+
# 1) render each voice separately
|
| 53 |
+
mel_mid = events_to_midi(melody)
|
| 54 |
+
acc_mid = events_to_midi(accompaniment)
|
| 55 |
+
|
| 56 |
+
# 2) build a fresh MidiFile
|
| 57 |
+
combined = mido.MidiFile()
|
| 58 |
+
combined.ticks_per_beat = mel_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
| 59 |
+
|
| 60 |
+
print("Midi built")
|
| 61 |
+
|
| 62 |
+
# 3) meta‐track with tempo & time signature
|
| 63 |
+
meta = mido.MidiTrack()
|
| 64 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
| 65 |
+
meta.append(mido.MetaMessage('time_signature',
|
| 66 |
+
numerator=original_time_sig[0],
|
| 67 |
+
denominator=original_time_sig[1]))
|
| 68 |
+
combined.tracks.append(meta)
|
| 69 |
+
|
| 70 |
+
# 4) append melody *then* accompaniment
|
| 71 |
+
combined.tracks.extend(mel_mid.tracks[1:]) # Skip existing meta track
|
| 72 |
+
combined.tracks.extend(acc_mid.tracks[1:])
|
| 73 |
+
# 5) save in exactly that order
|
| 74 |
+
|
| 75 |
+
for track in combined.tracks:
|
| 76 |
+
for msg in track:
|
| 77 |
+
if msg.type in ['note_on', 'note_off']:
|
| 78 |
+
# Ensure valid MIDI values
|
| 79 |
+
if hasattr(msg, 'velocity'):
|
| 80 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 81 |
+
if hasattr(msg, 'note'):
|
| 82 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 83 |
+
|
| 84 |
+
print(f"Melody tracks: {len(mel_mid.tracks)}")
|
| 85 |
+
print(f"Accompaniment tracks: {len(acc_mid.tracks)}")
|
| 86 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
| 87 |
+
|
| 88 |
+
# Add track cleanup (keep only unique tracks):
|
| 89 |
+
unique_tracks = []
|
| 90 |
+
seen = set()
|
| 91 |
+
for track in combined.tracks:
|
| 92 |
+
track_hash = str([msg.hex() for msg in track])
|
| 93 |
+
if track_hash not in seen:
|
| 94 |
+
unique_tracks.append(track)
|
| 95 |
+
seen.add(track_hash)
|
| 96 |
+
combined.tracks = unique_tracks
|
| 97 |
+
|
| 98 |
+
print(f"Final track count: {len(combined.tracks)}")
|
| 99 |
+
|
| 100 |
+
print("Output Midi metadata added")
|
| 101 |
+
|
| 102 |
+
return combined
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def harmonizer(ai_model,midi_file, start_time, end_time,top_p):
|
| 107 |
+
"""
|
| 108 |
+
this function harmonizes a melody in a MIDI file
|
| 109 |
+
returns the harmonized MIDI
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
midi_file: path to the MIDI file
|
| 113 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
| 114 |
+
end_time: end time of the selected measure in milliseconds
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
| 118 |
+
|
| 119 |
+
# Load metadata and model...
|
| 120 |
+
|
| 121 |
+
# Log original note parameters
|
| 122 |
+
for track in midi_file.tracks:
|
| 123 |
+
for msg in track:
|
| 124 |
+
if msg.type in ['note_on', 'note_off']:
|
| 125 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
| 126 |
+
print(f"Invalid velocity: {msg.velocity}")
|
| 127 |
+
if msg.note > 127 or msg.note < 0:
|
| 128 |
+
print(f"Invalid pitch: {msg.note}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Load original MIDI and extract metadata
|
| 132 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
| 133 |
+
|
| 134 |
+
print("Midi metadata loaded")
|
| 135 |
+
|
| 136 |
+
# load an anticipatory music transformer
|
| 137 |
+
model = ai_model # add .cuda() if you have a GPU
|
| 138 |
+
|
| 139 |
+
print("Model loaded")
|
| 140 |
+
|
| 141 |
+
harmonized_midi = harmonize_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
| 142 |
+
|
| 143 |
+
print("Midi generated")
|
| 144 |
+
|
| 145 |
+
print(f"Harmonized MIDI tracks: {len(harmonized_midi.tracks)}")
|
| 146 |
+
|
| 147 |
+
# Add MIDI validation
|
| 148 |
+
for track in harmonized_midi.tracks:
|
| 149 |
+
for msg in track:
|
| 150 |
+
if msg.type in ['note_on', 'note_off']:
|
| 151 |
+
# Clamp invalid values
|
| 152 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 153 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 154 |
+
|
| 155 |
+
print("Midi saved")
|
| 156 |
+
|
| 157 |
+
return harmonized_midi
|
| 158 |
+
|
| 159 |
+
def infill_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
| 160 |
+
|
| 161 |
+
# Turn full midi to events
|
| 162 |
+
events = midi_to_events(midi)
|
| 163 |
+
|
| 164 |
+
print("Midi converted to events")
|
| 165 |
+
|
| 166 |
+
# Get clip from 0 to end of full midi
|
| 167 |
+
|
| 168 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
| 169 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
| 170 |
+
|
| 171 |
+
# Get initial prompt
|
| 172 |
+
history = ops.clip(events, 0, start_time, clip_duration=False)
|
| 173 |
+
|
| 174 |
+
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
| 175 |
+
|
| 176 |
+
# Generate accompaniment conditioning on melody
|
| 177 |
+
infilling = generate(model, start_time, end_time, inputs=history, controls=anticipated, top_p=top_p, debug=False)
|
| 178 |
+
|
| 179 |
+
# Append anticipated continuation to accompaniment
|
| 180 |
+
full_events = ops.combine(infilling, anticipated)
|
| 181 |
+
|
| 182 |
+
print("Accompaniment generated")
|
| 183 |
+
|
| 184 |
+
# 1) render each voice separately
|
| 185 |
+
full_mid = events_to_midi(full_events)
|
| 186 |
+
|
| 187 |
+
# 2) build a fresh MidiFile
|
| 188 |
+
combined = mido.MidiFile()
|
| 189 |
+
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
| 190 |
+
|
| 191 |
+
print("Midi built")
|
| 192 |
+
|
| 193 |
+
# 3) meta‐track with tempo & time signature
|
| 194 |
+
meta = mido.MidiTrack()
|
| 195 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
| 196 |
+
meta.append(mido.MetaMessage('time_signature',
|
| 197 |
+
numerator=original_time_sig[0],
|
| 198 |
+
denominator=original_time_sig[1]))
|
| 199 |
+
combined.tracks.append(meta)
|
| 200 |
+
|
| 201 |
+
# 4) append melody *then* accompaniment
|
| 202 |
+
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
|
| 203 |
+
|
| 204 |
+
# 5) save in exactly that order
|
| 205 |
+
|
| 206 |
+
for track in combined.tracks:
|
| 207 |
+
for msg in track:
|
| 208 |
+
if msg.type in ['note_on', 'note_off']:
|
| 209 |
+
# Ensure valid MIDI values
|
| 210 |
+
if hasattr(msg, 'velocity'):
|
| 211 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 212 |
+
if hasattr(msg, 'note'):
|
| 213 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 214 |
+
|
| 215 |
+
print(f"Melody tracks: {len(full_mid.tracks)}")
|
| 216 |
+
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
|
| 217 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
| 218 |
+
|
| 219 |
+
# Add track cleanup (keep only unique tracks):
|
| 220 |
+
unique_tracks = []
|
| 221 |
+
seen = set()
|
| 222 |
+
for track in combined.tracks:
|
| 223 |
+
track_hash = str([msg.hex() for msg in track])
|
| 224 |
+
if track_hash not in seen:
|
| 225 |
+
unique_tracks.append(track)
|
| 226 |
+
seen.add(track_hash)
|
| 227 |
+
combined.tracks = unique_tracks
|
| 228 |
+
|
| 229 |
+
print(f"Final track count: {len(combined.tracks)}")
|
| 230 |
+
|
| 231 |
+
print("Output Midi metadata added")
|
| 232 |
+
|
| 233 |
+
return combined
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def infiller(ai_model,midi_file, start_time, end_time,top_p):
|
| 238 |
+
"""
|
| 239 |
+
this function harmonizes a melody in a MIDI file
|
| 240 |
+
returns the harmonized MIDI
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
midi_file: path to the MIDI file
|
| 244 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
| 245 |
+
end_time: end time of the selected measure in milliseconds
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
| 249 |
+
|
| 250 |
+
# Load metadata and model...
|
| 251 |
+
|
| 252 |
+
# Log original note parameters
|
| 253 |
+
for track in midi_file.tracks:
|
| 254 |
+
for msg in track:
|
| 255 |
+
if msg.type in ['note_on', 'note_off']:
|
| 256 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
| 257 |
+
print(f"Invalid velocity: {msg.velocity}")
|
| 258 |
+
if msg.note > 127 or msg.note < 0:
|
| 259 |
+
print(f"Invalid pitch: {msg.note}")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Load original MIDI and extract metadata
|
| 263 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
| 264 |
+
|
| 265 |
+
print("Midi metadata loaded")
|
| 266 |
+
|
| 267 |
+
# load an anticipatory music transformer
|
| 268 |
+
model = ai_model # add .cuda() if you have a GPU
|
| 269 |
+
|
| 270 |
+
print("Model loaded")
|
| 271 |
+
|
| 272 |
+
infilled_midi = infill_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
| 273 |
+
|
| 274 |
+
print("Midi generated")
|
| 275 |
+
|
| 276 |
+
print(f"Harmonized MIDI tracks: {len(infilled_midi.tracks)}")
|
| 277 |
+
|
| 278 |
+
# Add MIDI validation
|
| 279 |
+
for track in infilled_midi.tracks:
|
| 280 |
+
for msg in track:
|
| 281 |
+
if msg.type in ['note_on', 'note_off']:
|
| 282 |
+
# Clamp invalid values
|
| 283 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 284 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 285 |
+
|
| 286 |
+
print("Midi saved")
|
| 287 |
+
|
| 288 |
+
return infilled_midi
|
| 289 |
+
|
| 290 |
+
def change_melody_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
| 291 |
+
|
| 292 |
+
events = midi_to_events(midi)
|
| 293 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
| 294 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
| 295 |
+
|
| 296 |
+
# Extract melody (instrument 0) as events and accompaniment as controls
|
| 297 |
+
instruments = list(ops.get_instruments(segment).keys())
|
| 298 |
+
accompaniment_instruments = [instr for instr in instruments if instr != 0]
|
| 299 |
+
melody_events, accompaniment_controls = extract_instruments(segment, accompaniment_instruments)
|
| 300 |
+
|
| 301 |
+
# Get initial prompt (melody before start_time)
|
| 302 |
+
history = ops.clip(melody_events, 0, start_time, clip_duration=False)
|
| 303 |
+
|
| 304 |
+
# Include accompaniment controls for the entire duration
|
| 305 |
+
controls = accompaniment_controls # Full accompaniment as controls
|
| 306 |
+
|
| 307 |
+
# Generate new melody conditioned on accompaniment
|
| 308 |
+
infilling = generate(model, start_time, end_time, inputs=history, controls=controls, top_p=top_p, debug=False)
|
| 309 |
+
|
| 310 |
+
# Append anticipated continuation
|
| 311 |
+
anticipated_melody = [CONTROL_OFFSET + tok for tok in ops.clip(melody_events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
| 312 |
+
full_events = ops.combine(infilling, anticipated_melody)
|
| 313 |
+
|
| 314 |
+
acc_mid = events_to_midi(accompaniment_controls)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# Render and combine MIDI tracks
|
| 318 |
+
full_mid = events_to_midi(full_events)
|
| 319 |
+
combined = mido.MidiFile()
|
| 320 |
+
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
| 321 |
+
|
| 322 |
+
print("Midi built")
|
| 323 |
+
|
| 324 |
+
# 3) meta‐track with tempo & time signature
|
| 325 |
+
meta = mido.MidiTrack()
|
| 326 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
| 327 |
+
meta.append(mido.MetaMessage('time_signature',
|
| 328 |
+
numerator=original_time_sig[0],
|
| 329 |
+
denominator=original_time_sig[1]))
|
| 330 |
+
combined.tracks.append(meta)
|
| 331 |
+
|
| 332 |
+
# 4) append melody *then* accompaniment
|
| 333 |
+
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
|
| 334 |
+
combined.tracks.extend(acc_mid.tracks[:]) # Skip existing meta track
|
| 335 |
+
|
| 336 |
+
# 5) save in exactly that order
|
| 337 |
+
|
| 338 |
+
for track in combined.tracks:
|
| 339 |
+
for msg in track:
|
| 340 |
+
if msg.type in ['note_on', 'note_off']:
|
| 341 |
+
# Ensure valid MIDI values
|
| 342 |
+
if hasattr(msg, 'velocity'):
|
| 343 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 344 |
+
if hasattr(msg, 'note'):
|
| 345 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 346 |
+
|
| 347 |
+
print(f"Melody tracks: {len(full_mid.tracks)}")
|
| 348 |
+
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
|
| 349 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
| 350 |
+
|
| 351 |
+
# Add track cleanup (keep only unique tracks):
|
| 352 |
+
unique_tracks = []
|
| 353 |
+
seen = set()
|
| 354 |
+
for track in combined.tracks:
|
| 355 |
+
track_hash = str([msg.hex() for msg in track])
|
| 356 |
+
if track_hash not in seen:
|
| 357 |
+
unique_tracks.append(track)
|
| 358 |
+
seen.add(track_hash)
|
| 359 |
+
combined.tracks = unique_tracks
|
| 360 |
+
|
| 361 |
+
print(f"Final track count: {len(combined.tracks)}")
|
| 362 |
+
|
| 363 |
+
print("Output Midi metadata added")
|
| 364 |
+
|
| 365 |
+
return combined
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def change_melody(ai_model,midi_file, start_time, end_time,top_p):
|
| 370 |
+
"""
|
| 371 |
+
this function harmonizes a melody in a MIDI file
|
| 372 |
+
returns the harmonized MIDI
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
midi_file: path to the MIDI file
|
| 376 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
| 377 |
+
end_time: end time of the selected measure in milliseconds
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
| 381 |
+
|
| 382 |
+
# Load metadata and model...
|
| 383 |
+
|
| 384 |
+
# Log original note parameters
|
| 385 |
+
for track in midi_file.tracks:
|
| 386 |
+
for msg in track:
|
| 387 |
+
if msg.type in ['note_on', 'note_off']:
|
| 388 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
| 389 |
+
print(f"Invalid velocity: {msg.velocity}")
|
| 390 |
+
if msg.note > 127 or msg.note < 0:
|
| 391 |
+
print(f"Invalid pitch: {msg.note}")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Load original MIDI and extract metadata
|
| 395 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
| 396 |
+
|
| 397 |
+
print("Midi metadata loaded")
|
| 398 |
+
|
| 399 |
+
# load an anticipatory music transformer
|
| 400 |
+
model = ai_model # add .cuda() if you have a GPU
|
| 401 |
+
|
| 402 |
+
print("Model loaded")
|
| 403 |
+
|
| 404 |
+
change_melody_gen_midi = change_melody_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
| 405 |
+
|
| 406 |
+
print("Midi generated")
|
| 407 |
+
|
| 408 |
+
print(f"Harmonized MIDI tracks: {len(change_melody_gen_midi.tracks)}")
|
| 409 |
+
|
| 410 |
+
# Add MIDI validation
|
| 411 |
+
for track in change_melody_gen_midi.tracks:
|
| 412 |
+
for msg in track:
|
| 413 |
+
if msg.type in ['note_on', 'note_off']:
|
| 414 |
+
# Clamp invalid values
|
| 415 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
| 416 |
+
msg.note = min(max(msg.note, 0), 127)
|
| 417 |
+
|
| 418 |
+
print("Midi saved")
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
return change_melody_gen_midi
|
agents/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
def load_midi_metadata(midi_file):
|
| 4 |
+
|
| 5 |
+
original_tempo = 500000 # default tempo (120 BPM)
|
| 6 |
+
original_time_sig = (4, 4) # default time signature
|
| 7 |
+
|
| 8 |
+
for msg in midi_file:
|
| 9 |
+
if msg.type == 'set_tempo':
|
| 10 |
+
original_tempo = msg.tempo
|
| 11 |
+
elif msg.type == 'time_signature':
|
| 12 |
+
original_time_sig = (msg.numerator, msg.denominator)
|
| 13 |
+
|
| 14 |
+
return midi_file, original_tempo, original_time_sig
|
| 15 |
+
|
anticipation/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Infrastructure for constructing anticipatory infilling models.
|
| 2 |
+
|
| 3 |
+
This model provides infrastructure to preprocess Midi music datasets
|
| 4 |
+
for training anticipatory music infilling models. For more context, see:
|
| 5 |
+
|
| 6 |
+
Anticipatory Music Transformer
|
| 7 |
+
John Thickstun, David Hall, Chris Donahue, Percy Liang
|
| 8 |
+
Preprint Report, 2023
|
| 9 |
+
"""
|
anticipation/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (598 Bytes). View file
|
|
|
anticipation/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
anticipation/__pycache__/convert.cpython-311.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
anticipation/__pycache__/ops.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
anticipation/__pycache__/sample.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
anticipation/__pycache__/tokenize.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
anticipation/__pycache__/visuals.cpython-311.pyc
ADDED
|
Binary file (4.02 kB). View file
|
|
|
anticipation/__pycache__/vocab.cpython-311.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
anticipation/config-original.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global configuration for anticipatory infilling models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# model hyper-parameters
|
| 6 |
+
|
| 7 |
+
CONTEXT_SIZE = 1024 # model context
|
| 8 |
+
EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
|
| 9 |
+
M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
|
| 10 |
+
DELTA = 5 # anticipation time in seconds
|
| 11 |
+
|
| 12 |
+
assert CONTEXT_SIZE == 1+EVENT_SIZE*M
|
| 13 |
+
|
| 14 |
+
# vocabulary constants
|
| 15 |
+
|
| 16 |
+
MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
|
| 17 |
+
MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
|
| 18 |
+
TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
|
| 19 |
+
|
| 20 |
+
MAX_PITCH = 128 # 128 MIDI pitches
|
| 21 |
+
MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
|
| 22 |
+
MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
|
| 23 |
+
|
| 24 |
+
MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
|
| 25 |
+
|
| 26 |
+
# preprocessing settings
|
| 27 |
+
|
| 28 |
+
PREPROC_WORKERS = 16
|
| 29 |
+
|
| 30 |
+
COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
|
| 31 |
+
MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
|
| 32 |
+
MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
|
| 33 |
+
MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
|
| 34 |
+
MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
|
| 35 |
+
|
| 36 |
+
# LakhMIDI dataset splits
|
| 37 |
+
|
| 38 |
+
LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
|
| 39 |
+
LAKH_VALID = ['e']
|
| 40 |
+
LAKH_TEST = ['f']
|
| 41 |
+
|
| 42 |
+
# derived quantities
|
| 43 |
+
|
| 44 |
+
MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
|
| 45 |
+
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
|
| 46 |
+
|
| 47 |
+
MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == '__main__':
|
| 51 |
+
print('Model constants:')
|
| 52 |
+
print(f' -> anticipation interval: {DELTA}s')
|
| 53 |
+
print('Vocabulary constants:')
|
| 54 |
+
print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
|
| 55 |
+
print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
|
| 56 |
+
print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
|
| 57 |
+
print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
|
| 58 |
+
print('Preprocessing constants:')
|
| 59 |
+
print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
|
| 60 |
+
print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
|
anticipation/config.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global configuration for anticipatory infilling models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# model hyper-parameters
|
| 6 |
+
|
| 7 |
+
CONTEXT_SIZE = 1024 # model context
|
| 8 |
+
EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
|
| 9 |
+
M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
|
| 10 |
+
DELTA = 5 # anticipation time in seconds
|
| 11 |
+
|
| 12 |
+
assert CONTEXT_SIZE == 1+EVENT_SIZE*M
|
| 13 |
+
|
| 14 |
+
# vocabulary constants
|
| 15 |
+
|
| 16 |
+
MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
|
| 17 |
+
MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
|
| 18 |
+
TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
|
| 19 |
+
|
| 20 |
+
MAX_PITCH = 128 # 128 MIDI pitches
|
| 21 |
+
MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
|
| 22 |
+
MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
|
| 23 |
+
|
| 24 |
+
MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
|
| 25 |
+
|
| 26 |
+
# preprocessing settings
|
| 27 |
+
|
| 28 |
+
PREPROC_WORKERS = 16
|
| 29 |
+
|
| 30 |
+
COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
|
| 31 |
+
MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
|
| 32 |
+
MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
|
| 33 |
+
MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
|
| 34 |
+
MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
|
| 35 |
+
|
| 36 |
+
# LakhMIDI dataset splits
|
| 37 |
+
|
| 38 |
+
LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
|
| 39 |
+
LAKH_VALID = ['e']
|
| 40 |
+
LAKH_TEST = ['f']
|
| 41 |
+
|
| 42 |
+
# derived quantities
|
| 43 |
+
|
| 44 |
+
MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
|
| 45 |
+
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
|
| 46 |
+
|
| 47 |
+
MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == '__main__':
|
| 51 |
+
print('Model constants:')
|
| 52 |
+
print(f' -> anticipation interval: {DELTA}s')
|
| 53 |
+
print('Vocabulary constants:')
|
| 54 |
+
print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
|
| 55 |
+
print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
|
| 56 |
+
print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
|
| 57 |
+
print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
|
| 58 |
+
print('Preprocessing constants:')
|
| 59 |
+
print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
|
| 60 |
+
print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
|
anticipation/convert-original.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for converting to and from Midi data and encoded/tokenized data.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import mido
|
| 8 |
+
|
| 9 |
+
from anticipation.config import *
|
| 10 |
+
from anticipation.vocab import *
|
| 11 |
+
from anticipation.ops import unpad
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def midi_to_interarrival(midifile, debug=False, stats=False):
|
| 15 |
+
midi = mido.MidiFile(midifile)
|
| 16 |
+
|
| 17 |
+
tokens = []
|
| 18 |
+
dt = 0
|
| 19 |
+
|
| 20 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
| 21 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
| 22 |
+
truncations = 0
|
| 23 |
+
for message in midi:
|
| 24 |
+
dt += message.time
|
| 25 |
+
|
| 26 |
+
# sanity check: negative time?
|
| 27 |
+
if message.time < 0:
|
| 28 |
+
raise ValueError
|
| 29 |
+
|
| 30 |
+
if message.type == 'program_change':
|
| 31 |
+
instruments[message.channel] = message.program
|
| 32 |
+
elif message.type in ['note_on', 'note_off']:
|
| 33 |
+
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
|
| 34 |
+
if delta_ticks != round(TIME_RESOLUTION*dt):
|
| 35 |
+
truncations += 1
|
| 36 |
+
|
| 37 |
+
if delta_ticks > 0: # if time elapsed since last token
|
| 38 |
+
tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
|
| 39 |
+
|
| 40 |
+
# special case: channel 9 is drums!
|
| 41 |
+
inst = 128 if message.channel == 9 else instruments[message.channel]
|
| 42 |
+
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
|
| 43 |
+
tokens.append(offset + (2**7)*inst + message.note)
|
| 44 |
+
dt = 0
|
| 45 |
+
elif message.type == 'set_tempo':
|
| 46 |
+
tempo = message.tempo
|
| 47 |
+
elif message.type == 'time_signature':
|
| 48 |
+
pass # we use real time
|
| 49 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
| 50 |
+
pass # we don't attempt to model these
|
| 51 |
+
elif message.type == 'control_change':
|
| 52 |
+
pass # this includes pedal and per-track volume: ignore for now
|
| 53 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
| 54 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
| 55 |
+
'device_name', 'sequence_number']:
|
| 56 |
+
pass # possibly useful metadata but ignore for now
|
| 57 |
+
elif message.type == 'channel_prefix':
|
| 58 |
+
pass # relatively common, but can we ignore this?
|
| 59 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
| 60 |
+
pass # I have no idea what this is
|
| 61 |
+
else:
|
| 62 |
+
if debug:
|
| 63 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
| 64 |
+
|
| 65 |
+
if stats:
|
| 66 |
+
return tokens, truncations
|
| 67 |
+
|
| 68 |
+
return tokens
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def interarrival_to_midi(tokens, debug=False):
|
| 72 |
+
mid = mido.MidiFile()
|
| 73 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
| 74 |
+
|
| 75 |
+
track_idx = {} # maps instrument to (track number, current time)
|
| 76 |
+
time_in_ticks = 0
|
| 77 |
+
num_tracks = 0
|
| 78 |
+
for token in tokens:
|
| 79 |
+
if token == MIDI_SEPARATOR:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if token < MIDI_START_OFFSET:
|
| 83 |
+
time_in_ticks += token - MIDI_TIME_OFFSET
|
| 84 |
+
elif token < MIDI_END_OFFSET:
|
| 85 |
+
token -= MIDI_START_OFFSET
|
| 86 |
+
instrument = token // 2**7
|
| 87 |
+
pitch = token - (2**7)*instrument
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
track, previous_time, idx = track_idx[instrument]
|
| 91 |
+
except KeyError:
|
| 92 |
+
idx = num_tracks
|
| 93 |
+
previous_time = 0
|
| 94 |
+
track = mido.MidiTrack()
|
| 95 |
+
mid.tracks.append(track)
|
| 96 |
+
if instrument == 128: # drums always go on channel 9
|
| 97 |
+
idx = 9
|
| 98 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
| 99 |
+
else:
|
| 100 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
| 101 |
+
track.append(message)
|
| 102 |
+
num_tracks += 1
|
| 103 |
+
if num_tracks == 9:
|
| 104 |
+
num_tracks += 1 # skip the drums track
|
| 105 |
+
|
| 106 |
+
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
|
| 107 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 108 |
+
else:
|
| 109 |
+
token -= MIDI_END_OFFSET
|
| 110 |
+
instrument = token // 2**7
|
| 111 |
+
pitch = token - (2**7)*instrument
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
track, previous_time, idx = track_idx[instrument]
|
| 115 |
+
except KeyError:
|
| 116 |
+
# shouldn't happen because we should have a corresponding onset
|
| 117 |
+
if debug:
|
| 118 |
+
print('IGNORING bad offset')
|
| 119 |
+
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
|
| 123 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 124 |
+
|
| 125 |
+
return mid
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def midi_to_compound(midifile, debug=False):
|
| 129 |
+
if type(midifile) == str:
|
| 130 |
+
midi = mido.MidiFile(midifile)
|
| 131 |
+
else:
|
| 132 |
+
midi = midifile
|
| 133 |
+
|
| 134 |
+
tokens = []
|
| 135 |
+
note_idx = 0
|
| 136 |
+
open_notes = defaultdict(list)
|
| 137 |
+
|
| 138 |
+
time = 0
|
| 139 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
| 140 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
| 141 |
+
for message in midi:
|
| 142 |
+
time += message.time
|
| 143 |
+
|
| 144 |
+
# sanity check: negative time?
|
| 145 |
+
if message.time < 0:
|
| 146 |
+
raise ValueError
|
| 147 |
+
|
| 148 |
+
if message.type == 'program_change':
|
| 149 |
+
instruments[message.channel] = message.program
|
| 150 |
+
elif message.type in ['note_on', 'note_off']:
|
| 151 |
+
# special case: channel 9 is drums!
|
| 152 |
+
instr = 128 if message.channel == 9 else instruments[message.channel]
|
| 153 |
+
|
| 154 |
+
if message.type == 'note_on' and message.velocity > 0: # onset
|
| 155 |
+
# time quantization
|
| 156 |
+
time_in_ticks = round(TIME_RESOLUTION*time)
|
| 157 |
+
|
| 158 |
+
# Our compound word is: (time, duration, note, instr, velocity)
|
| 159 |
+
tokens.append(time_in_ticks) # 5ms resolution
|
| 160 |
+
tokens.append(-1) # placeholder (we'll fill this in later)
|
| 161 |
+
tokens.append(message.note)
|
| 162 |
+
tokens.append(instr)
|
| 163 |
+
tokens.append(message.velocity)
|
| 164 |
+
|
| 165 |
+
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
|
| 166 |
+
note_idx += 1
|
| 167 |
+
else: # offset
|
| 168 |
+
try:
|
| 169 |
+
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
|
| 170 |
+
except IndexError:
|
| 171 |
+
if debug:
|
| 172 |
+
print('WARNING: ignoring bad offset')
|
| 173 |
+
else:
|
| 174 |
+
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
|
| 175 |
+
tokens[5*open_idx + 1] = duration_ticks
|
| 176 |
+
#del open_notes[(instr,message.note,message.channel)]
|
| 177 |
+
elif message.type == 'set_tempo':
|
| 178 |
+
tempo = message.tempo
|
| 179 |
+
elif message.type == 'time_signature':
|
| 180 |
+
pass # we use real time
|
| 181 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
| 182 |
+
pass # we don't attempt to model these
|
| 183 |
+
elif message.type == 'control_change':
|
| 184 |
+
pass # this includes pedal and per-track volume: ignore for now
|
| 185 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
| 186 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
| 187 |
+
'device_name', 'sequence_number']:
|
| 188 |
+
pass # possibly useful metadata but ignore for now
|
| 189 |
+
elif message.type == 'channel_prefix':
|
| 190 |
+
pass # relatively common, but can we ignore this?
|
| 191 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
| 192 |
+
pass # I have no idea what this is
|
| 193 |
+
else:
|
| 194 |
+
if debug:
|
| 195 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
| 196 |
+
|
| 197 |
+
unclosed_count = 0
|
| 198 |
+
for _,v in open_notes.items():
|
| 199 |
+
unclosed_count += len(v)
|
| 200 |
+
|
| 201 |
+
if debug and unclosed_count > 0:
|
| 202 |
+
print(f'WARNING: {unclosed_count} unclosed notes')
|
| 203 |
+
print(' ', midifile)
|
| 204 |
+
|
| 205 |
+
return tokens
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def compound_to_midi(tokens, debug=False):
|
| 209 |
+
mid = mido.MidiFile()
|
| 210 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
| 211 |
+
|
| 212 |
+
it = iter(tokens)
|
| 213 |
+
time_index = defaultdict(list)
|
| 214 |
+
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
|
| 215 |
+
time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
|
| 216 |
+
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
|
| 217 |
+
|
| 218 |
+
track_idx = {} # maps instrument to (track number, current time)
|
| 219 |
+
num_tracks = 0
|
| 220 |
+
for time_in_ticks, event_type in sorted(time_index.keys()):
|
| 221 |
+
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
|
| 222 |
+
if event_type == 0: # onset
|
| 223 |
+
try:
|
| 224 |
+
track, previous_time, idx = track_idx[instrument]
|
| 225 |
+
except KeyError:
|
| 226 |
+
idx = num_tracks
|
| 227 |
+
previous_time = 0
|
| 228 |
+
track = mido.MidiTrack()
|
| 229 |
+
mid.tracks.append(track)
|
| 230 |
+
if instrument == 128: # drums always go on channel 9
|
| 231 |
+
idx = 9
|
| 232 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
| 233 |
+
else:
|
| 234 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
| 235 |
+
track.append(message)
|
| 236 |
+
num_tracks += 1
|
| 237 |
+
if num_tracks == 9:
|
| 238 |
+
num_tracks += 1 # skip the drums track
|
| 239 |
+
|
| 240 |
+
track.append(mido.Message(
|
| 241 |
+
'note_on', note=note, channel=idx, velocity=velocity,
|
| 242 |
+
time=time_in_ticks-previous_time))
|
| 243 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 244 |
+
else: # offset
|
| 245 |
+
try:
|
| 246 |
+
track, previous_time, idx = track_idx[instrument]
|
| 247 |
+
except KeyError:
|
| 248 |
+
# shouldn't happen because we should have a corresponding onset
|
| 249 |
+
if debug:
|
| 250 |
+
print('IGNORING bad offset')
|
| 251 |
+
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
track.append(mido.Message(
|
| 255 |
+
'note_off', note=note, channel=idx,
|
| 256 |
+
time=time_in_ticks-previous_time))
|
| 257 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 258 |
+
|
| 259 |
+
return mid
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def compound_to_events(tokens, stats=False):
|
| 263 |
+
assert len(tokens) % 5 == 0
|
| 264 |
+
tokens = tokens.copy()
|
| 265 |
+
|
| 266 |
+
# remove velocities
|
| 267 |
+
del tokens[4::5]
|
| 268 |
+
|
| 269 |
+
# combine (note, instrument)
|
| 270 |
+
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
|
| 271 |
+
assert all(-1 <= tok < 129 for tok in tokens[3::4])
|
| 272 |
+
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
|
| 273 |
+
for note, instr in zip(tokens[2::4],tokens[3::4])]
|
| 274 |
+
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
|
| 275 |
+
del tokens[3::4]
|
| 276 |
+
|
| 277 |
+
# max duration cutoff and set unknown durations to 250ms
|
| 278 |
+
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
|
| 279 |
+
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
|
| 280 |
+
for tok in tokens[1::3]]
|
| 281 |
+
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
|
| 282 |
+
|
| 283 |
+
assert min(tokens[0::3]) >= 0
|
| 284 |
+
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
|
| 285 |
+
|
| 286 |
+
assert len(tokens) % 3 == 0
|
| 287 |
+
|
| 288 |
+
if stats:
|
| 289 |
+
return tokens, truncations
|
| 290 |
+
|
| 291 |
+
return tokens
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def events_to_compound(tokens, debug=False):
|
| 295 |
+
tokens = unpad(tokens)
|
| 296 |
+
|
| 297 |
+
# move all tokens to zero-offset for synthesis
|
| 298 |
+
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
|
| 299 |
+
for tok in tokens]
|
| 300 |
+
|
| 301 |
+
# remove type offsets
|
| 302 |
+
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
|
| 303 |
+
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
|
| 304 |
+
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
|
| 305 |
+
|
| 306 |
+
offset = 0 # add max time from previous track for synthesis
|
| 307 |
+
track_max = 0 # keep track of max time in track
|
| 308 |
+
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
| 309 |
+
if note == SEPARATOR:
|
| 310 |
+
offset += track_max
|
| 311 |
+
track_max = 0
|
| 312 |
+
if debug:
|
| 313 |
+
print('Sequence Boundary')
|
| 314 |
+
else:
|
| 315 |
+
track_max = max(track_max, time+dur)
|
| 316 |
+
tokens[3*j] += offset
|
| 317 |
+
|
| 318 |
+
# strip sequence separators
|
| 319 |
+
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
|
| 320 |
+
tokens = [tok for tok in tokens if tok != SEPARATOR]
|
| 321 |
+
|
| 322 |
+
assert len(tokens) % 3 == 0
|
| 323 |
+
out = 5*(len(tokens)//3)*[0]
|
| 324 |
+
out[0::5] = tokens[0::3]
|
| 325 |
+
out[1::5] = tokens[1::3]
|
| 326 |
+
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
|
| 327 |
+
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
|
| 328 |
+
out[4::5] = (len(tokens)//3)*[72] # default velocity
|
| 329 |
+
|
| 330 |
+
assert max(out[1::5]) < MAX_DUR
|
| 331 |
+
assert max(out[2::5]) < MAX_PITCH
|
| 332 |
+
assert max(out[3::5]) < MAX_INSTR
|
| 333 |
+
assert all(tok >= 0 for tok in out)
|
| 334 |
+
|
| 335 |
+
return out
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def events_to_midi(tokens, debug=False):
|
| 339 |
+
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
|
| 340 |
+
|
| 341 |
+
def midi_to_events(midifile, debug=False):
|
| 342 |
+
return compound_to_events(midi_to_compound(midifile, debug=debug))
|
anticipation/convert.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for converting to and from Midi data and encoded/tokenized data.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import mido
|
| 8 |
+
|
| 9 |
+
from anticipation.config import *
|
| 10 |
+
from anticipation.vocab import *
|
| 11 |
+
from anticipation.ops import unpad
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def midi_to_interarrival(midifile, debug=False, stats=False):
|
| 15 |
+
midi = mido.MidiFile(midifile)
|
| 16 |
+
|
| 17 |
+
tokens = []
|
| 18 |
+
dt = 0
|
| 19 |
+
|
| 20 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
| 21 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
| 22 |
+
truncations = 0
|
| 23 |
+
for message in midi:
|
| 24 |
+
dt += message.time
|
| 25 |
+
|
| 26 |
+
# sanity check: negative time?
|
| 27 |
+
if message.time < 0:
|
| 28 |
+
raise ValueError
|
| 29 |
+
|
| 30 |
+
if message.type == 'program_change':
|
| 31 |
+
instruments[message.channel] = message.program
|
| 32 |
+
elif message.type in ['note_on', 'note_off']:
|
| 33 |
+
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
|
| 34 |
+
if delta_ticks != round(TIME_RESOLUTION*dt):
|
| 35 |
+
truncations += 1
|
| 36 |
+
|
| 37 |
+
if delta_ticks > 0: # if time elapsed since last token
|
| 38 |
+
tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
|
| 39 |
+
|
| 40 |
+
# special case: channel 9 is drums!
|
| 41 |
+
inst = 128 if message.channel == 9 else instruments[message.channel]
|
| 42 |
+
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
|
| 43 |
+
tokens.append(offset + (2**7)*inst + message.note)
|
| 44 |
+
dt = 0
|
| 45 |
+
elif message.type == 'set_tempo':
|
| 46 |
+
tempo = message.tempo
|
| 47 |
+
elif message.type == 'time_signature':
|
| 48 |
+
pass # we use real time
|
| 49 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
| 50 |
+
pass # we don't attempt to model these
|
| 51 |
+
elif message.type == 'control_change':
|
| 52 |
+
pass # this includes pedal and per-track volume: ignore for now
|
| 53 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
| 54 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
| 55 |
+
'device_name', 'sequence_number']:
|
| 56 |
+
pass # possibly useful metadata but ignore for now
|
| 57 |
+
elif message.type == 'channel_prefix':
|
| 58 |
+
pass # relatively common, but can we ignore this?
|
| 59 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
| 60 |
+
pass # I have no idea what this is
|
| 61 |
+
else:
|
| 62 |
+
if debug:
|
| 63 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
| 64 |
+
|
| 65 |
+
if stats:
|
| 66 |
+
return tokens, truncations
|
| 67 |
+
|
| 68 |
+
return tokens
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def interarrival_to_midi(tokens, debug=False):
|
| 72 |
+
mid = mido.MidiFile()
|
| 73 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
| 74 |
+
|
| 75 |
+
track_idx = {} # maps instrument to (track number, current time)
|
| 76 |
+
time_in_ticks = 0
|
| 77 |
+
num_tracks = 0
|
| 78 |
+
for token in tokens:
|
| 79 |
+
if token == MIDI_SEPARATOR:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if token < MIDI_START_OFFSET:
|
| 83 |
+
time_in_ticks += token - MIDI_TIME_OFFSET
|
| 84 |
+
elif token < MIDI_END_OFFSET:
|
| 85 |
+
token -= MIDI_START_OFFSET
|
| 86 |
+
instrument = token // 2**7
|
| 87 |
+
pitch = token - (2**7)*instrument
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
track, previous_time, idx = track_idx[instrument]
|
| 91 |
+
except KeyError:
|
| 92 |
+
idx = num_tracks
|
| 93 |
+
previous_time = 0
|
| 94 |
+
track = mido.MidiTrack()
|
| 95 |
+
mid.tracks.append(track)
|
| 96 |
+
if instrument == 128: # drums always go on channel 9
|
| 97 |
+
idx = 9
|
| 98 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
| 99 |
+
else:
|
| 100 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
| 101 |
+
track.append(message)
|
| 102 |
+
num_tracks += 1
|
| 103 |
+
if num_tracks == 9:
|
| 104 |
+
num_tracks += 1 # skip the drums track
|
| 105 |
+
|
| 106 |
+
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
|
| 107 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 108 |
+
else:
|
| 109 |
+
token -= MIDI_END_OFFSET
|
| 110 |
+
instrument = token // 2**7
|
| 111 |
+
pitch = token - (2**7)*instrument
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
track, previous_time, idx = track_idx[instrument]
|
| 115 |
+
except KeyError:
|
| 116 |
+
# shouldn't happen because we should have a corresponding onset
|
| 117 |
+
if debug:
|
| 118 |
+
print('IGNORING bad offset')
|
| 119 |
+
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
|
| 123 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 124 |
+
|
| 125 |
+
return mid
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def midi_to_compound(midifile, debug=False):
|
| 129 |
+
if type(midifile) == str:
|
| 130 |
+
midi = mido.MidiFile(midifile)
|
| 131 |
+
else:
|
| 132 |
+
midi = midifile
|
| 133 |
+
|
| 134 |
+
tokens = []
|
| 135 |
+
note_idx = 0
|
| 136 |
+
open_notes = defaultdict(list)
|
| 137 |
+
|
| 138 |
+
time = 0
|
| 139 |
+
instruments = defaultdict(lambda: {'program': 0, 'channel': None}) # Track channel assignments
|
| 140 |
+
next_channel = 0
|
| 141 |
+
|
| 142 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
| 143 |
+
for message in midi:
|
| 144 |
+
time += message.time
|
| 145 |
+
|
| 146 |
+
# sanity check: negative time?
|
| 147 |
+
if message.time < 0:
|
| 148 |
+
raise ValueError
|
| 149 |
+
|
| 150 |
+
if message.type == 'program_change':
|
| 151 |
+
# Reserve channels 0-8, 10-15 (skip 9 for drums)
|
| 152 |
+
if message.channel != 9 and message.channel not in instruments:
|
| 153 |
+
instruments[message.channel]['program'] = message.program
|
| 154 |
+
instruments[message.channel]['channel'] = next_channel
|
| 155 |
+
next_channel += 1
|
| 156 |
+
if next_channel == 9: # Skip channel 9 (drums)
|
| 157 |
+
next_channel = 10
|
| 158 |
+
elif message.type in ['note_on', 'note_off']:
|
| 159 |
+
# special case: channel 9 is drums!
|
| 160 |
+
instr = 128 if message.channel == 9 else instruments[message.channel]['program']
|
| 161 |
+
channel = 9 if message.channel == 9 else instruments[message.channel]['channel']
|
| 162 |
+
compound_instr = (instr << 4) | channel
|
| 163 |
+
if message.type == 'note_on' and message.velocity > 0: # onset
|
| 164 |
+
# time quantization
|
| 165 |
+
time_in_ticks = round(TIME_RESOLUTION*time)
|
| 166 |
+
|
| 167 |
+
# Our compound word is: (time, duration, note, instr, velocity)
|
| 168 |
+
tokens.append(time_in_ticks) # 5ms resolution
|
| 169 |
+
tokens.append(-1) # placeholder (we'll fill this in later)
|
| 170 |
+
tokens.append(message.note)
|
| 171 |
+
tokens.append(compound_instr)
|
| 172 |
+
tokens.append(message.velocity)
|
| 173 |
+
|
| 174 |
+
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
|
| 175 |
+
note_idx += 1
|
| 176 |
+
else: # offset
|
| 177 |
+
try:
|
| 178 |
+
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
|
| 179 |
+
except IndexError:
|
| 180 |
+
if debug:
|
| 181 |
+
print('WARNING: ignoring bad offset')
|
| 182 |
+
else:
|
| 183 |
+
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
|
| 184 |
+
tokens[5*open_idx + 1] = duration_ticks
|
| 185 |
+
#del open_notes[(instr,message.note,message.channel)]
|
| 186 |
+
elif message.type == 'set_tempo':
|
| 187 |
+
tempo = message.tempo
|
| 188 |
+
elif message.type == 'time_signature':
|
| 189 |
+
pass # we use real time
|
| 190 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
| 191 |
+
pass # we don't attempt to model these
|
| 192 |
+
elif message.type == 'control_change':
|
| 193 |
+
pass # this includes pedal and per-track volume: ignore for now
|
| 194 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
| 195 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
| 196 |
+
'device_name', 'sequence_number']:
|
| 197 |
+
pass # possibly useful metadata but ignore for now
|
| 198 |
+
elif message.type == 'channel_prefix':
|
| 199 |
+
pass # relatively common, but can we ignore this?
|
| 200 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
| 201 |
+
pass # I have no idea what this is
|
| 202 |
+
else:
|
| 203 |
+
if debug:
|
| 204 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
| 205 |
+
|
| 206 |
+
unclosed_count = 0
|
| 207 |
+
for _,v in open_notes.items():
|
| 208 |
+
unclosed_count += len(v)
|
| 209 |
+
|
| 210 |
+
if debug and unclosed_count > 0:
|
| 211 |
+
print(f'WARNING: {unclosed_count} unclosed notes')
|
| 212 |
+
print(' ', midifile)
|
| 213 |
+
|
| 214 |
+
return tokens
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def compound_to_midi(tokens, debug=False):
|
| 218 |
+
mid = mido.MidiFile()
|
| 219 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
| 220 |
+
|
| 221 |
+
tracks = {}
|
| 222 |
+
for token in tokens:
|
| 223 |
+
# Decode program and channel
|
| 224 |
+
program = (token >> 4) & 0x7F
|
| 225 |
+
channel = token & 0x0F
|
| 226 |
+
|
| 227 |
+
if (program, channel) not in tracks:
|
| 228 |
+
track = mido.MidiTrack()
|
| 229 |
+
mid.tracks.append(track)
|
| 230 |
+
tracks[(program, channel)] = track
|
| 231 |
+
track.append(mido.Message('program_change',
|
| 232 |
+
program=program,
|
| 233 |
+
channel=channel))
|
| 234 |
+
|
| 235 |
+
it = iter(tokens)
|
| 236 |
+
time_index = defaultdict(list)
|
| 237 |
+
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
|
| 238 |
+
time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
|
| 239 |
+
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
|
| 240 |
+
|
| 241 |
+
track_idx = {} # maps instrument to (track number, current time)
|
| 242 |
+
num_tracks = 0
|
| 243 |
+
for time_in_ticks, event_type in sorted(time_index.keys()):
|
| 244 |
+
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
|
| 245 |
+
if event_type == 0: # onset
|
| 246 |
+
try:
|
| 247 |
+
track, previous_time, idx = track_idx[instrument]
|
| 248 |
+
except KeyError:
|
| 249 |
+
idx = num_tracks
|
| 250 |
+
previous_time = 0
|
| 251 |
+
track = mido.MidiTrack()
|
| 252 |
+
mid.tracks.append(track)
|
| 253 |
+
if instrument == 128: # drums always go on channel 9
|
| 254 |
+
idx = 9
|
| 255 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
| 256 |
+
else:
|
| 257 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
| 258 |
+
track.append(message)
|
| 259 |
+
num_tracks += 1
|
| 260 |
+
if num_tracks == 9:
|
| 261 |
+
num_tracks += 1 # skip the drums track
|
| 262 |
+
|
| 263 |
+
track.append(mido.Message(
|
| 264 |
+
'note_on', note=note, channel=idx, velocity=velocity,
|
| 265 |
+
time=time_in_ticks-previous_time))
|
| 266 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 267 |
+
else: # offset
|
| 268 |
+
try:
|
| 269 |
+
track, previous_time, idx = track_idx[instrument]
|
| 270 |
+
except KeyError:
|
| 271 |
+
# shouldn't happen because we should have a corresponding onset
|
| 272 |
+
if debug:
|
| 273 |
+
print('IGNORING bad offset')
|
| 274 |
+
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
track.append(mido.Message(
|
| 278 |
+
'note_off', note=note, channel=idx,
|
| 279 |
+
time=time_in_ticks-previous_time))
|
| 280 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
| 281 |
+
|
| 282 |
+
return mid
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def compound_to_events(tokens, stats=False):
|
| 286 |
+
assert len(tokens) % 5 == 0
|
| 287 |
+
tokens = tokens.copy()
|
| 288 |
+
|
| 289 |
+
# remove velocities
|
| 290 |
+
del tokens[4::5]
|
| 291 |
+
|
| 292 |
+
# combine (note, instrument)
|
| 293 |
+
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
|
| 294 |
+
assert all(-1 <= tok < 129 for tok in tokens[3::4])
|
| 295 |
+
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
|
| 296 |
+
for note, instr in zip(tokens[2::4],tokens[3::4])]
|
| 297 |
+
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
|
| 298 |
+
del tokens[3::4]
|
| 299 |
+
|
| 300 |
+
# max duration cutoff and set unknown durations to 250ms
|
| 301 |
+
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
|
| 302 |
+
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
|
| 303 |
+
for tok in tokens[1::3]]
|
| 304 |
+
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
|
| 305 |
+
|
| 306 |
+
assert min(tokens[0::3]) >= 0
|
| 307 |
+
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
|
| 308 |
+
|
| 309 |
+
assert len(tokens) % 3 == 0
|
| 310 |
+
|
| 311 |
+
if stats:
|
| 312 |
+
return tokens, truncations
|
| 313 |
+
|
| 314 |
+
return tokens
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def events_to_compound(tokens, debug=False):
|
| 318 |
+
tokens = unpad(tokens)
|
| 319 |
+
|
| 320 |
+
# move all tokens to zero-offset for synthesis
|
| 321 |
+
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
|
| 322 |
+
for tok in tokens]
|
| 323 |
+
|
| 324 |
+
# remove type offsets
|
| 325 |
+
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
|
| 326 |
+
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
|
| 327 |
+
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
|
| 328 |
+
|
| 329 |
+
offset = 0 # add max time from previous track for synthesis
|
| 330 |
+
track_max = 0 # keep track of max time in track
|
| 331 |
+
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
| 332 |
+
if note == SEPARATOR:
|
| 333 |
+
offset += track_max
|
| 334 |
+
track_max = 0
|
| 335 |
+
if debug:
|
| 336 |
+
print('Sequence Boundary')
|
| 337 |
+
else:
|
| 338 |
+
track_max = max(track_max, time+dur)
|
| 339 |
+
tokens[3*j] += offset
|
| 340 |
+
|
| 341 |
+
# strip sequence separators
|
| 342 |
+
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
|
| 343 |
+
tokens = [tok for tok in tokens if tok != SEPARATOR]
|
| 344 |
+
|
| 345 |
+
assert len(tokens) % 3 == 0
|
| 346 |
+
out = 5*(len(tokens)//3)*[0]
|
| 347 |
+
out[0::5] = tokens[0::3]
|
| 348 |
+
out[1::5] = tokens[1::3]
|
| 349 |
+
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
|
| 350 |
+
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
|
| 351 |
+
out[4::5] = (len(tokens)//3)*[72] # default velocity
|
| 352 |
+
|
| 353 |
+
assert max(out[1::5]) < MAX_DUR
|
| 354 |
+
assert max(out[2::5]) < MAX_PITCH
|
| 355 |
+
assert max(out[3::5]) < MAX_INSTR
|
| 356 |
+
assert all(tok >= 0 for tok in out)
|
| 357 |
+
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def events_to_midi(tokens, debug=False):
|
| 362 |
+
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
|
| 363 |
+
|
| 364 |
+
def midi_to_events(midifile, debug=False):
|
| 365 |
+
return compound_to_events(midi_to_compound(midifile, debug=debug))
|
anticipation/ops.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for operating on encoded Midi sequences.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
from anticipation.config import *
|
| 8 |
+
from anticipation.vocab import *
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def print_tokens(tokens):
|
| 12 |
+
print('---------------------')
|
| 13 |
+
for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
| 14 |
+
if note == SEPARATOR:
|
| 15 |
+
assert tm == SEPARATOR and dur == SEPARATOR
|
| 16 |
+
print(j, 'SEPARATOR')
|
| 17 |
+
continue
|
| 18 |
+
|
| 19 |
+
if note == REST:
|
| 20 |
+
assert tm < CONTROL_OFFSET
|
| 21 |
+
assert dur == DUR_OFFSET+0
|
| 22 |
+
print(j, tm, 'REST')
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
if note < CONTROL_OFFSET:
|
| 26 |
+
tm = tm - TIME_OFFSET
|
| 27 |
+
dur = dur - DUR_OFFSET
|
| 28 |
+
note = note - NOTE_OFFSET
|
| 29 |
+
instr = note//2**7
|
| 30 |
+
pitch = note - (2**7)*instr
|
| 31 |
+
print(j, tm, dur, instr, pitch)
|
| 32 |
+
else:
|
| 33 |
+
tm = tm - ATIME_OFFSET
|
| 34 |
+
dur = dur - ADUR_OFFSET
|
| 35 |
+
note = note - ANOTE_OFFSET
|
| 36 |
+
instr = note//2**7
|
| 37 |
+
pitch = note - (2**7)*instr
|
| 38 |
+
print(j, tm, dur, instr, pitch, '(A)')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def clip(tokens, start, end, clip_duration=True, seconds=True):
|
| 42 |
+
if seconds:
|
| 43 |
+
start = int(TIME_RESOLUTION*start)
|
| 44 |
+
end = int(TIME_RESOLUTION*end)
|
| 45 |
+
|
| 46 |
+
new_tokens = []
|
| 47 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 48 |
+
if note < CONTROL_OFFSET:
|
| 49 |
+
this_time = time - TIME_OFFSET
|
| 50 |
+
this_dur = dur - DUR_OFFSET
|
| 51 |
+
else:
|
| 52 |
+
this_time = time - ATIME_OFFSET
|
| 53 |
+
this_dur = dur - ADUR_OFFSET
|
| 54 |
+
|
| 55 |
+
if this_time < start or end < this_time:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
# truncate extended notes
|
| 59 |
+
if clip_duration and end < this_time + this_dur:
|
| 60 |
+
dur -= this_time + this_dur - end
|
| 61 |
+
|
| 62 |
+
new_tokens.extend([time, dur, note])
|
| 63 |
+
|
| 64 |
+
return new_tokens
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def mask(tokens, start, end):
|
| 68 |
+
new_tokens = []
|
| 69 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 70 |
+
if note < CONTROL_OFFSET:
|
| 71 |
+
this_time = (time - TIME_OFFSET)/float(TIME_RESOLUTION)
|
| 72 |
+
else:
|
| 73 |
+
this_time = (time - ATIME_OFFSET)/float(TIME_RESOLUTION)
|
| 74 |
+
|
| 75 |
+
if start < this_time < end:
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
new_tokens.extend([time, dur, note])
|
| 79 |
+
|
| 80 |
+
return new_tokens
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def delete(tokens, criterion):
|
| 84 |
+
new_tokens = []
|
| 85 |
+
for token in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 86 |
+
if criterion(token):
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
new_tokens.extend(token)
|
| 90 |
+
|
| 91 |
+
return new_tokens
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def sort(tokens):
|
| 95 |
+
""" sort sequence of events or controls (but not both) """
|
| 96 |
+
|
| 97 |
+
times = tokens[0::3]
|
| 98 |
+
indices = sorted(range(len(times)), key=times.__getitem__)
|
| 99 |
+
|
| 100 |
+
sorted_tokens = []
|
| 101 |
+
for idx in indices:
|
| 102 |
+
sorted_tokens.extend(tokens[3*idx:3*(idx+1)])
|
| 103 |
+
|
| 104 |
+
return sorted_tokens
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def split(tokens):
|
| 108 |
+
""" split a sequence into events and controls """
|
| 109 |
+
|
| 110 |
+
events = []
|
| 111 |
+
controls = []
|
| 112 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 113 |
+
if note < CONTROL_OFFSET:
|
| 114 |
+
events.extend([time, dur, note])
|
| 115 |
+
else:
|
| 116 |
+
controls.extend([time, dur, note])
|
| 117 |
+
|
| 118 |
+
return events, controls
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def pad(tokens, end_time=None, density=TIME_RESOLUTION):
|
| 122 |
+
end_time = TIME_OFFSET+(end_time if end_time else max_time(tokens, seconds=False))
|
| 123 |
+
new_tokens = []
|
| 124 |
+
previous_time = TIME_OFFSET+0
|
| 125 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 126 |
+
# must pad before separation, anticipation
|
| 127 |
+
assert note < CONTROL_OFFSET
|
| 128 |
+
|
| 129 |
+
# insert pad tokens to ensure the desired density
|
| 130 |
+
while time > previous_time + density:
|
| 131 |
+
new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
|
| 132 |
+
previous_time += density
|
| 133 |
+
|
| 134 |
+
new_tokens.extend([time, dur, note])
|
| 135 |
+
previous_time = time
|
| 136 |
+
|
| 137 |
+
while end_time > previous_time + density:
|
| 138 |
+
new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
|
| 139 |
+
previous_time += density
|
| 140 |
+
|
| 141 |
+
return new_tokens
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def unpad(tokens):
|
| 145 |
+
new_tokens = []
|
| 146 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 147 |
+
if note == REST: continue
|
| 148 |
+
|
| 149 |
+
new_tokens.extend([time, dur, note])
|
| 150 |
+
|
| 151 |
+
return new_tokens
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def anticipate(events, controls, delta=DELTA*TIME_RESOLUTION):
|
| 155 |
+
"""
|
| 156 |
+
Interleave a sequence of events with anticipated controls.
|
| 157 |
+
|
| 158 |
+
Inputs:
|
| 159 |
+
events : a sequence of events
|
| 160 |
+
controls : a sequence of time-localized controls
|
| 161 |
+
delta : the anticipation interval
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
tokens : interleaved events and anticipated controls
|
| 165 |
+
controls : unconsumed controls (control time > max_time(events) + delta)
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
if len(controls) == 0:
|
| 169 |
+
return events, controls
|
| 170 |
+
|
| 171 |
+
tokens = []
|
| 172 |
+
event_time = 0
|
| 173 |
+
control_time = controls[0] - ATIME_OFFSET
|
| 174 |
+
for time, dur, note in zip(events[0::3],events[1::3],events[2::3]):
|
| 175 |
+
while event_time >= control_time - delta:
|
| 176 |
+
tokens.extend(controls[0:3])
|
| 177 |
+
controls = controls[3:] # consume this control
|
| 178 |
+
control_time = controls[0] - ATIME_OFFSET if len(controls) > 0 else float('inf')
|
| 179 |
+
|
| 180 |
+
assert note < CONTROL_OFFSET
|
| 181 |
+
event_time = time - TIME_OFFSET
|
| 182 |
+
tokens.extend([time, dur, note])
|
| 183 |
+
|
| 184 |
+
return tokens, controls
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def sparsity(tokens):
|
| 188 |
+
max_dt = 0
|
| 189 |
+
previous_time = TIME_OFFSET+0
|
| 190 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 191 |
+
if note == SEPARATOR: continue
|
| 192 |
+
assert note < CONTROL_OFFSET # don't operate on interleaved sequences
|
| 193 |
+
|
| 194 |
+
max_dt = max(max_dt, time - previous_time)
|
| 195 |
+
previous_time = time
|
| 196 |
+
|
| 197 |
+
return max_dt
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def min_time(tokens, seconds=True, instr=None):
|
| 201 |
+
mt = None
|
| 202 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 203 |
+
# stop calculating at sequence separator
|
| 204 |
+
if note == SEPARATOR: break
|
| 205 |
+
|
| 206 |
+
if note < CONTROL_OFFSET:
|
| 207 |
+
time -= TIME_OFFSET
|
| 208 |
+
note -= NOTE_OFFSET
|
| 209 |
+
else:
|
| 210 |
+
time -= ATIME_OFFSET
|
| 211 |
+
note -= ANOTE_OFFSET
|
| 212 |
+
|
| 213 |
+
# min time of a particular instrument
|
| 214 |
+
if instr is not None and instr != note//2**7:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
mt = time if mt is None else min(mt, time)
|
| 218 |
+
|
| 219 |
+
if mt is None: mt = 0
|
| 220 |
+
return mt/float(TIME_RESOLUTION) if seconds else mt
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def max_time(tokens, seconds=True, instr=None):
|
| 224 |
+
mt = 0
|
| 225 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 226 |
+
# keep checking for max_time, even if it appears after a separator
|
| 227 |
+
# (this is important because we use this check for vocab overflow in tokenization)
|
| 228 |
+
if note == SEPARATOR: continue
|
| 229 |
+
|
| 230 |
+
if note < CONTROL_OFFSET:
|
| 231 |
+
time -= TIME_OFFSET
|
| 232 |
+
note -= NOTE_OFFSET
|
| 233 |
+
else:
|
| 234 |
+
time -= ATIME_OFFSET
|
| 235 |
+
note -= ANOTE_OFFSET
|
| 236 |
+
|
| 237 |
+
# max time of a particular instrument
|
| 238 |
+
if instr is not None and instr != note//2**7:
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
mt = max(mt, time)
|
| 242 |
+
|
| 243 |
+
return mt/float(TIME_RESOLUTION) if seconds else mt
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_instruments(tokens):
|
| 247 |
+
instruments = defaultdict(int)
|
| 248 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 249 |
+
if note >= SPECIAL_OFFSET: continue
|
| 250 |
+
|
| 251 |
+
if note < CONTROL_OFFSET:
|
| 252 |
+
note -= NOTE_OFFSET
|
| 253 |
+
else:
|
| 254 |
+
note -= ANOTE_OFFSET
|
| 255 |
+
|
| 256 |
+
instr = note//2**7
|
| 257 |
+
instruments[instr] += 1
|
| 258 |
+
|
| 259 |
+
return instruments
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def translate(tokens, dt, seconds=False):
|
| 263 |
+
if seconds:
|
| 264 |
+
dt = int(TIME_RESOLUTION*dt)
|
| 265 |
+
|
| 266 |
+
new_tokens = []
|
| 267 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
| 268 |
+
# stop translating after EOT
|
| 269 |
+
if note == SEPARATOR:
|
| 270 |
+
new_tokens.extend([time, dur, note])
|
| 271 |
+
dt = 0
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
if note < CONTROL_OFFSET:
|
| 275 |
+
this_time = time - TIME_OFFSET
|
| 276 |
+
else:
|
| 277 |
+
this_time = time - ATIME_OFFSET
|
| 278 |
+
|
| 279 |
+
assert 0 <= this_time + dt
|
| 280 |
+
new_tokens.extend([time+dt, dur, note])
|
| 281 |
+
|
| 282 |
+
return new_tokens
|
| 283 |
+
|
| 284 |
+
def combine(events, controls):
|
| 285 |
+
return sort(events + [token - CONTROL_OFFSET for token in controls])
|
anticipation/sample.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API functions for sampling from anticipatory infilling models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from anticipation import ops
|
| 13 |
+
from anticipation.config import *
|
| 14 |
+
from anticipation.vocab import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def safe_logits(logits, idx):
|
| 18 |
+
logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
|
| 19 |
+
logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens
|
| 20 |
+
|
| 21 |
+
# don't generate stuff in the wrong time slot
|
| 22 |
+
if idx % 3 == 0:
|
| 23 |
+
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
|
| 24 |
+
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
|
| 25 |
+
elif idx % 3 == 1:
|
| 26 |
+
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
|
| 27 |
+
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
|
| 28 |
+
elif idx % 3 == 2:
|
| 29 |
+
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
|
| 30 |
+
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
|
| 31 |
+
|
| 32 |
+
return logits
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def nucleus(logits, top_p):
|
| 36 |
+
# from HF implementation
|
| 37 |
+
if top_p < 1.0:
|
| 38 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 39 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 40 |
+
|
| 41 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 42 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 43 |
+
|
| 44 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 45 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 46 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 47 |
+
|
| 48 |
+
# scatter sorted tensors to original indexing
|
| 49 |
+
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
| 50 |
+
logits[indices_to_remove] = -float("inf")
|
| 51 |
+
|
| 52 |
+
return logits
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def future_logits(logits, curtime):
|
| 56 |
+
""" don't sample events in the past """
|
| 57 |
+
if curtime > 0:
|
| 58 |
+
logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')
|
| 59 |
+
|
| 60 |
+
return logits
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def instr_logits(logits, full_history):
|
| 64 |
+
""" don't sample more than 16 instruments """
|
| 65 |
+
instrs = ops.get_instruments(full_history)
|
| 66 |
+
if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
|
| 67 |
+
return logits
|
| 68 |
+
|
| 69 |
+
for instr in range(MAX_INSTR):
|
| 70 |
+
if instr not in instrs:
|
| 71 |
+
logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')
|
| 72 |
+
|
| 73 |
+
return logits
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def add_token(model, z, tokens, top_p, current_time, debug=False):
|
| 77 |
+
assert len(tokens) % 3 == 0
|
| 78 |
+
|
| 79 |
+
history = tokens.copy()
|
| 80 |
+
lookback = max(len(tokens) - 1017, 0)
|
| 81 |
+
history = history[lookback:] # Markov window
|
| 82 |
+
offset = ops.min_time(history, seconds=False)
|
| 83 |
+
history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer
|
| 84 |
+
|
| 85 |
+
new_token = []
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for i in range(3):
|
| 88 |
+
input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
|
| 89 |
+
logits = model(input_tokens).logits[0,-1]
|
| 90 |
+
|
| 91 |
+
idx = input_tokens.shape[1]-1
|
| 92 |
+
logits = safe_logits(logits, idx)
|
| 93 |
+
if i == 0:
|
| 94 |
+
logits = future_logits(logits, current_time - offset)
|
| 95 |
+
elif i == 2:
|
| 96 |
+
logits = instr_logits(logits, tokens)
|
| 97 |
+
logits = nucleus(logits, top_p)
|
| 98 |
+
|
| 99 |
+
probs = F.softmax(logits, dim=-1)
|
| 100 |
+
token = torch.multinomial(probs, 1)
|
| 101 |
+
new_token.append(int(token))
|
| 102 |
+
|
| 103 |
+
new_token[0] += offset # revert to full sequence timing
|
| 104 |
+
if debug:
|
| 105 |
+
print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
|
| 106 |
+
|
| 107 |
+
return new_token
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
|
| 111 |
+
if inputs is None:
|
| 112 |
+
inputs = []
|
| 113 |
+
|
| 114 |
+
if controls is None:
|
| 115 |
+
controls = []
|
| 116 |
+
|
| 117 |
+
start_time = int(TIME_RESOLUTION*start_time)
|
| 118 |
+
end_time = int(TIME_RESOLUTION*end_time)
|
| 119 |
+
|
| 120 |
+
# prompt is events up to start_time
|
| 121 |
+
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
|
| 122 |
+
|
| 123 |
+
# treat events beyond start_time as controls
|
| 124 |
+
future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
|
| 125 |
+
if debug:
|
| 126 |
+
print('Future')
|
| 127 |
+
ops.print_tokens(future)
|
| 128 |
+
|
| 129 |
+
# clip controls that preceed the sequence
|
| 130 |
+
controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)
|
| 131 |
+
|
| 132 |
+
if debug:
|
| 133 |
+
print('Controls')
|
| 134 |
+
ops.print_tokens(controls)
|
| 135 |
+
|
| 136 |
+
z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS]
|
| 137 |
+
if debug:
|
| 138 |
+
print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')
|
| 139 |
+
|
| 140 |
+
# interleave the controls with the events
|
| 141 |
+
tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))
|
| 142 |
+
|
| 143 |
+
if debug:
|
| 144 |
+
print('Prompt')
|
| 145 |
+
ops.print_tokens(tokens)
|
| 146 |
+
|
| 147 |
+
current_time = ops.max_time(prompt, seconds=False)
|
| 148 |
+
if debug:
|
| 149 |
+
print('Current time:', current_time)
|
| 150 |
+
|
| 151 |
+
with tqdm(range(end_time-start_time)) as progress:
|
| 152 |
+
if controls:
|
| 153 |
+
atime, adur, anote = controls[0:3]
|
| 154 |
+
anticipated_tokens = controls[3:]
|
| 155 |
+
anticipated_time = atime - ATIME_OFFSET
|
| 156 |
+
else:
|
| 157 |
+
# nothing to anticipate
|
| 158 |
+
anticipated_time = math.inf
|
| 159 |
+
|
| 160 |
+
while True:
|
| 161 |
+
while current_time >= anticipated_time - delta:
|
| 162 |
+
tokens.extend([atime, adur, anote])
|
| 163 |
+
if debug:
|
| 164 |
+
note = anote - ANOTE_OFFSET
|
| 165 |
+
instr = note//2**7
|
| 166 |
+
print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)
|
| 167 |
+
|
| 168 |
+
if len(anticipated_tokens) > 0:
|
| 169 |
+
atime, adur, anote = anticipated_tokens[0:3]
|
| 170 |
+
anticipated_tokens = anticipated_tokens[3:]
|
| 171 |
+
anticipated_time = atime - ATIME_OFFSET
|
| 172 |
+
else:
|
| 173 |
+
# nothing more to anticipate
|
| 174 |
+
anticipated_time = math.inf
|
| 175 |
+
|
| 176 |
+
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
|
| 177 |
+
new_time = new_token[0] - TIME_OFFSET
|
| 178 |
+
if new_time >= end_time:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
if debug:
|
| 182 |
+
new_note = new_token[2] - NOTE_OFFSET
|
| 183 |
+
new_instr = new_note//2**7
|
| 184 |
+
new_pitch = new_note - (2**7)*new_instr
|
| 185 |
+
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
|
| 186 |
+
|
| 187 |
+
tokens.extend(new_token)
|
| 188 |
+
dt = new_time - current_time
|
| 189 |
+
assert dt >= 0
|
| 190 |
+
current_time = new_time
|
| 191 |
+
progress.update(dt)
|
| 192 |
+
|
| 193 |
+
events, _ = ops.split(tokens)
|
| 194 |
+
return ops.sort(ops.unpad(events) + future)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
|
| 198 |
+
if inputs is None:
|
| 199 |
+
inputs = []
|
| 200 |
+
|
| 201 |
+
if controls is None:
|
| 202 |
+
controls = []
|
| 203 |
+
else:
|
| 204 |
+
# treat controls as ordinary tokens
|
| 205 |
+
controls = [token-CONTROL_OFFSET for token in controls]
|
| 206 |
+
|
| 207 |
+
start_time = int(TIME_RESOLUTION*start_time)
|
| 208 |
+
end_time = int(TIME_RESOLUTION*end_time)
|
| 209 |
+
|
| 210 |
+
inputs = ops.sort(inputs + controls)
|
| 211 |
+
|
| 212 |
+
# prompt is events up to start_time
|
| 213 |
+
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
|
| 214 |
+
if debug:
|
| 215 |
+
print('Prompt')
|
| 216 |
+
ops.print_tokens(prompt)
|
| 217 |
+
|
| 218 |
+
# treat events beyond start_time as controls
|
| 219 |
+
controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
|
| 220 |
+
if debug:
|
| 221 |
+
print('Future')
|
| 222 |
+
ops.print_tokens(controls)
|
| 223 |
+
|
| 224 |
+
z = [AUTOREGRESS]
|
| 225 |
+
if debug:
|
| 226 |
+
print('AR Mode')
|
| 227 |
+
|
| 228 |
+
current_time = ops.max_time(prompt, seconds=False)
|
| 229 |
+
if debug:
|
| 230 |
+
print('Current time:', current_time)
|
| 231 |
+
|
| 232 |
+
tokens = prompt
|
| 233 |
+
with tqdm(range(end_time-start_time)) as progress:
|
| 234 |
+
if controls:
|
| 235 |
+
atime, adur, anote = controls[0:3]
|
| 236 |
+
anticipated_tokens = controls[3:]
|
| 237 |
+
anticipated_time = atime - TIME_OFFSET
|
| 238 |
+
else:
|
| 239 |
+
# nothing to anticipate
|
| 240 |
+
anticipated_time = math.inf
|
| 241 |
+
|
| 242 |
+
while True:
|
| 243 |
+
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
|
| 244 |
+
new_time = new_token[0] - TIME_OFFSET
|
| 245 |
+
if new_time >= end_time:
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
dt = new_time - current_time
|
| 249 |
+
assert dt >= 0
|
| 250 |
+
current_time = new_time
|
| 251 |
+
|
| 252 |
+
# backfill anything that should have come before the new token
|
| 253 |
+
while current_time >= anticipated_time:
|
| 254 |
+
tokens.extend([atime, adur, anote])
|
| 255 |
+
if debug:
|
| 256 |
+
note = anote - NOTE_OFFSET
|
| 257 |
+
instr = note//2**7
|
| 258 |
+
print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)
|
| 259 |
+
|
| 260 |
+
if len(anticipated_tokens) > 0:
|
| 261 |
+
atime, adur, anote = anticipated_tokens[0:3]
|
| 262 |
+
anticipated_tokens = anticipated_tokens[3:]
|
| 263 |
+
anticipated_time = atime - TIME_OFFSET
|
| 264 |
+
else:
|
| 265 |
+
# nothing more to anticipate
|
| 266 |
+
anticipated_time = math.inf
|
| 267 |
+
|
| 268 |
+
if debug:
|
| 269 |
+
new_note = new_token[2] - NOTE_OFFSET
|
| 270 |
+
new_instr = new_note//2**7
|
| 271 |
+
new_pitch = new_note - (2**7)*new_instr
|
| 272 |
+
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
|
| 273 |
+
|
| 274 |
+
tokens.extend(new_token)
|
| 275 |
+
progress.update(dt)
|
| 276 |
+
|
| 277 |
+
if anticipated_time != math.inf:
|
| 278 |
+
tokens.extend([atime, adur, anote])
|
| 279 |
+
|
| 280 |
+
return ops.sort(ops.unpad(tokens) + controls)
|
anticipation/tokenize.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Top-level functions for preprocessing data to be used for training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from anticipation import ops
|
| 10 |
+
from anticipation.config import *
|
| 11 |
+
from anticipation.vocab import *
|
| 12 |
+
from anticipation.convert import compound_to_events, midi_to_interarrival
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_spans(all_events, rate):
|
| 16 |
+
events = []
|
| 17 |
+
controls = []
|
| 18 |
+
span = True
|
| 19 |
+
next_span = end_span = TIME_OFFSET+0
|
| 20 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
| 21 |
+
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
|
| 22 |
+
|
| 23 |
+
# end of an anticipated span; decide when to do it again (next_span)
|
| 24 |
+
if span and time >= end_span:
|
| 25 |
+
span = False
|
| 26 |
+
next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate))
|
| 27 |
+
|
| 28 |
+
# anticipate a 3-second span
|
| 29 |
+
if (not span) and time >= next_span:
|
| 30 |
+
span = True
|
| 31 |
+
end_span = time + DELTA*TIME_RESOLUTION
|
| 32 |
+
|
| 33 |
+
if span:
|
| 34 |
+
# mark this event as a control
|
| 35 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
| 36 |
+
else:
|
| 37 |
+
events.extend([time, dur, note])
|
| 38 |
+
|
| 39 |
+
return events, controls
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
ANTICIPATION_RATES = 10
|
| 43 |
+
def extract_random(all_events, rate):
|
| 44 |
+
events = []
|
| 45 |
+
controls = []
|
| 46 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
| 47 |
+
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
|
| 48 |
+
|
| 49 |
+
if np.random.random() < rate/float(ANTICIPATION_RATES):
|
| 50 |
+
# mark this event as a control
|
| 51 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
| 52 |
+
else:
|
| 53 |
+
events.extend([time, dur, note])
|
| 54 |
+
|
| 55 |
+
return events, controls
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def extract_instruments(all_events, instruments):
|
| 59 |
+
events = []
|
| 60 |
+
controls = []
|
| 61 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
| 62 |
+
assert note < CONTROL_OFFSET # shouldn't be in the sequence yet
|
| 63 |
+
assert note not in [SEPARATOR, REST] # these shouldn't either
|
| 64 |
+
|
| 65 |
+
instr = (note-NOTE_OFFSET)//2**7
|
| 66 |
+
if instr in instruments:
|
| 67 |
+
# mark this event as a control
|
| 68 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
| 69 |
+
else:
|
| 70 |
+
events.extend([time, dur, note])
|
| 71 |
+
|
| 72 |
+
return events, controls
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def maybe_tokenize(compound_tokens):
|
| 76 |
+
# skip sequences with very few events
|
| 77 |
+
if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS:
|
| 78 |
+
return None, None, 1 # short track
|
| 79 |
+
|
| 80 |
+
events, truncations = compound_to_events(compound_tokens, stats=True)
|
| 81 |
+
end_time = ops.max_time(events, seconds=False)
|
| 82 |
+
|
| 83 |
+
# don't want to deal with extremely short tracks
|
| 84 |
+
if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS:
|
| 85 |
+
return None, None, 1 # short track
|
| 86 |
+
|
| 87 |
+
# don't want to deal with extremely long tracks
|
| 88 |
+
if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS:
|
| 89 |
+
return None, None, 2 # long track
|
| 90 |
+
|
| 91 |
+
# skip sequences more instruments than MIDI channels (16)
|
| 92 |
+
if len(ops.get_instruments(events)) > MAX_TRACK_INSTR:
|
| 93 |
+
return None, None, 3 # too many instruments
|
| 94 |
+
|
| 95 |
+
return events, truncations, 0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False):
|
| 99 |
+
assert augment_factor == 1 # can't augment interarrival-tokenized data
|
| 100 |
+
|
| 101 |
+
all_truncations = 0
|
| 102 |
+
seqcount = rest_count = 0
|
| 103 |
+
stats = 4*[0] # (short, long, too many instruments, inexpressible)
|
| 104 |
+
np.random.seed(0)
|
| 105 |
+
|
| 106 |
+
with open(output, 'w') as outfile:
|
| 107 |
+
concatenated_tokens = []
|
| 108 |
+
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
|
| 109 |
+
with open(filename, 'r') as f:
|
| 110 |
+
_, _, status = maybe_tokenize([int(token) for token in f.read().split()])
|
| 111 |
+
|
| 112 |
+
if status > 0:
|
| 113 |
+
stats[status-1] += 1
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
filename = filename[:-len('.compound.txt')] # get the original MIDI
|
| 117 |
+
|
| 118 |
+
# already parsed; shouldn't raise an exception
|
| 119 |
+
tokens, truncations = midi_to_interarrival(filename, stats=True)
|
| 120 |
+
tokens[0:0] = [MIDI_SEPARATOR]
|
| 121 |
+
concatenated_tokens.extend(tokens)
|
| 122 |
+
all_truncations += truncations
|
| 123 |
+
|
| 124 |
+
# write out full sequences to file
|
| 125 |
+
while len(concatenated_tokens) >= CONTEXT_SIZE:
|
| 126 |
+
seq = concatenated_tokens[0:CONTEXT_SIZE]
|
| 127 |
+
concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:]
|
| 128 |
+
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
|
| 129 |
+
seqcount += 1
|
| 130 |
+
|
| 131 |
+
if debug:
|
| 132 |
+
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
|
| 133 |
+
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
|
| 134 |
+
|
| 135 |
+
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def tokenize(datafiles, output, augment_factor, idx=0, debug=False):
|
| 139 |
+
tokens = []
|
| 140 |
+
all_truncations = 0
|
| 141 |
+
seqcount = rest_count = 0
|
| 142 |
+
stats = 4*[0] # (short, long, too many instruments, inexpressible)
|
| 143 |
+
np.random.seed(0)
|
| 144 |
+
|
| 145 |
+
with open(output, 'w') as outfile:
|
| 146 |
+
concatenated_tokens = []
|
| 147 |
+
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
|
| 148 |
+
with open(filename, 'r') as f:
|
| 149 |
+
all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()])
|
| 150 |
+
|
| 151 |
+
if status > 0:
|
| 152 |
+
stats[status-1] += 1
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
instruments = list(ops.get_instruments(all_events).keys())
|
| 156 |
+
end_time = ops.max_time(all_events, seconds=False)
|
| 157 |
+
|
| 158 |
+
# different random augmentations
|
| 159 |
+
for k in range(augment_factor):
|
| 160 |
+
if k % 10 == 0:
|
| 161 |
+
# no augmentation
|
| 162 |
+
events = all_events.copy()
|
| 163 |
+
controls = []
|
| 164 |
+
elif k % 10 == 1:
|
| 165 |
+
# span augmentation
|
| 166 |
+
lmbda = .05
|
| 167 |
+
events, controls = extract_spans(all_events, lmbda)
|
| 168 |
+
elif k % 10 < 6:
|
| 169 |
+
# random augmentation
|
| 170 |
+
r = np.random.randint(1,ANTICIPATION_RATES)
|
| 171 |
+
events, controls = extract_random(all_events, r)
|
| 172 |
+
else:
|
| 173 |
+
if len(instruments) > 1:
|
| 174 |
+
# instrument augmentation: at least one, but not all instruments
|
| 175 |
+
u = 1+np.random.randint(len(instruments)-1)
|
| 176 |
+
subset = np.random.choice(instruments, u, replace=False)
|
| 177 |
+
events, controls = extract_instruments(all_events, subset)
|
| 178 |
+
else:
|
| 179 |
+
# no augmentation
|
| 180 |
+
events = all_events.copy()
|
| 181 |
+
controls = []
|
| 182 |
+
|
| 183 |
+
if len(concatenated_tokens) == 0:
|
| 184 |
+
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
|
| 185 |
+
|
| 186 |
+
all_truncations += truncations
|
| 187 |
+
events = ops.pad(events, end_time)
|
| 188 |
+
rest_count += sum(1 if tok == REST else 0 for tok in events[2::3])
|
| 189 |
+
tokens, controls = ops.anticipate(events, controls)
|
| 190 |
+
assert len(controls) == 0 # should have consumed all controls (because of padding)
|
| 191 |
+
tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR]
|
| 192 |
+
concatenated_tokens.extend(tokens)
|
| 193 |
+
|
| 194 |
+
# write out full sequences to file
|
| 195 |
+
while len(concatenated_tokens) >= EVENT_SIZE*M:
|
| 196 |
+
seq = concatenated_tokens[0:EVENT_SIZE*M]
|
| 197 |
+
concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:]
|
| 198 |
+
|
| 199 |
+
# relativize time to the context
|
| 200 |
+
seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False)
|
| 201 |
+
assert ops.min_time(seq, seconds=False) == 0
|
| 202 |
+
if ops.max_time(seq, seconds=False) >= MAX_TIME:
|
| 203 |
+
stats[3] += 1
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
# if seq contains SEPARATOR, global controls describe the first sequence
|
| 207 |
+
seq.insert(0, z)
|
| 208 |
+
|
| 209 |
+
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
|
| 210 |
+
seqcount += 1
|
| 211 |
+
|
| 212 |
+
# grab the current augmentation controls if we didn't already
|
| 213 |
+
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
|
| 214 |
+
|
| 215 |
+
if debug:
|
| 216 |
+
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
|
| 217 |
+
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
|
| 218 |
+
|
| 219 |
+
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
|
anticipation/visuals.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for inspecting encoded music data.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
import anticipation.ops as ops
|
| 11 |
+
from anticipation.config import *
|
| 12 |
+
from anticipation.vocab import *
|
| 13 |
+
|
| 14 |
+
def visualize(tokens, output, selected=None):
|
| 15 |
+
#colors = ['white', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan',
|
| 16 |
+
# 'dodgerblue', 'slategray', 'navy', 'mediumpurple', 'mediumorchid', 'magenta', 'lightpink']
|
| 17 |
+
colors = ['white', '#426aa0', '#b26789', '#de9283', '#eac29f', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan', 'dodgerblue', 'slategray', 'navy']
|
| 18 |
+
|
| 19 |
+
plt.rcParams['figure.dpi'] = 300
|
| 20 |
+
plt.rcParams['savefig.dpi'] = 300
|
| 21 |
+
|
| 22 |
+
max_time = ops.max_time(tokens, seconds=False)
|
| 23 |
+
grid = np.zeros([max_time, MAX_PITCH])
|
| 24 |
+
instruments = list(sorted(list(ops.get_instruments(tokens).keys())))
|
| 25 |
+
if 128 in instruments:
|
| 26 |
+
instruments.remove(128)
|
| 27 |
+
|
| 28 |
+
for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
| 29 |
+
if note == SEPARATOR:
|
| 30 |
+
assert tm == SEPARATOR and dur == SEPARATOR
|
| 31 |
+
print(j, 'SEPARATOR')
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
if note == REST:
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
assert note < CONTROL_OFFSET
|
| 38 |
+
|
| 39 |
+
tm = tm - TIME_OFFSET
|
| 40 |
+
dur = dur - DUR_OFFSET
|
| 41 |
+
note = note - NOTE_OFFSET
|
| 42 |
+
instr = note//2**7
|
| 43 |
+
pitch = note - (2**7)*instr
|
| 44 |
+
|
| 45 |
+
if instr == 128: # drums
|
| 46 |
+
continue # we don't visualize this
|
| 47 |
+
|
| 48 |
+
if selected and instr not in selected:
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
grid[tm:tm+dur, pitch] = 1+instruments.index(instr)
|
| 52 |
+
|
| 53 |
+
plt.clf()
|
| 54 |
+
plt.axis('off')
|
| 55 |
+
cmap = matplotlib.colors.ListedColormap(colors)
|
| 56 |
+
bounds = list(range(MAX_TRACK_INSTR)) + [16]
|
| 57 |
+
norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
|
| 58 |
+
plt.imshow(np.flipud(grid.T), aspect=16, cmap=cmap, norm=norm, interpolation='none')
|
| 59 |
+
|
| 60 |
+
patches = [matplotlib.patches.Patch(color=colors[i+1], label=f"{instruments[i]}")
|
| 61 |
+
for i in range(len(instruments))]
|
| 62 |
+
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
|
| 63 |
+
|
| 64 |
+
plt.tight_layout()
|
| 65 |
+
plt.savefig(output)
|
anticipation/vocab.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The vocabularies used for arrival-time and interarrival-time encodings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# training sequence vocab
|
| 6 |
+
|
| 7 |
+
from anticipation.config import *
|
| 8 |
+
|
| 9 |
+
# the event block
|
| 10 |
+
EVENT_OFFSET = 0
|
| 11 |
+
TIME_OFFSET = EVENT_OFFSET
|
| 12 |
+
DUR_OFFSET = TIME_OFFSET + MAX_TIME
|
| 13 |
+
NOTE_OFFSET = DUR_OFFSET + MAX_DUR
|
| 14 |
+
REST = NOTE_OFFSET + MAX_NOTE
|
| 15 |
+
|
| 16 |
+
# the control block
|
| 17 |
+
CONTROL_OFFSET = NOTE_OFFSET + MAX_NOTE + 1
|
| 18 |
+
ATIME_OFFSET = CONTROL_OFFSET + 0
|
| 19 |
+
ADUR_OFFSET = ATIME_OFFSET + MAX_TIME
|
| 20 |
+
ANOTE_OFFSET = ADUR_OFFSET + MAX_DUR
|
| 21 |
+
|
| 22 |
+
# the special block
|
| 23 |
+
SPECIAL_OFFSET = ANOTE_OFFSET + MAX_NOTE
|
| 24 |
+
SEPARATOR = SPECIAL_OFFSET
|
| 25 |
+
AUTOREGRESS = SPECIAL_OFFSET + 1
|
| 26 |
+
ANTICIPATE = SPECIAL_OFFSET + 2
|
| 27 |
+
VOCAB_SIZE = ANTICIPATE+1
|
| 28 |
+
|
| 29 |
+
# interarrival-time (MIDI-like) vocab
|
| 30 |
+
MIDI_TIME_OFFSET = 0
|
| 31 |
+
MIDI_START_OFFSET = MIDI_TIME_OFFSET + MAX_INTERARRIVAL
|
| 32 |
+
MIDI_END_OFFSET = MIDI_START_OFFSET + MAX_NOTE
|
| 33 |
+
MIDI_SEPARATOR = MIDI_END_OFFSET + MAX_NOTE
|
| 34 |
+
MIDI_VOCAB_SIZE = MIDI_SEPARATOR + 1
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
print('Arrival-Time Training Sequence Format:')
|
| 38 |
+
print('Event Offset: ', EVENT_OFFSET)
|
| 39 |
+
print(' -> time offset :', TIME_OFFSET)
|
| 40 |
+
print(' -> duration offset :', DUR_OFFSET)
|
| 41 |
+
print(' -> note offset :', NOTE_OFFSET)
|
| 42 |
+
print(' -> rest token: ', REST)
|
| 43 |
+
print('Anticipated Control Offset: ', CONTROL_OFFSET)
|
| 44 |
+
print(' -> anticipated time offset :', ATIME_OFFSET)
|
| 45 |
+
print(' -> anticipated duration offset :', ADUR_OFFSET)
|
| 46 |
+
print(' -> anticipated note offset :', ANOTE_OFFSET)
|
| 47 |
+
print('Special Token Offset: ', SPECIAL_OFFSET)
|
| 48 |
+
print(' -> separator token: ', SEPARATOR)
|
| 49 |
+
print(' -> autoregression flag: ', AUTOREGRESS)
|
| 50 |
+
print(' -> anticipation flag: ', ANTICIPATE)
|
| 51 |
+
print('Arrival Encoding Vocabulary Size: ', VOCAB_SIZE)
|
| 52 |
+
print('')
|
| 53 |
+
print('Interarrival-Time Training Sequence Format:')
|
| 54 |
+
print(' -> time offset: ', MIDI_TIME_OFFSET)
|
| 55 |
+
print(' -> note-on offset: ', MIDI_START_OFFSET)
|
| 56 |
+
print(' -> note-off offset: ', MIDI_END_OFFSET)
|
| 57 |
+
print(' -> separator token: ', MIDI_SEPARATOR)
|
| 58 |
+
print('Interarrival Encoding Vocabulary Size: ', MIDI_VOCAB_SIZE)
|
api.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.agents import harmonizer, infiller, change_melody
|
| 2 |
+
from flask import Flask, request, jsonify
|
| 3 |
+
from flask_cors import CORS
|
| 4 |
+
import mido
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import music21
|
| 8 |
+
import traceback
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
import threading
|
| 11 |
+
from transformers import AutoModelForCausalLM
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
app = Flask(__name__)
|
| 15 |
+
CORS(app)
|
| 16 |
+
|
| 17 |
+
@app.after_request
|
| 18 |
+
def add_cors_headers(response):
|
| 19 |
+
# Allow only your domain
|
| 20 |
+
response.headers['Access-Control-Allow-Origin'] = 'https://https://inscoreai.netlify.app/.com'
|
| 21 |
+
response.headers['Access-Control-Allow-Methods'] = 'GET, POST'
|
| 22 |
+
response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
|
| 23 |
+
return response
|
| 24 |
+
|
| 25 |
+
def midi_to_musicxml(midi_file_path):
|
| 26 |
+
"""Convert MIDI file to MusicXML string with absolute safety"""
|
| 27 |
+
try:
|
| 28 |
+
midi_path_str = str(midi_file_path)
|
| 29 |
+
|
| 30 |
+
# Parse and convert to MusicXML
|
| 31 |
+
score = music21.converter.parse(midi_path_str)
|
| 32 |
+
|
| 33 |
+
# Create temporary output file path
|
| 34 |
+
temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml")
|
| 35 |
+
|
| 36 |
+
# Write to temporary file
|
| 37 |
+
score.write('musicxml', temp_output)
|
| 38 |
+
|
| 39 |
+
# Read back as string
|
| 40 |
+
with open(temp_output, 'r') as f:
|
| 41 |
+
musicxml_str = f.read()
|
| 42 |
+
|
| 43 |
+
# Clean up
|
| 44 |
+
os.unlink(temp_output)
|
| 45 |
+
|
| 46 |
+
return musicxml_str
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Conversion error: {str(e)}")
|
| 49 |
+
traceback.print_exc()
|
| 50 |
+
raise
|
| 51 |
+
|
| 52 |
+
def load_model():
|
| 53 |
+
global MODEL
|
| 54 |
+
with MODEL_LOCK:
|
| 55 |
+
if MODEL is None:
|
| 56 |
+
print("⏳ Loading music generation model...")
|
| 57 |
+
MODEL = AutoModelForCausalLM.from_pretrained('stanford-crfm/music-small-800k',local_files_only=True, force_download=False) # Prevent re-downloads
|
| 58 |
+
# Add .cuda() here if using GPU
|
| 59 |
+
print("✅ Model loaded successfully!")
|
| 60 |
+
return MODEL
|
| 61 |
+
|
| 62 |
+
# Model loading setup
|
| 63 |
+
MODEL = None
|
| 64 |
+
MODEL_LOCK = threading.Lock()
|
| 65 |
+
|
| 66 |
+
# Initialize model when app starts
|
| 67 |
+
with app.app_context():
|
| 68 |
+
load_model()
|
| 69 |
+
|
| 70 |
+
@app.route('/upload', methods=['POST'])
|
| 71 |
+
def handle_upload():
|
| 72 |
+
temp_midi_path = None
|
| 73 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Validate input
|
| 77 |
+
if 'midi_file' not in request.files:
|
| 78 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
| 79 |
+
|
| 80 |
+
midi_file = request.files['midi_file']
|
| 81 |
+
start_time = request.form.get('start_time', '0')
|
| 82 |
+
end_time = request.form.get('end_time', '0')
|
| 83 |
+
|
| 84 |
+
# Create temporary MIDI file with random name
|
| 85 |
+
temp_dir = tempfile.gettempdir()
|
| 86 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
| 87 |
+
|
| 88 |
+
# Save uploaded MIDI to temp file
|
| 89 |
+
midi_file.save(temp_midi_path)
|
| 90 |
+
|
| 91 |
+
# Process MIDI
|
| 92 |
+
midi = mido.MidiFile(temp_midi_path)
|
| 93 |
+
model = load_model()
|
| 94 |
+
harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
| 95 |
+
|
| 96 |
+
# Save harmonized MIDI (overwriting temp file)
|
| 97 |
+
harmonized_midi.save(temp_midi_path)
|
| 98 |
+
|
| 99 |
+
# Convert to MusicXML string
|
| 100 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
| 101 |
+
|
| 102 |
+
# Final type verification
|
| 103 |
+
if not isinstance(musicxml_str, str):
|
| 104 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
| 105 |
+
|
| 106 |
+
return jsonify({
|
| 107 |
+
"status": "success",
|
| 108 |
+
"musicxml": musicxml_str
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error processing request: {str(e)}")
|
| 113 |
+
traceback.print_exc()
|
| 114 |
+
return jsonify({
|
| 115 |
+
"status": "error",
|
| 116 |
+
"message": str(e)
|
| 117 |
+
}), 400
|
| 118 |
+
finally:
|
| 119 |
+
# Clean up temp file
|
| 120 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
| 121 |
+
try:
|
| 122 |
+
os.unlink(temp_midi_path)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
| 125 |
+
|
| 126 |
+
@app.route('/uploadinfill', methods=['POST'])
|
| 127 |
+
def handle_upload_infilling():
|
| 128 |
+
temp_midi_path = None
|
| 129 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
# Validate input
|
| 133 |
+
if 'midi_file' not in request.files:
|
| 134 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
| 135 |
+
|
| 136 |
+
midi_file = request.files['midi_file']
|
| 137 |
+
start_time = request.form.get('start_time', '0')
|
| 138 |
+
end_time = request.form.get('end_time', '0')
|
| 139 |
+
|
| 140 |
+
# Create temporary MIDI file with random name
|
| 141 |
+
temp_dir = tempfile.gettempdir()
|
| 142 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
| 143 |
+
|
| 144 |
+
# Save uploaded MIDI to temp file
|
| 145 |
+
midi_file.save(temp_midi_path)
|
| 146 |
+
|
| 147 |
+
# Process MIDI
|
| 148 |
+
midi = mido.MidiFile(temp_midi_path)
|
| 149 |
+
model = load_model()
|
| 150 |
+
infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
| 151 |
+
|
| 152 |
+
# Save harmonized MIDI (overwriting temp file)
|
| 153 |
+
infilled_midi.save(temp_midi_path)
|
| 154 |
+
|
| 155 |
+
# Convert to MusicXML string
|
| 156 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
| 157 |
+
|
| 158 |
+
# Final type verification
|
| 159 |
+
if not isinstance(musicxml_str, str):
|
| 160 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
| 161 |
+
|
| 162 |
+
return jsonify({
|
| 163 |
+
"status": "success",
|
| 164 |
+
"musicxml": musicxml_str
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error processing request: {str(e)}")
|
| 169 |
+
traceback.print_exc()
|
| 170 |
+
return jsonify({
|
| 171 |
+
"status": "error",
|
| 172 |
+
"message": str(e)
|
| 173 |
+
}), 400
|
| 174 |
+
finally:
|
| 175 |
+
# Clean up temp file
|
| 176 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
| 177 |
+
try:
|
| 178 |
+
os.unlink(temp_midi_path)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
| 181 |
+
|
| 182 |
+
@app.route('/uploadchangemelody', methods=['POST'])
|
| 183 |
+
def handle_upload_changemelody():
|
| 184 |
+
temp_midi_path = None
|
| 185 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
# Validate input
|
| 189 |
+
if 'midi_file' not in request.files:
|
| 190 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
| 191 |
+
|
| 192 |
+
midi_file = request.files['midi_file']
|
| 193 |
+
start_time = request.form.get('start_time', '0')
|
| 194 |
+
end_time = request.form.get('end_time', '0')
|
| 195 |
+
|
| 196 |
+
# Create temporary MIDI file with random name
|
| 197 |
+
temp_dir = tempfile.gettempdir()
|
| 198 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
| 199 |
+
|
| 200 |
+
# Save uploaded MIDI to temp file
|
| 201 |
+
midi_file.save(temp_midi_path)
|
| 202 |
+
|
| 203 |
+
# Process MIDI
|
| 204 |
+
midi = mido.MidiFile(temp_midi_path)
|
| 205 |
+
model = load_model()
|
| 206 |
+
changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
| 207 |
+
|
| 208 |
+
# Save harmonized MIDI (overwriting temp file)
|
| 209 |
+
changed_melody_midi.save(temp_midi_path)
|
| 210 |
+
|
| 211 |
+
# Convert to MusicXML string
|
| 212 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
| 213 |
+
|
| 214 |
+
# Final type verification
|
| 215 |
+
if not isinstance(musicxml_str, str):
|
| 216 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
| 217 |
+
|
| 218 |
+
return jsonify({
|
| 219 |
+
"status": "success",
|
| 220 |
+
"musicxml": musicxml_str
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
print(f"Error processing request: {str(e)}")
|
| 225 |
+
traceback.print_exc()
|
| 226 |
+
return jsonify({
|
| 227 |
+
"status": "error",
|
| 228 |
+
"message": str(e)
|
| 229 |
+
}), 400
|
| 230 |
+
finally:
|
| 231 |
+
# Clean up temp file
|
| 232 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
| 233 |
+
try:
|
| 234 |
+
os.unlink(temp_midi_path)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == '__main__':
|
| 240 |
+
app.run(debug=True, port=5000)
|
examples/full-score3.mid
ADDED
|
Binary file (1.36 kB). View file
|
|
|
examples/strawberry.mid
ADDED
|
Binary file (24.2 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib == 3.7.1
|
| 2 |
+
midi2audio == 0.1.1
|
| 3 |
+
mido == 1.2.10
|
| 4 |
+
numpy >= 1.22.4
|
| 5 |
+
torch >= 2.0.1
|
| 6 |
+
transformers == 4.29.2
|
| 7 |
+
tqdm == 4.65.0
|
| 8 |
+
flask==3.1.1
|
| 9 |
+
flask-cors==5.0.1
|
| 10 |
+
music21
|
| 11 |
+
gunicorn
|