File size: 27,051 Bytes
f5399d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651903a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5399d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651903a
f5399d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651903a
 
 
 
f5399d9
651903a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5399d9
 
 
 
 
 
 
 
 
651903a
f5399d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
from miditoolkit import MidiFile, Note, Instrument, TempoChange, ControlChange
import bisect
import numpy as np
import os
from copy import copy
import random
from collections import defaultdict

"""
def normalize_midi(midi_obj, target_ticks_per_beat = 500, target_tempo = 120):
    ticks_per_beat = midi_obj.ticks_per_beat
    merged_events = []
    for i in range(len(midi_obj.instruments)):
        filter_control_changes = []
        for cc in midi_obj.instruments[i].control_changes:
            if cc.number == 64:
                filter_control_changes.append(cc)
        merged_events.extend(midi_obj.instruments[i].notes + filter_control_changes)
    merged_events.sort(key=lambda x: (x.start, x.pitch) if isinstance(x, Note) else (x.time, x.number))
    
    time_interval = []
    last_time = 0
    for note in merged_events:
        if isinstance(note, Note):
            time_interval.append(note.start - last_time)
            last_time = note.start
        else:
            time_interval.append(note.time - last_time)
            last_time = note.time

    output_notes = []
    output_cc = []
    ind = -1
    now_tempo = 120
    now_time = 0
    for i, note in enumerate(merged_events):
        if isinstance(note, Note):
            time = note.start
        else:
            time = note.time
        while ind + 1 < len(midi_obj.tempo_changes) and time >= midi_obj.tempo_changes[ind+1].time:
            now_tempo = midi_obj.tempo_changes[ind+1].tempo
            ind += 1
        ratio = target_ticks_per_beat * target_tempo / now_tempo / ticks_per_beat
        start_time = time_interval[i] * ratio + now_time
        if isinstance(note, Note):
            end_time = (note.end - note.start) * ratio + start_time
            output_notes.append(Note(note.velocity, note.pitch, round(start_time), round(end_time)))
        else:
            output_cc.append(ControlChange(64, note.value, round(start_time)))
        now_time = round(start_time)
    
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_cc))
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    for note in output_notes:
        output_midi_obj.max_tick = max(output_midi_obj.max_tick, note.end)
    for cc in output_cc:
        output_midi_obj.max_tick = max(output_midi_obj.max_tick, cc.time)
    return output_midi_obj
"""

"""
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):    
    # 创建一个新的、干净的MidiFile对象用于输出
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    
    # 获取原始MIDI的tick到秒的精确映射
    # 这是最关键的一步,partitura和miditoolkit都有类似功能
    # miditoolkit的get_tick_to_time_mapping()可以处理所有tempo变化
    tick_to_time_map = midi_obj.get_tick_to_time_mapping()
    
    # 计算从秒转换回目标tick的比例因子
    # 目标MIDI中,每秒对应的tick数 = target_ticks_per_beat * (target_tempo / 60)
    seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)

    merged_notes = []
    merged_cc = []

    # 遍历所有乐器轨道
    for instrument in midi_obj.instruments:
        # 只处理非鼓组的乐器
        if not instrument.is_drum:
            # --- 处理音符 (Notes) ---
            for note in instrument.notes:
                # 1. 将原始tick转换为绝对秒数
                start_time_sec = tick_to_time_map[note.start]
                end_time_sec = tick_to_time_map[note.end]
                
                # 2. 将绝对秒数转换为目标tick
                new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
                new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
                
                # 避免duration为0的音符
                if new_start_tick == new_end_tick:
                    new_end_tick += 1

                merged_notes.append(Note(velocity=note.velocity, 
                                         pitch=note.pitch, 
                                         start=new_start_tick, 
                                         end=new_end_tick))
            
            # --- 处理延音踏板 (CC #64) ---
            for cc in instrument.control_changes:
                if cc.number == 64:
                    # 1. 将原始tick转换为绝对秒数
                    time_sec = tick_to_time_map[cc.time]
                    
                    # 2. 将绝对秒数转换为目标tick
                    new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
                    
                    merged_cc.append(ControlChange(number=64, 
                                                   value=cc.value, 
                                                   time=new_time_tick))

    # --- 排序并创建新乐器 ---
    # 按开始时间排序,对于同时开始的事件,CC优先于Note
    merged_notes.sort(key=lambda x: (x.start, x.pitch))
    merged_cc.sort(key=lambda x: (x.time, x.number))
    
    output_instrument = Instrument(program=0, is_drum=False, name="Piano")
    output_instrument.notes = merged_notes
    output_instrument.control_changes = merged_cc
    output_midi_obj.instruments.append(output_instrument)
    
    # --- 正确计算 max_tick ---
    max_tick = 0
    if output_instrument.notes:
        max_tick = max(max_tick, max(n.end for n in output_instrument.notes))
    if output_instrument.control_changes:
        max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes))
    
    output_midi_obj.max_tick = max_tick

    return output_midi_obj
"""

