|
|
from src.utils.midi import ids_to_midi, midi_to_ids |
|
|
from src.model.pianoformer import PianoT5Gemma |
|
|
from miditoolkit import MidiFile |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from transformers import LogitsProcessorList, LogitsProcessor |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
from src.utils.midi import normalize_midi, merge_and_sort |
|
|
from miditoolkit import MidiFile, Note, TempoChange, Instrument, ControlChange |
|
|
import bisect |
|
|
|
|
|
class BatchSparseForcedTokenProcessor(LogitsProcessor): |
|
|
def __init__(self, input_ids, config, target_len, origin_len, already, weight, progress_callback): |
|
|
self.batch_map = [{j: input_ids[i][j] for j in range(0, len(input_ids[i]), 8)} for i in range(len(input_ids))] |
|
|
self.valid_id_range = config.valid_id_range |
|
|
self.target_len = target_len |
|
|
self.origin_len = origin_len |
|
|
self.already = already |
|
|
self.weight = weight |
|
|
self.progress_callback = progress_callback |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
if self.progress_callback: |
|
|
self.progress_callback( |
|
|
(input_ids.shape[1] - self.origin_len) / (self.target_len - self.origin_len) * self.weight + self.already |
|
|
) |
|
|
step = input_ids.shape[1] - 1 |
|
|
batch_size = scores.shape[0] |
|
|
for i in range(batch_size): |
|
|
sample_map = self.batch_map[i] |
|
|
if step in sample_map: |
|
|
forced_token_id = sample_map[step] |
|
|
scores[i] = float('-inf') |
|
|
scores[i, forced_token_id] = 0.0 |
|
|
else: |
|
|
step = step % 8 |
|
|
scores[i, :self.valid_id_range[step][0]] = float('-inf') |
|
|
scores[i, self.valid_id_range[step][1]:] = float('-inf') |
|
|
|
|
|
|
|
|
return scores |
|
|
|
|
|
@torch.no_grad() |
|
|
def batch_performance_render( |
|
|
model, |
|
|
score_midi_objs, |
|
|
max_context_length=4096, |
|
|
overlap_ratio=0.5, |
|
|
temperature=1.0, |
|
|
top_p=0.95, |
|
|
device="cpu", |
|
|
progress_callback=None |
|
|
): |
|
|
def slide_window(total_len, window_len): |
|
|
if total_len <= window_len: |
|
|
return [(0, total_len)] |
|
|
window_len = window_len // 8 * 8 |
|
|
out = [] |
|
|
start = 0 |
|
|
while start + window_len <= total_len: |
|
|
out.append((start, start + window_len)) |
|
|
start += int(window_len * (1 - overlap_ratio)) // 8 * 8 |
|
|
if out[-1][1] != total_len: |
|
|
out.append((start, total_len)) |
|
|
return out |
|
|
if max_context_length > 4096: |
|
|
raise ValueError("You should set max_context_length <= 4096!") |
|
|
batch_ids = [torch.tensor(midi_to_ids(model.config, score_midi_obj), dtype=torch.long).to(device) for score_midi_obj in score_midi_objs] |
|
|
len_list = [len(batch_ids[i]) for i in range(len(batch_ids))] |
|
|
|
|
|
input_ids = pad_sequence(batch_ids, batch_first=True, padding_value=model.config.pad_token_id) |
|
|
windows = slide_window(input_ids.shape[1], max_context_length) |
|
|
|
|
|
output_list = [] |
|
|
res_tensor = None |
|
|
for i in tqdm(range(len(windows))): |
|
|
start, end = windows[i] |
|
|
logits_processor = LogitsProcessorList([ |
|
|
BatchSparseForcedTokenProcessor( |
|
|
input_ids[:,start:end], |
|
|
model.config, |
|
|
end, |
|
|
start, |
|
|
i / len(windows), |
|
|
1 / len(windows), |
|
|
progress_callback, |
|
|
) |
|
|
]) |
|
|
if i == 0: |
|
|
output = model.generate( |
|
|
input_ids[:,start:end], |
|
|
do_sample=True, |
|
|
max_new_tokens=end-start, |
|
|
logits_processor=logits_processor, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
res_tensor = output[:,1:] |
|
|
else: |
|
|
last_start, last_end = windows[i-1] |
|
|
length = int(((last_end-last_start) - (start-last_start)) * 0.2) |
|
|
decoder_input_ids = output_list[i-1][:, start-last_start:last_end-last_start - length] |
|
|
start_tensor = torch.tensor([[model.config.bos_token_id] for _ in range(input_ids.shape[0])], dtype=torch.long).to(device) |
|
|
decoder_input_ids = torch.cat([start_tensor, decoder_input_ids], dim=1) |
|
|
|
|
|
output = model.generate( |
|
|
input_ids[:,start:end], |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
do_sample=True, |
|
|
max_new_tokens=end-last_end+length, |
|
|
logits_processor=logits_processor, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
res_tensor = torch.cat([res_tensor[:,:-length], output[:,-(end-last_end+length):]], dim=1) |
|
|
output_list.append(output) |
|
|
res_tensor = res_tensor.cpu().numpy().tolist() |
|
|
|
|
|
res = [] |
|
|
for i in range(len(res_tensor)): |
|
|
|
|
|
res.append(ids_to_midi(model.config, res_tensor[i][:len_list[i]])) |
|
|
return res |
|
|
|
|
|
|
|
|
def map_midi(score_midi_obj, performance_midi_obj): |
|
|
def compute_duration(start_time, target_duration, tempo_list): |
|
|
if target_duration <= 0: |
|
|
return 0 |
|
|
if not tempo_list: |
|
|
|
|
|
tempo_list = [TempoChange(120, 0)] |
|
|
|
|
|
|
|
|
|
|
|
tempo_times = [t.time for t in tempo_list] |
|
|
|
|
|
|
|
|
start_tempo_idx = bisect.bisect_right(tempo_times, start_time) - 1 |
|
|
|
|
|
if start_tempo_idx < 0: |
|
|
start_tempo_idx = 0 |
|
|
|
|
|
|
|
|
total_ticks_duration = 0.0 |
|
|
time_remaining_ms = float(target_duration) |
|
|
current_tick = start_time |
|
|
current_tempo_idx = start_tempo_idx |
|
|
|
|
|
|
|
|
|
|
|
while time_remaining_ms > 1e-9: |
|
|
current_tempo_event = tempo_list[current_tempo_idx] |
|
|
current_bpm = current_tempo_event.tempo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ms_per_tick = (60 * 1000.0 / current_bpm) / 500 |
|
|
|
|
|
|
|
|
|
|
|
end_of_segment_tick = float('inf') |
|
|
if current_tempo_idx + 1 < len(tempo_list): |
|
|
end_of_segment_tick = tempo_list[current_tempo_idx + 1].time |
|
|
|
|
|
|
|
|
ticks_in_segment = end_of_segment_tick - current_tick |
|
|
|
|
|
ms_in_segment = ticks_in_segment * ms_per_tick |
|
|
|
|
|
|
|
|
if time_remaining_ms <= ms_in_segment: |
|
|
|
|
|
|
|
|
ticks_needed = time_remaining_ms / ms_per_tick |
|
|
total_ticks_duration += ticks_needed |
|
|
|
|
|
time_remaining_ms = 0 |
|
|
else: |
|
|
|
|
|
|
|
|
total_ticks_duration += ticks_in_segment |
|
|
time_remaining_ms -= ms_in_segment |
|
|
|
|
|
|
|
|
current_tick = end_of_segment_tick |
|
|
current_tempo_idx += 1 |
|
|
|
|
|
|
|
|
return round(total_ticks_duration) |
|
|
|
|
|
def ms_to_tick(target_ms, tempo_list): |
|
|
|
|
|
if target_ms <= 0: |
|
|
return 0 |
|
|
if not tempo_list: |
|
|
|
|
|
tempo_list = [TempoChange(120, 0)] |
|
|
|
|
|
|
|
|
accumulated_ms = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(tempo_list) - 1): |
|
|
current_tempo_event = tempo_list[i] |
|
|
next_tempo_event = tempo_list[i+1] |
|
|
|
|
|
current_bpm = current_tempo_event.tempo |
|
|
|
|
|
|
|
|
ticks_in_segment = next_tempo_event.time - current_tempo_event.time |
|
|
|
|
|
|
|
|
if ticks_in_segment == 0: |
|
|
continue |
|
|
|
|
|
ms_per_tick = (60 * 1000.0 / current_bpm) / 500 |
|
|
ms_in_segment = ticks_in_segment * ms_per_tick |
|
|
|
|
|
|
|
|
if target_ms <= accumulated_ms + ms_in_segment: |
|
|
|
|
|
ms_into_segment = target_ms - accumulated_ms |
|
|
ticks_needed = ms_into_segment / ms_per_tick |
|
|
|
|
|
|
|
|
final_tick = current_tempo_event.time + ticks_needed |
|
|
return round(final_tick) |
|
|
|
|
|
|
|
|
accumulated_ms += ms_in_segment |
|
|
|
|
|
|
|
|
last_tempo_event = tempo_list[-1] |
|
|
last_bpm = last_tempo_event.tempo |
|
|
|
|
|
ms_per_tick = (60 * 1000.0 / last_bpm) / 500 |
|
|
|
|
|
|
|
|
ms_into_segment = target_ms - accumulated_ms |
|
|
ticks_needed = ms_into_segment / ms_per_tick |
|
|
|
|
|
|
|
|
final_tick = last_tempo_event.time + ticks_needed |
|
|
return round(final_tick) |
|
|
|
|
|
norm_score = merge_and_sort(score_midi_obj) |
|
|
norm_performance = normalize_midi(performance_midi_obj) |
|
|
|
|
|
score_notes = norm_score.instruments[0].notes |
|
|
performance_notes = norm_performance.instruments[0].notes |
|
|
performance_ccs = norm_performance.instruments[0].control_changes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_list = [] |
|
|
last = -1 |
|
|
score_start = score_notes[0].start |
|
|
performance_start = performance_notes[0].start |
|
|
for i in range(len(score_notes)): |
|
|
performance_notes[i].end -= performance_start |
|
|
performance_notes[i].start -= performance_start |
|
|
score_notes[i].end -= score_start |
|
|
score_notes[i].start -= score_start |
|
|
if score_notes[i].start != last: |
|
|
start_list.append((score_notes[i].start, performance_notes[i].start, i)) |
|
|
last = score_notes[i].start |
|
|
|
|
|
for i in range(len(performance_ccs)): |
|
|
performance_ccs[i].time -= performance_start |
|
|
|
|
|
score_interval_list = [] |
|
|
performance_interval_list = [] |
|
|
|
|
|
for i in range(len(start_list)-1): |
|
|
score_interval_list.append(start_list[i+1][0] - start_list[i][0]) |
|
|
performance_interval_list.append(start_list[i+1][1] - start_list[i][1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tempo_list = [] |
|
|
start_note_offset = [] |
|
|
for i in range(len(score_interval_list)): |
|
|
if performance_interval_list[i] != 0: |
|
|
bpm = 120.0 / performance_interval_list[i] * score_interval_list[i] |
|
|
else: |
|
|
bpm = 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_note_offset.append(0) |
|
|
tempo_list.append(max(min(bpm, 500), 10)) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, len(start_note_offset)): |
|
|
start_note_offset[i] += start_note_offset[i-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
note_tempo_list = [] |
|
|
note_performance_align = [] |
|
|
note_start_offset = [0] |
|
|
cnt = 0 |
|
|
for i in range(len(score_notes)): |
|
|
if cnt < len(start_list) - 2 and i >= start_list[cnt + 1][2]: |
|
|
cnt += 1 |
|
|
note_tempo_list.append(tempo_list[cnt]) |
|
|
note_performance_align.append(start_list[cnt][1]) |
|
|
note_start_offset.append(start_note_offset[cnt]) |
|
|
|
|
|
|
|
|
for i in range(len(score_notes)): |
|
|
score_notes[i].start += note_start_offset[i] |
|
|
note_interval_list = [0] |
|
|
for i in range(len(score_notes)-1): |
|
|
note_interval_list.append(score_notes[i+1].start - score_notes[i].start) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
micro_shift_list = [0] |
|
|
cnt = 1 |
|
|
last_time = 0 |
|
|
for i in range(1, len(score_notes)): |
|
|
last_time += note_interval_list[i] / note_tempo_list[i-1] * 120 |
|
|
micro_shift_list.append((performance_notes[i].start - last_time) / 120 * note_tempo_list[i-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = MidiFile(ticks_per_beat=500) |
|
|
res_notes = [] |
|
|
start_time_list = [] |
|
|
tempo_list_filter = [] |
|
|
cc_list = [] |
|
|
last = -1 |
|
|
for i in range(len(score_notes)): |
|
|
start_time_list.append(round(score_notes[i].start + micro_shift_list[i])) |
|
|
|
|
|
|
|
|
|
|
|
if last != round(note_tempo_list[i]): |
|
|
tempo_list_filter.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i]))) |
|
|
last = round(note_tempo_list[i]) |
|
|
for i in range(len(score_notes)): |
|
|
res_notes.append( |
|
|
Note( |
|
|
performance_notes[i].velocity, |
|
|
score_notes[i].pitch, |
|
|
start_time_list[i], |
|
|
start_time_list[i]+compute_duration(start_time_list[i], performance_notes[i].duration, tempo_list_filter) |
|
|
) |
|
|
) |
|
|
|
|
|
for cc in performance_ccs: |
|
|
cc_list.append(ControlChange(64, cc.value, ms_to_tick(cc.time, tempo_list_filter))) |
|
|
|
|
|
|
|
|
res.tempo_changes = tempo_list_filter |
|
|
res.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=res_notes, control_changes=cc_list)) |
|
|
res.time_signature_changes = norm_score.time_signature_changes |
|
|
res.key_signature_changes = norm_score.key_signature_changes |
|
|
return res |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |