yhj137's picture
update
651903a
from miditoolkit import MidiFile, Note, Instrument, TempoChange, ControlChange
import bisect
import numpy as np
import os
from copy import copy
import random
from collections import defaultdict
"""
def normalize_midi(midi_obj, target_ticks_per_beat = 500, target_tempo = 120):
ticks_per_beat = midi_obj.ticks_per_beat
merged_events = []
for i in range(len(midi_obj.instruments)):
filter_control_changes = []
for cc in midi_obj.instruments[i].control_changes:
if cc.number == 64:
filter_control_changes.append(cc)
merged_events.extend(midi_obj.instruments[i].notes + filter_control_changes)
merged_events.sort(key=lambda x: (x.start, x.pitch) if isinstance(x, Note) else (x.time, x.number))
time_interval = []
last_time = 0
for note in merged_events:
if isinstance(note, Note):
time_interval.append(note.start - last_time)
last_time = note.start
else:
time_interval.append(note.time - last_time)
last_time = note.time
output_notes = []
output_cc = []
ind = -1
now_tempo = 120
now_time = 0
for i, note in enumerate(merged_events):
if isinstance(note, Note):
time = note.start
else:
time = note.time
while ind + 1 < len(midi_obj.tempo_changes) and time >= midi_obj.tempo_changes[ind+1].time:
now_tempo = midi_obj.tempo_changes[ind+1].tempo
ind += 1
ratio = target_ticks_per_beat * target_tempo / now_tempo / ticks_per_beat
start_time = time_interval[i] * ratio + now_time
if isinstance(note, Note):
end_time = (note.end - note.start) * ratio + start_time
output_notes.append(Note(note.velocity, note.pitch, round(start_time), round(end_time)))
else:
output_cc.append(ControlChange(64, note.value, round(start_time)))
now_time = round(start_time)
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
output_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_cc))
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
for note in output_notes:
output_midi_obj.max_tick = max(output_midi_obj.max_tick, note.end)
for cc in output_cc:
output_midi_obj.max_tick = max(output_midi_obj.max_tick, cc.time)
return output_midi_obj
"""
"""
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
# 创建一个新的、干净的MidiFile对象用于输出
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
# 获取原始MIDI的tick到秒的精确映射
# 这是最关键的一步,partitura和miditoolkit都有类似功能
# miditoolkit的get_tick_to_time_mapping()可以处理所有tempo变化
tick_to_time_map = midi_obj.get_tick_to_time_mapping()
# 计算从秒转换回目标tick的比例因子
# 目标MIDI中,每秒对应的tick数 = target_ticks_per_beat * (target_tempo / 60)
seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
merged_notes = []
merged_cc = []
# 遍历所有乐器轨道
for instrument in midi_obj.instruments:
# 只处理非鼓组的乐器
if not instrument.is_drum:
# --- 处理音符 (Notes) ---
for note in instrument.notes:
# 1. 将原始tick转换为绝对秒数
start_time_sec = tick_to_time_map[note.start]
end_time_sec = tick_to_time_map[note.end]
# 2. 将绝对秒数转换为目标tick
new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
# 避免duration为0的音符
if new_start_tick == new_end_tick:
new_end_tick += 1
merged_notes.append(Note(velocity=note.velocity,
pitch=note.pitch,
start=new_start_tick,
end=new_end_tick))
# --- 处理延音踏板 (CC #64) ---
for cc in instrument.control_changes:
if cc.number == 64:
# 1. 将原始tick转换为绝对秒数
time_sec = tick_to_time_map[cc.time]
# 2. 将绝对秒数转换为目标tick
new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
merged_cc.append(ControlChange(number=64,
value=cc.value,
time=new_time_tick))
# --- 排序并创建新乐器 ---
# 按开始时间排序,对于同时开始的事件,CC优先于Note
merged_notes.sort(key=lambda x: (x.start, x.pitch))
merged_cc.sort(key=lambda x: (x.time, x.number))
output_instrument = Instrument(program=0, is_drum=False, name="Piano")
output_instrument.notes = merged_notes
output_instrument.control_changes = merged_cc
output_midi_obj.instruments.append(output_instrument)
# --- 正确计算 max_tick ---
max_tick = 0
if output_instrument.notes:
max_tick = max(max_tick, max(n.end for n in output_instrument.notes))
if output_instrument.control_changes:
max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes))
output_midi_obj.max_tick = max_tick
return output_midi_obj
"""
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
"""
将一个MidiFile对象标准化:
1. 合并所有轨道的钢琴音符和延音踏板事件。
2. 将所有时间信息(包括tempo变化)统一转换为一个固定的ticks_per_beat和tempo。
3. 清理重叠音符以避免解析错误。
4. 正确计算并设置max_tick。
Args:
midi_obj (MidiFile): 原始的MidiFile对象。
target_ticks_per_beat (int): 目标ticks_per_beat.
target_tempo (float): 目标tempo (BPM).
Returns:
MidiFile: 标准化后的新MidiFile对象。
"""
# 创建一个新的、干净的MidiFile对象用于输出
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
tick_to_time_map = midi_obj.get_tick_to_time_mapping()
seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
# --- 1. 收集并转换所有音符 ---
all_converted_notes = []
for instrument in midi_obj.instruments:
if not instrument.is_drum:
for note in instrument.notes:
start_time_sec = tick_to_time_map[note.start]
end_time_sec = tick_to_time_map[note.end]
new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
if new_start_tick >= new_end_tick:
# 确保音符至少有1 tick的长度
new_end_tick = new_start_tick + 1
all_converted_notes.append(Note(velocity=note.velocity,
pitch=note.pitch,
start=new_start_tick,
end=new_end_tick))
# --- 2. 清理重叠音符 (关键新增部分) ---
# 首先按音高分组,然后按开始时间排序
notes_by_pitch = defaultdict(list)
for note in all_converted_notes:
notes_by_pitch[note.pitch].append(note)
merged_notes = []
for pitch in sorted(notes_by_pitch.keys()):
# 对每个音高的音符列表按开始时间排序
sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
# 迭代并修复重叠
if len(sorted_notes) > 1:
for i in range(len(sorted_notes) - 1):
current_note = sorted_notes[i]
next_note = sorted_notes[i+1]
# 如果当前音符的结束时间晚于或等于下一个音符的开始时间
if current_note.end >= next_note.start:
# 修正当前音符的结束时间,让它在下一个音符开始前结束
# 我们可以让它在下一个音符开始时就结束
current_note.end = next_note.start
# 如果修复后导致时长为0,则丢弃该音符(或者设置为1 tick,这里选择前者更干净)
if current_note.start >= current_note.end:
# 标记为待删除,而不是直接删除,以避免迭代问题
current_note.pitch = -1 # 用一个无效音高作为标记
# 将处理过的(且未被标记删除的)音符添加到最终列表
merged_notes.extend([n for n in sorted_notes if n.pitch != -1])
# --- 3. 收集并转换CC事件 ---
merged_cc = []
for instrument in midi_obj.instruments:
if not instrument.is_drum:
for cc in instrument.control_changes:
if cc.number == 64:
time_sec = tick_to_time_map[cc.time]
new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
merged_cc.append(ControlChange(number=64,
value=cc.value,
time=new_time_tick))
# --- 4. 排序并创建新乐器 ---
merged_notes.sort(key=lambda x: (x.start, x.pitch))
merged_cc.sort(key=lambda x: (x.time, x.number))
output_instrument = Instrument(program=0, is_drum=False, name="Piano")
output_instrument.notes = merged_notes
output_instrument.control_changes = merged_cc
output_midi_obj.instruments.append(output_instrument)
# --- 5. 正确计算 max_tick ---
max_tick = 0
if output_instrument.notes:
max_tick = max(max_tick, max(n.end for n in output_instrument.notes if n.end is not None))
if output_instrument.control_changes:
max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes if c.time is not None))
# 添加一个小的buffer,确保最后一个事件不会被截断
output_midi_obj.max_tick = max_tick + target_ticks_per_beat
return output_midi_obj
def merge_and_sort(midi_obj):
output_midi_obj = MidiFile(ticks_per_beat=500)
output_midi_obj.time_signature_changes = midi_obj.time_signature_changes
output_midi_obj.key_signature_changes = midi_obj.key_signature_changes
output_instrument = Instrument(program=0, is_drum=False, name="Piano")
tick_ratio = 500 / midi_obj.ticks_per_beat
all_notes = []
for instrument in midi_obj.instruments:
if not instrument.is_drum:
for note in instrument.notes:
all_notes.append(
Note(
velocity=note.velocity,
pitch=note.pitch,
start=round(note.start * tick_ratio),
end=round(note.end * tick_ratio)
)
)
notes_by_pitch = defaultdict(list)
for note in all_notes:
notes_by_pitch[note.pitch].append(note)
merged_notes = []
for pitch in sorted(notes_by_pitch.keys()):
sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
if len(sorted_notes) > 1:
for i in range(len(sorted_notes) - 1):
current_note = sorted_notes[i]
next_note = sorted_notes[i+1]
if current_note.end >= next_note.start:
current_note.end = next_note.start
if current_note.start >= current_note.end:
current_note.pitch = -1
merged_notes.extend([n for n in sorted_notes if n.pitch != -1])
merged_notes.sort(key=lambda x: (x.start, x.pitch))
output_instrument.notes = merged_notes
output_midi_obj.instruments.append(output_instrument)
for time_signature in output_midi_obj.time_signature_changes:
time_signature.time = round(time_signature.time * tick_ratio)
for key_signature in output_midi_obj.key_signature_changes:
key_signature.time = round(key_signature.time * tick_ratio)
return output_midi_obj
def midi_to_ids(config, midi_obj, normalize=True):
def get_pedal(time_list, ccs, time):
i = bisect.bisect_right(time_list, time)
if i == 0:
return 0
else:
return ccs[i-1].value
if normalize:
norm_midi_obj = normalize_midi(midi_obj)
else:
norm_midi_obj = midi_obj
time_list = [cc.time for cc in norm_midi_obj.instruments[0].control_changes]
#print(time_list)
intervals = []
last_time = 0
for note in norm_midi_obj.instruments[0].notes:
intervals.append(note.start - last_time)
last_time = note.start
intervals.append(4990)
ids = []
last_time = 0
for i, note in enumerate(norm_midi_obj.instruments[0].notes):
interval = config.timing_start + intervals[i]
#print(interval - interval_start)
pitch = config.pitch_start + note.pitch
velocity = config.velocity_start + note.velocity
duration = config.timing_start + note.duration
last_time = last_time + intervals[i]
pedal1 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time)
pedal2 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 1 / 4)
pedal3 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 2 / 4)
pedal4 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 3 / 4)
pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))
ids.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])
return ids
def ids_to_midi(config, ids, target_ticks_per_beat = 500, target_tempo = 120, pedal_ratio = 1.0):
note_list = []
cc_list = []
intervals = []
for i in range(0, len(ids), 8):
intervals.append(ids[i+1] - config.timing_start)
intervals.append(4990)
last_time = 0
for i in range(0, len(ids), 8):
interval = intervals[i // 8]
pitch = ids[i] - config.pitch_start
velocity = ids[i+2] - config.velocity_start
duration = ids[i+3] - config.timing_start
pedal1 = ids[i+4] - config.pedal_start
pedal2 = ids[i+5] - config.pedal_start
pedal3 = ids[i+6] - config.pedal_start
pedal4 = ids[i+7] - config.pedal_start
note_list.append(Note(velocity, pitch, last_time + interval, last_time + interval + duration))
last_time += interval
interval_time = intervals[i // 8 + 1]
interval_step = intervals[i // 8 + 1] / 4 * pedal_ratio
cc_list.append(ControlChange(64, pedal1, last_time))
cc_list.append(ControlChange(64, pedal2, round(last_time + interval_step)))
cc_list.append(ControlChange(64, pedal3, round(last_time + interval_time - interval_step * 2)))
cc_list.append(ControlChange(64, pedal4, round(last_time + interval_time - interval_step)))
#cc_list.append(ControlChange(64, pedal1, last_time))
#cc_list.append(ControlChange(64, pedal2, round(last_time + intervals[i // 8 + 1] * 1 / 4)))
#cc_list.append(ControlChange(64, pedal3, round(last_time + intervals[i // 8 + 1] * 2 / 4)))
#cc_list.append(ControlChange(64, pedal4, round(last_time + intervals[i // 8 + 1] * 3 / 4)))
last_value = 0
new_cc_list = []
for cc in cc_list:
if cc.value != last_value:
new_cc_list.append(cc)
last_value = cc.value
max_tick = 0
for note in note_list:
max_tick = max(max_tick, note.end)
for cc in cc_list:
max_tick = max(max_tick, cc.time)
max_tick = max_tick + 1
output = MidiFile(ticks_per_beat=target_ticks_per_beat)
output.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=note_list, control_changes=new_cc_list))
output.tempo_changes.append(TempoChange(target_tempo, 0))
output.max_tick = max_tick
return output
def read_corresp(corresp_path):
out = []
performacne_id_list = []
with open(corresp_path, "r") as f:
align_txt = f.readlines()
score_ids_map = {}
performance_ids_map = {}
score_temp_list = []
performance_temp_list = set()
for line in align_txt[1:]:
informs = line.split("\t")
if informs[0] != '*':
score_temp_list.append((float(informs[1]), int(informs[3]), int(informs[0])))
if informs[5] != '*':
performance_temp_list.add((float(informs[6]), int(informs[8]), int(informs[5])))
performance_temp_list = list(performance_temp_list)
score_temp_list.sort()
performance_temp_list.sort()
for i, inform in enumerate(score_temp_list):
score_ids_map[inform[2]] = i
for i, inform in enumerate(performance_temp_list):
performance_ids_map[inform[2]] = i
for line in align_txt[1:]:
informs = line.split("\t")
if informs[0] == '*':
break
if informs[5] != '*':
out.append((score_ids_map[int(informs[0])], performance_ids_map[int(informs[5])]))
else:
out.append((score_ids_map[int(informs[0])], -1))
for line in align_txt[1:]:
informs = line.split("\t")
if informs[5] != '*':
performacne_id_list.append(performance_ids_map[int(informs[5])])
if out[0][1] == -1:
out[0] = (out[0][0], min(performacne_id_list))
if out[-1][1] == -1:
out[-1] = (out[-1][0], max(performacne_id_list))
out.sort()
return out
def interpolate(a, b):
a = np.array(a) + np.linspace(0, 1e-5, len(a))
b = np.array(b)
known_inds = np.where(~np.isnan(b))[0]
x_known = a[known_inds]
y_known = b[known_inds]
res = np.interp(a, x_known, y_known)
res[known_inds] = b[known_inds]
return [round(i) for i in res.tolist()]
def segment_sequences(x, label, unknown_ids, total_notes, max_consecutive_missing, min_segment_notes):
if not unknown_ids:
if total_notes >= min_segment_notes:
return [x], [label]
else:
return [], []
x_segments = []
label_segments = []
unknown_set = set(unknown_ids)
last_cut_note_idx = 0
consecutive_missing_count = 0
for i in range(total_notes):
if i in unknown_set:
consecutive_missing_count += 1
else:
consecutive_missing_count = 0
if consecutive_missing_count >= max_consecutive_missing:
segment_end_note_idx = i - consecutive_missing_count + 1
if segment_end_note_idx - last_cut_note_idx >= min_segment_notes:
start_token = last_cut_note_idx * 8
end_token = segment_end_note_idx * 8
x_segments.append(x[start_token:end_token])
label_segments.append(label[start_token:end_token])
last_cut_note_idx = i + 1
consecutive_missing_count = 0
if total_notes - last_cut_note_idx >= min_segment_notes:
start_token = last_cut_note_idx * 8
x_segments.append(x[start_token:])
label_segments.append(label[start_token:])
return x_segments, label_segments
def align_score_and_performance(config, score_midi_obj, performance_midi_obj):
norm_score_midi_obj = normalize_midi(score_midi_obj)
norm_performance_midi_obj = normalize_midi(performance_midi_obj)
norm_score_midi_obj.dump("temp/score.mid")
norm_performance_midi_obj.dump("temp/performance.mid")
os.chdir("./tools/AlignmentTool")
os.system(f"timeout 120s ./MIDIToMIDIAlign.sh ../../temp/performance ../../temp/score")
os.chdir("./../../")
corresp_list = read_corresp("temp/score_corresp.txt")
aligned_midi_obj = MidiFile(ticks_per_beat=500)
score_notes = norm_score_midi_obj.instruments[0].notes
performance_notes = norm_performance_midi_obj.instruments[0].notes
score_start_list = []
output_notes = []
output_ccs = []
vel_list = []
start_list = []
duration_list = []
unknown_ids = []
for i, ids in enumerate(corresp_list):
if ids[1] != -1:
vel_list.append(performance_notes[ids[1]].velocity)
start_list.append(performance_notes[ids[1]].start)
duration_list.append(performance_notes[ids[1]].end - performance_notes[ids[1]].start)
else:
vel_list.append(np.nan)
duration_list.append(np.nan)
unknown_ids.append(i)
score_start_list.append(score_notes[ids[0]].start)
start_list.sort()
temp = []
cnt = 0
for i in range(len(corresp_list)):
if i not in unknown_ids:
temp.append(start_list[cnt])
cnt += 1
else:
temp.append(np.nan)
start_list = interpolate(score_start_list, temp)
vel_list = interpolate(start_list, vel_list)
duration_list = interpolate(start_list, duration_list)
end_list = []
for i, ids in enumerate(corresp_list):
end = start_list[i]+duration_list[i]
end_list.append(end)
output_notes.append(Note(vel_list[i], score_notes[ids[0]].pitch, start_list[i], end))
max_tick = max(end_list) + 4999
for cc in norm_performance_midi_obj.instruments[0].control_changes:
if cc.time <= max_tick:
output_ccs.append(cc)
else:
break
aligned_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_ccs))
x = midi_to_ids(config, norm_score_midi_obj)
label = midi_to_ids(config, aligned_midi_obj, normalize=False)
assert(len(x) == len(label))
for i in range(len(x)):
if i % 8 == 0:
assert(x[i] == label[i])
total_notes = len(score_notes)
xs, labels = segment_sequences(
x,
label,
unknown_ids,
total_notes,
5,
64,
)
return xs, labels
def enhanced_ids(config, ids):
res = copy(ids)
retry = 10
for i in range(len(res)):
j = i % 8
if j == 3:
value = res[i] - config.valid_id_range[j][0]
if value == 10:
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 5)
if n >= -9 and n <= 5:
noise = n
break
else:
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 5)
if n >= -4 and n <= 5:
noise = n
break
value = min(max(value + noise, 0), 4999)
res[i] = config.valid_id_range[j][0] + value
elif j == 2:
value = res[i] - config.valid_id_range[j][0]
if value == 5:
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 2.5)
if n >= -4 and n <= 2:
noise = n
break
elif value == 120:
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 2.5)
if n >= -2 and n <= 7:
noise = n
break
else:
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 2.5)
if n >= -2 and n <= 2:
noise = n
break
value = min(max(value + noise, 0), 127)
res[i] = config.valid_id_range[j][0] + value
elif j == 1:
value = res[i] - config.valid_id_range[j][0]
noise = 0
for _ in range(retry):
n = round(np.random.randn() * 5)
if n >= -4 and n <= 5:
noise = n
break
value = min(max(value + noise, 0), 4990)
res[i] = config.valid_id_range[j][0] + value
return res
def enhanced_ids_uniform(config, ids):
res = copy(ids)
for i in range(len(res)):
j = i % 8
if j == 3:
value = res[i] - config.valid_id_range[j][0]
if value == 10:
noise = random.randint(-9, 5)
else:
noise = random.randint(-4, 5)
value = min(max(value + noise, 0), 4999)
res[i] = config.valid_id_range[j][0] + value
elif j == 2:
value = res[i] - config.valid_id_range[j][0]
if value == 5:
noise = random.randint(-4, 2)
elif value == 120:
noise = random.randint(-2, 7)
else:
noise = random.randint(-2, 2)
value = min(max(value + noise, 0), 127)
res[i] = config.valid_id_range[j][0] + value
elif j == 1:
value = res[i] - config.valid_id_range[j][0]
noise = random.randint(-4, 5)
value = min(max(value + noise, 0), 4990)
res[i] = config.valid_id_range[j][0] + value
return res
#if __name__ == "__main__":
# midi_obj = MidiFile("data/midi/test/2.mid")
# ids = midi_to_ids(midi_obj)
# midi = ids_to_midi(ids)
# midi.dump("data/rebuild/2.mid")