def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
    """
    将一个MidiFile对象标准化:
    1. 合并所有轨道的钢琴音符和延音踏板事件。
    2. 将所有时间信息(包括tempo变化)统一转换为一个固定的ticks_per_beat和tempo。
    3. 清理重叠音符以避免解析错误。
    4. 正确计算并设置max_tick。

    Args:
        midi_obj (MidiFile): 原始的MidiFile对象。
        target_ticks_per_beat (int): 目标ticks_per_beat.
        target_tempo (float): 目标tempo (BPM).

    Returns:
        MidiFile: 标准化后的新MidiFile对象。
    """
    
    # 创建一个新的、干净的MidiFile对象用于输出
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    
    tick_to_time_map = midi_obj.get_tick_to_time_mapping()
    seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)

    # --- 1. 收集并转换所有音符 ---
    all_converted_notes = []
    for instrument in midi_obj.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                start_time_sec = tick_to_time_map[note.start]
                end_time_sec = tick_to_time_map[note.end]
                
                new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
                new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
                
                if new_start_tick >= new_end_tick:
                    # 确保音符至少有1 tick的长度
                    new_end_tick = new_start_tick + 1

                all_converted_notes.append(Note(velocity=note.velocity, 
                                                pitch=note.pitch, 
                                                start=new_start_tick, 
                                                end=new_end_tick))

    # --- 2. 清理重叠音符 (关键新增部分) ---
    # 首先按音高分组,然后按开始时间排序
    notes_by_pitch = defaultdict(list)
    for note in all_converted_notes:
        notes_by_pitch[note.pitch].append(note)

    merged_notes = []
    for pitch in sorted(notes_by_pitch.keys()):
        # 对每个音高的音符列表按开始时间排序
        sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
        
        # 迭代并修复重叠
        if len(sorted_notes) > 1:
            for i in range(len(sorted_notes) - 1):
                current_note = sorted_notes[i]
                next_note = sorted_notes[i+1]
                
                # 如果当前音符的结束时间晚于或等于下一个音符的开始时间
                if current_note.end >= next_note.start:
                    # 修正当前音符的结束时间,让它在下一个音符开始前结束
                    # 我们可以让它在下一个音符开始时就结束
                    current_note.end = next_note.start
                    # 如果修复后导致时长为0,则丢弃该音符(或者设置为1 tick,这里选择前者更干净)
                    if current_note.start >= current_note.end:
                         # 标记为待删除,而不是直接删除,以避免迭代问题
                         current_note.pitch = -1 # 用一个无效音高作为标记

        # 将处理过的(且未被标记删除的)音符添加到最终列表
        merged_notes.extend([n for n in sorted_notes if n.pitch != -1])

    # --- 3. 收集并转换CC事件 ---
    merged_cc = []
    for instrument in midi_obj.instruments:
        if not instrument.is_drum:
            for cc in instrument.control_changes:
                if cc.number == 64:
                    time_sec = tick_to_time_map[cc.time]
                    new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
                    merged_cc.append(ControlChange(number=64, 
                                                   value=cc.value, 
                                                   time=new_time_tick))

    # --- 4. 排序并创建新乐器 ---
    merged_notes.sort(key=lambda x: (x.start, x.pitch))
    merged_cc.sort(key=lambda x: (x.time, x.number))
    
    output_instrument = Instrument(program=0, is_drum=False, name="Piano")
    output_instrument.notes = merged_notes
    output_instrument.control_changes = merged_cc
    output_midi_obj.instruments.append(output_instrument)
    
    # --- 5. 正确计算 max_tick ---
    max_tick = 0
    if output_instrument.notes:
        max_tick = max(max_tick, max(n.end for n in output_instrument.notes if n.end is not None))
    if output_instrument.control_changes:
        max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes if c.time is not None))
    
    # 添加一个小的buffer,确保最后一个事件不会被截断
    output_midi_obj.max_tick = max_tick + target_ticks_per_beat 

    return output_midi_obj

def merge_and_sort(midi_obj):
    output_midi_obj = MidiFile(ticks_per_beat=500)
    output_midi_obj.time_signature_changes = midi_obj.time_signature_changes
    output_midi_obj.key_signature_changes = midi_obj.key_signature_changes

    output_instrument = Instrument(program=0, is_drum=False, name="Piano")
    tick_ratio = 500 / midi_obj.ticks_per_beat
    all_notes = []
    for instrument in midi_obj.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                all_notes.append(
                    Note(
                        velocity=note.velocity, 
                        pitch=note.pitch, 
                        start=round(note.start * tick_ratio), 
                        end=round(note.end * tick_ratio)
                    )
                )
    notes_by_pitch = defaultdict(list)
    for note in all_notes:
        notes_by_pitch[note.pitch].append(note)
    merged_notes = []
    for pitch in sorted(notes_by_pitch.keys()):
        sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
        if len(sorted_notes) > 1:
            for i in range(len(sorted_notes) - 1):
                current_note = sorted_notes[i]
                next_note = sorted_notes[i+1]
                if current_note.end >= next_note.start:
                    current_note.end = next_note.start
                    if current_note.start >= current_note.end:
                         current_note.pitch = -1

        merged_notes.extend([n for n in sorted_notes if n.pitch != -1])

    merged_notes.sort(key=lambda x: (x.start, x.pitch))
    output_instrument.notes = merged_notes
    output_midi_obj.instruments.append(output_instrument)
    for time_signature in output_midi_obj.time_signature_changes:
        time_signature.time = round(time_signature.time * tick_ratio)
    for key_signature in output_midi_obj.key_signature_changes:
        key_signature.time = round(key_signature.time * tick_ratio)
    return output_midi_obj

def midi_to_ids(config, midi_obj, normalize=True):
    def get_pedal(time_list, ccs, time):
        i = bisect.bisect_right(time_list, time)
        if i == 0:
            return 0
        else:
            return ccs[i-1].value
    if normalize:
        norm_midi_obj = normalize_midi(midi_obj)
    else:
        norm_midi_obj = midi_obj
    time_list = [cc.time for cc in norm_midi_obj.instruments[0].control_changes]
    #print(time_list)
    intervals = []
    last_time = 0
    for note in norm_midi_obj.instruments[0].notes:
        intervals.append(note.start - last_time)
        last_time = note.start
    intervals.append(4990)

    ids = []
    last_time = 0
    for i, note in enumerate(norm_midi_obj.instruments[0].notes):
        interval = config.timing_start + intervals[i]
        #print(interval - interval_start)

        pitch = config.pitch_start + note.pitch
        velocity = config.velocity_start + note.velocity
        duration = config.timing_start + note.duration
        last_time = last_time + intervals[i]

        pedal1 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time)
        pedal2 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 1 / 4)
        pedal3 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 2 / 4)
        pedal4 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 3 / 4)
        
        pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
        interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
        velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
        duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
        pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
        pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
        pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
        pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))

        ids.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])
    return ids

def ids_to_midi(config, ids, target_ticks_per_beat = 500, target_tempo = 120, pedal_ratio = 1.0):
    note_list = []
    cc_list = []
    intervals = []
    for i in range(0, len(ids), 8):
        intervals.append(ids[i+1] - config.timing_start)
    intervals.append(4990)
    
    last_time = 0
    for i in range(0, len(ids), 8):
        interval = intervals[i // 8]
        pitch = ids[i] - config.pitch_start
        velocity = ids[i+2] - config.velocity_start
        duration = ids[i+3] - config.timing_start
        pedal1 = ids[i+4] - config.pedal_start
        pedal2 = ids[i+5] - config.pedal_start
        pedal3 = ids[i+6] - config.pedal_start
        pedal4 = ids[i+7] - config.pedal_start
        note_list.append(Note(velocity, pitch, last_time + interval, last_time + interval + duration))
        last_time += interval
        
        interval_time = intervals[i // 8 + 1]
        interval_step = intervals[i // 8 + 1] / 4 * pedal_ratio

        cc_list.append(ControlChange(64, pedal1, last_time))
        cc_list.append(ControlChange(64, pedal2, round(last_time + interval_step)))
        cc_list.append(ControlChange(64, pedal3, round(last_time + interval_time - interval_step * 2)))
        cc_list.append(ControlChange(64, pedal4, round(last_time + interval_time - interval_step)))

        #cc_list.append(ControlChange(64, pedal1, last_time))
        #cc_list.append(ControlChange(64, pedal2, round(last_time + intervals[i // 8 + 1] * 1 / 4)))
        #cc_list.append(ControlChange(64, pedal3, round(last_time + intervals[i // 8 + 1] * 2 / 4)))
        #cc_list.append(ControlChange(64, pedal4, round(last_time + intervals[i // 8 + 1] * 3 / 4)))

    last_value = 0
    new_cc_list = []
    for cc in cc_list:
        if cc.value != last_value:
            new_cc_list.append(cc)
        last_value = cc.value

    max_tick = 0
    for note in note_list:
        max_tick = max(max_tick, note.end)
    for cc in cc_list:
        max_tick = max(max_tick, cc.time)
    max_tick = max_tick + 1

    output = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=note_list, control_changes=new_cc_list))
    output.tempo_changes.append(TempoChange(target_tempo, 0))
    output.max_tick = max_tick
    
    return output

def read_corresp(corresp_path):
    out = []
    performacne_id_list = []
    with open(corresp_path, "r") as f:
        align_txt = f.readlines()

    score_ids_map = {}
    performance_ids_map = {}
    score_temp_list = []
    performance_temp_list = set()
    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[0] != '*':
            score_temp_list.append((float(informs[1]), int(informs[3]), int(informs[0])))
        if informs[5] != '*':
            performance_temp_list.add((float(informs[6]), int(informs[8]), int(informs[5])))
    performance_temp_list = list(performance_temp_list)
    score_temp_list.sort()
    performance_temp_list.sort()
    for i, inform in enumerate(score_temp_list):
        score_ids_map[inform[2]] = i
    for i, inform in enumerate(performance_temp_list):
        performance_ids_map[inform[2]] = i

    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[0] == '*':
            break
        if informs[5] != '*':
            out.append((score_ids_map[int(informs[0])], performance_ids_map[int(informs[5])]))
        else:
            out.append((score_ids_map[int(informs[0])], -1))
    
    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[5] != '*':
            performacne_id_list.append(performance_ids_map[int(informs[5])])
    if out[0][1] == -1:
        out[0] = (out[0][0], min(performacne_id_list))
    if out[-1][1] == -1:
        out[-1] = (out[-1][0], max(performacne_id_list)) 
    out.sort()
    return out

def interpolate(a, b):
    a = np.array(a) + np.linspace(0, 1e-5, len(a))
    b = np.array(b)
    known_inds = np.where(~np.isnan(b))[0]
    x_known = a[known_inds]
    y_known = b[known_inds]
    res = np.interp(a, x_known, y_known)
    res[known_inds] = b[known_inds]
    return [round(i) for i in res.tolist()]

def segment_sequences(x, label, unknown_ids, total_notes, max_consecutive_missing, min_segment_notes):

    if not unknown_ids:
        if total_notes >= min_segment_notes:
            return [x], [label]
        else:
            return [], []

    x_segments = []
    label_segments = []
    
    unknown_set = set(unknown_ids)
    
    last_cut_note_idx = 0
    consecutive_missing_count = 0

    for i in range(total_notes):
        if i in unknown_set:
            consecutive_missing_count += 1
        else:
            consecutive_missing_count = 0

        if consecutive_missing_count >= max_consecutive_missing:
            segment_end_note_idx = i - consecutive_missing_count + 1
            
            if segment_end_note_idx - last_cut_note_idx >= min_segment_notes:
                start_token = last_cut_note_idx * 8
                end_token = segment_end_note_idx * 8
                
                x_segments.append(x[start_token:end_token])
                label_segments.append(label[start_token:end_token])
            
            last_cut_note_idx = i + 1
            consecutive_missing_count = 0

    if total_notes - last_cut_note_idx >= min_segment_notes:
        start_token = last_cut_note_idx * 8
        x_segments.append(x[start_token:])
        label_segments.append(label[start_token:])
        
    return x_segments, label_segments

def align_score_and_performance(config, score_midi_obj, performance_midi_obj):
    norm_score_midi_obj = normalize_midi(score_midi_obj)
    norm_performance_midi_obj = normalize_midi(performance_midi_obj)
    
    norm_score_midi_obj.dump("temp/score.mid")
    norm_performance_midi_obj.dump("temp/performance.mid")

    os.chdir("./tools/AlignmentTool")
    os.system(f"timeout 120s ./MIDIToMIDIAlign.sh ../../temp/performance ../../temp/score")
    os.chdir("./../../") 

    corresp_list = read_corresp("temp/score_corresp.txt")
    aligned_midi_obj = MidiFile(ticks_per_beat=500)
    score_notes = norm_score_midi_obj.instruments[0].notes
    performance_notes = norm_performance_midi_obj.instruments[0].notes
    score_start_list = []
    output_notes = []
    output_ccs = []
    vel_list = []
    start_list = []
    duration_list = []
    unknown_ids = []
    for i, ids in enumerate(corresp_list):
        if ids[1] != -1:
            vel_list.append(performance_notes[ids[1]].velocity)
            start_list.append(performance_notes[ids[1]].start)
            duration_list.append(performance_notes[ids[1]].end - performance_notes[ids[1]].start)
        else:
            vel_list.append(np.nan)
            duration_list.append(np.nan)
            unknown_ids.append(i)
        score_start_list.append(score_notes[ids[0]].start)
    start_list.sort()
    temp = []
    cnt = 0
    for i in range(len(corresp_list)):
        if i not in unknown_ids:
            temp.append(start_list[cnt])
            cnt += 1
        else:
            temp.append(np.nan)
    start_list = interpolate(score_start_list, temp)
    vel_list = interpolate(start_list, vel_list)
    duration_list = interpolate(start_list, duration_list)

    end_list = []
    for i, ids in enumerate(corresp_list):
        end = start_list[i]+duration_list[i]
        end_list.append(end)
        output_notes.append(Note(vel_list[i], score_notes[ids[0]].pitch, start_list[i], end))
    max_tick = max(end_list) + 4999
    for cc in norm_performance_midi_obj.instruments[0].control_changes:
        if cc.time <= max_tick:
            output_ccs.append(cc)
        else:
            break

    aligned_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_ccs))
    x = midi_to_ids(config, norm_score_midi_obj)
    label = midi_to_ids(config, aligned_midi_obj, normalize=False)
    assert(len(x) == len(label))
    for i in range(len(x)):
        if i % 8 == 0:
            assert(x[i] == label[i])

    total_notes = len(score_notes)
    xs, labels = segment_sequences(
        x,
        label,
        unknown_ids,
        total_notes,
        5,
        64,
    )
    return xs, labels

def enhanced_ids(config, ids):
    res = copy(ids)
    retry = 10
    for i in range(len(res)):
        j = i % 8
        if j == 3:
            value = res[i] - config.valid_id_range[j][0]
            if value == 10:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 5)
                    if n >= -9 and n <= 5:
                        noise = n
                        break
            else:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 5)
                    if n >= -4 and n <= 5:
                        noise = n
                        break
            value = min(max(value + noise, 0), 4999)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 2:
            value = res[i] - config.valid_id_range[j][0]
            if value == 5:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -4 and n <= 2:
                        noise = n
                        break
            elif value == 120:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -2 and n <= 7:
                        noise = n
                        break
            else:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -2 and n <= 2:
                        noise = n
                        break
            value = min(max(value + noise, 0), 127)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 1:
            value = res[i] - config.valid_id_range[j][0]
            noise = 0
            for _ in range(retry):
                n = round(np.random.randn() * 5)
                if n >= -4 and n <= 5:
                    noise = n
                    break
            value = min(max(value + noise, 0), 4990)
            res[i] = config.valid_id_range[j][0] + value
    return res

def enhanced_ids_uniform(config, ids):
    res = copy(ids)
    for i in range(len(res)):
        j = i % 8
        if j == 3:
            value = res[i] - config.valid_id_range[j][0]
            if value == 10:
                noise = random.randint(-9, 5)
            else:
                noise = random.randint(-4, 5)
            value = min(max(value + noise, 0), 4999)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 2:
            value = res[i] - config.valid_id_range[j][0]
            if value == 5:
                noise = random.randint(-4, 2)
            elif value == 120:
                noise = random.randint(-2, 7)
            else:
                noise = random.randint(-2, 2)
            value = min(max(value + noise, 0), 127)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 1:
            value = res[i] - config.valid_id_range[j][0]
            noise = random.randint(-4, 5)
            value = min(max(value + noise, 0), 4990)
            res[i] = config.valid_id_range[j][0] + value
    return res

#if __name__ == "__main__":
#    midi_obj = MidiFile("data/midi/test/2.mid")
#    ids = midi_to_ids(midi_obj)
#    midi = ids_to_midi(ids)
#    midi.dump("data/rebuild/2.mid")