File size: 20,660 Bytes
e7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108a341
 
e7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108a341
e7ddfc6
 
 
 
 
108a341
e7ddfc6
 
 
 
 
108a341
 
 
 
929272c
 
 
 
108a341
 
 
 
 
 
 
 
e7ddfc6
 
 
 
 
 
 
108a341
e7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
import os
import time
import torch
import gradio as gr
import threading
import traceback
import moduleconf
import transkun.transcribe
from pathlib import Path
import tempfile
import shutil

os.environ['NO_PROXY'] = "localhost, 127.0.0.1, ::1"

# 设置ffmpeg路径
current_dir = os.path.dirname(os.path.abspath(__file__))
ffmpeg_bin_path = os.path.join(current_dir, "ffmpeg_bin")

# 检查路径是否已在PATH中,避免重复添加
if ffmpeg_bin_path not in os.environ['PATH'].split(os.pathsep):
    os.environ['PATH'] = ffmpeg_bin_path + os.pathsep + os.environ['PATH']

# 检查CUDA是否可用
cuda_available = torch.cuda.is_available()

import mido
from collections import defaultdict, Counter

# 从原始代码复制MIDI处理函数
def midi_quantize(midi_path, debug=False, optimize_bpm=True):
    """
    分贝对于左右手(左右手可以通过C4 上下进行分隔):
    对于当前同时按下的音符(按下时间间隔短),他们的时值统一到下一个音符被按下的时间
    注意只能拉伸尾端,音符的头端不能被动,相当于不能改按下事件

    optimize_bpm: 是否进行BPM优化
    """
    try:
        # 读取MIDI文件
        mid = mido.MidiFile(midi_path)

        # C4的MIDI音符号是60
        C4_NOTE = 60

        # 为每个音轨处理
        for track_idx, track in enumerate(mid.tracks):
            # 检查音轨是否包含音符事件
            has_notes = any(msg.type in ['note_on', 'note_off'] for msg in track)
            if not has_notes:
                continue

            # 收集所有音符事件
            notes_on = []  # 存储note_on事件
            notes_off = []  # 存储note_off事件
            other_events = []  # 存储其他事件

            current_time = 0

            # 解析音轨中的所有事件
            for msg in track:
                current_time += msg.time

                if msg.type == 'note_on' and msg.velocity > 0:
                    notes_on.append({
                        'time': current_time,
                        'note': msg.note,
                        'velocity': msg.velocity,
                        'channel': msg.channel,
                        'msg': msg
                    })
                elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                    notes_off.append({
                        'time': current_time,
                        'note': msg.note,
                        'velocity': msg.velocity if msg.type == 'note_off' else 0,
                        'channel': msg.channel,
                        'msg': msg
                    })
                else:
                    other_events.append({
                        'time': current_time,
                        'msg': msg
                    })

            # 按左右手分组(以C4为界)
            left_hand_notes = []  # C4以下(包含C4)
            right_hand_notes = []  # C4以上

            for note in notes_on:
                if note['note'] <= C4_NOTE:
                    left_hand_notes.append(note)
                else:
                    right_hand_notes.append(note)

            # 处理左右手的音符
            def process_hand_notes(hand_notes, hand_name=""):
                if len(hand_notes) < 1:
                    return

                # 按时间排序
                hand_notes.sort(key=lambda x: x['time'])

                if debug:
                    print(f"\n处理{hand_name},共{len(hand_notes)}个音符:")
                    for note in hand_notes:
                        print(f"  音符{note['note']} 时间{note['time']}")

                # 找到同时按下的音符组
                TIME_THRESHOLD = 100  # 50 ticks内认为是同时按下

                i = 0
                while i < len(hand_notes):
                    # 找到当前时间点的所有同时按下的音符
                    current_time = hand_notes[i]['time']
                    simultaneous_notes = [hand_notes[i]]

                    j = i + 1
                    while j < len(hand_notes) and hand_notes[j]['time'] - current_time <= TIME_THRESHOLD:
                        simultaneous_notes.append(hand_notes[j])
                        j += 1

                    # 找到下一个音符组的开始时间
                    next_note_time = None
                    if j < len(hand_notes):
                        next_note_time = hand_notes[j]['time']

                    if debug:
                        note_names = [str(n['note']) for n in simultaneous_notes]
                        print(f"  同时音符组: {note_names} 在时间{current_time}, 下一组时间: {next_note_time}")

                    # 对于当前组的音符(不管是单个还是多个),都要调整时值
                    if next_note_time is not None:
                        # 调整当前组中所有音符的note_off时间
                        for note in simultaneous_notes:
                            # 找到对应的note_off事件(找到时间最近的且未被处理的那个)
                            best_off_event = None
                            min_time_diff = float('inf')

                            for off_event in notes_off:
                                if (off_event['note'] == note['note'] and
                                    off_event['channel'] == note['channel'] and
                                    off_event['time'] > note['time'] and
                                    not off_event.get('processed', False)):  # 确保未被处理过

                                    time_diff = off_event['time'] - note['time']
                                    if time_diff < min_time_diff:
                                        min_time_diff = time_diff
                                        best_off_event = off_event

                            if best_off_event is not None:
                                old_time = best_off_event['time']
                                # 将note_off时间设置到下一个音符开始前的一小段时间
                                # 确保不会太晚,也不会早于原始的最小持续时间
                                min_duration = 100  # 最小持续时间
                                target_time = next_note_time - 10
                                best_off_event['time'] = max(note['time'] + min_duration, target_time)
                                best_off_event['processed'] = True  # 标记为已处理

                                if debug:
                                    print(f"    音符{note['note']} off时间: {old_time} -> {best_off_event['time']}")

                    i = j

            # 处理左右手
            process_hand_notes(left_hand_notes, "左手")
            process_hand_notes(right_hand_notes, "右手")

            # 重建音轨
            all_events = []

            # 添加所有事件并按时间排序
            for note in notes_on:
                all_events.append(('note_on', note['time'], note))

            for note in notes_off:
                all_events.append(('note_off', note['time'], note))

            for event in other_events:
                all_events.append(('other', event['time'], event))

            # 按时间排序
            all_events.sort(key=lambda x: x[1])

            # 重建MIDI消息
            new_messages = []
            last_time = 0

            for event_type, event_time, event_data in all_events:
                delta_time = event_time - last_time

                if event_type == 'note_on':
                    msg = mido.Message('note_on',
                                       channel=event_data['channel'],
                                       note=event_data['note'],
                                       velocity=event_data['velocity'],
                                       time=delta_time)
                elif event_type == 'note_off':
                    msg = mido.Message('note_off',
                                       channel=event_data['channel'],
                                       note=event_data['note'],
                                       velocity=event_data['velocity'],
                                       time=delta_time)
                else:
                    msg = event_data['msg'].copy(time=delta_time)

                new_messages.append(msg)
                last_time = event_time

            # 替换原音轨
            track.clear()
            track.extend(new_messages)

        # 裁剪MIDI首尾空白
        if debug:
            print("裁剪MIDI首尾空白...")
        trim_midi_silence(mid, debug)

        # 保存处理后的文件
        output_path = os.path.splitext(midi_path)[0] + '_quantized.mid'
        mid.save(output_path)

        return output_path

    except Exception as e:
        raise Exception(f"处理MIDI文件时出错: {str(e)}")

def trim_midi_silence(mid, debug=False):
    """
    裁剪MIDI文件首尾的空白部分
    """
    try:
        # 找到第一个和最后一个音符事件的时间
        first_note_time = float('inf')
        last_note_time = 0

        for track in mid.tracks:
            current_time = 0
            track_first_note = None
            track_last_note = 0

            for msg in track:
                current_time += msg.time

                if msg.type == 'note_on' and msg.velocity > 0:
                    if track_first_note is None:
                        track_first_note = current_time
                    track_last_note = current_time

            if track_first_note is not None:
                first_note_time = min(first_note_time, track_first_note)
                last_note_time = max(last_note_time, track_last_note)

        if first_note_time == float('inf'):
            if debug:
                print("没有找到音符,跳过裁剪")
            return

        if debug:
            print(f"音符时间范围: {first_note_time} - {last_note_time}")

        # 调整所有音轨的时间
        for track in mid.tracks:
            if not track:
                continue

            # 重建消息,调整时间
            new_messages = []
            current_time = 0

            for msg in track:
                current_time += msg.time

                # 只保留在音符范围内的事件,或者是重要的元事件
                if (first_note_time <= current_time <= last_note_time + 1000 or  # 音符范围内
                    msg.type in ['set_tempo', 'key_signature', 'time_signature'] or  # 重要元事件
                    current_time < first_note_time):  # 开头的设置事件

                    # 调整时间:减去开头的空白时间
                    adjusted_time = max(0, current_time - first_note_time)
                    new_messages.append((adjusted_time, msg))

            # 重建track
            track.clear()
            if new_messages:
                last_time = 0
                for abs_time, msg in new_messages:
                    delta_time = abs_time - last_time
                    new_msg = msg.copy(time=delta_time)
                    track.append(new_msg)
                    last_time = abs_time

        if debug:
            print("MIDI裁剪完成")

    except Exception as e:
        if debug:
            print(f"MIDI裁剪失败: {e}")

import subprocess

# 核心转换函数
def process_audio(input_file, use_cuda=True, use_quantize=True, progress=gr.Progress(), file_progress_offset=0.0, file_progress_scale=1.0):
    """
    处理音频文件并生成MIDI文件。

    :param input_file: 输入音频文件路径。
    :param use_cuda: 是否使用CUDA加速。
    :param use_quantize: 是否对生成的MIDI文件进行量化处理。
    :param progress: Gradio进度条对象。
    :param file_progress_offset: 进度条的起始偏移量,用于批量处理。
    :param file_progress_scale: 进度条的缩放比例,用于批量处理。
    :return: 包含处理结果的字典。
    """
    temp_dir = None
    try:
        # 创建一个临时目录来存储所有的输出文件
        temp_dir = tempfile.mkdtemp()

        # 从输入文件中获取一个有意义的文件名
        input_name = Path(input_file).stem

        # 在临时目录中创建MIDI文件的路径
        output_file = Path(temp_dir) / f"{input_name}.mid"

        quantized_output_file = None

        start_time = time.time()
        progress(file_progress_offset, desc="准备转录...")

        # 使用命令行调用transkun
        progress(file_progress_offset + 0.3 * file_progress_scale, desc="转录中...")
        if use_cuda and cuda_available:
            cmd = ["transkun", input_file, str(output_file), "--device", "cuda"]
        else:
            cmd = ["transkun", input_file, str(output_file)]

        # 执行命令
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        stdout, stderr = process.communicate()

        # 检查命令是否成功执行
        if process.returncode != 0:
            raise Exception(f"transkun命令执行失败: {stderr}")

        progress(file_progress_offset + 0.7 * file_progress_scale, desc="保存MIDI...")

        # 如果勾选了规整化选项,则进行MIDI规整化
        if use_quantize:
            progress(file_progress_offset + 0.8 * file_progress_scale, desc="规整化MIDI...")
            try:
                # midi_quantize函数将以预期的名称写入输出文件
                quantized_output_file = midi_quantize(str(output_file), debug=False, optimize_bpm=True)
            except Exception as e:
                print(f"规整化处理失败: {str(e)}")
                # 规整化失败不影响主流程

        end_time = time.time()
        process_time = round(end_time - start_time, 2)

        progress(file_progress_offset + 1.0 * file_progress_scale, desc="完成!")

        # 返回结果
        result_files = [str(output_file)]
        if quantized_output_file:
            result_files.append(quantized_output_file)

        return {
            "output": f"转换完成!用时 {process_time}秒",
            "files": result_files
        }

    except Exception as e:
        traceback.print_exc()
        return {
            "output": f"转换失败: {str(e)}",
            "files": []
        }
    # 删除了手动清理代码块,现在由 Gradio 来处理。

# 创建Gradio界面
def create_interface():
    with gr.Blocks(title="Transkun - Piano Audio to MIDI", theme=gr.themes.Soft(primary_hue="blue")) as app:
        gr.Markdown(
            """
            # Transkun - 钢琴音频转MIDI
            将钢琴演奏音频转换为MIDI文件
            """
        )

        with gr.Row():
            with gr.Column(scale=2):
                # 输入部分
                gr.Markdown("### 1. 选择输入音频文件")
                input_audio = gr.File(label="输入音频文件", file_count="multiple", file_types=["audio"], interactive=True)

                gr.Markdown("### 2. 选择转换选项")
                with gr.Row():
                    use_cuda = gr.Checkbox(
                        label=f"启用CUDA加速 (CUDA {'可用 ✓' if cuda_available else '不可用 ✗'})",
                        value=cuda_available,
                        interactive=cuda_available
                    )


                def show_warning(is_checked):
                    """Show a warning message when the checkbox is checked."""
                    if is_checked:
                        gr.Warning("注意:开启此项可能影响踏板效果。")
                        return gr.update(info="")
                    else:
                        return gr.update(info="")
                use_quantize = gr.Checkbox(
                    label="使用MIDI规整化(附带有_quantized后缀的输出文件),建议只用于阅读MIDI(❗注意:此项可能影响踏板效果)",
                    value=False
                )

                use_quantize.change(
                    fn=show_warning,
                    inputs=use_quantize,
                    outputs=use_quantize
                )

                convert_btn = gr.Button("开始转换", variant="primary")

            with gr.Column(scale=1):
                # 输出部分
                status_output = gr.Textbox(label="状态", value="准备就绪", interactive=False)
                file_output = gr.File(label="生成的MIDI文件", interactive=False, file_count="multiple")

                # 创建一个隐藏的文本框来存储文件路径
                file_paths_store = gr.State([])

                # 下载按钮
                with gr.Row():
                    download_all_btn = gr.Button("一键下载全部文件", variant="secondary", visible=False)
                    download_status = gr.Textbox(label="下载状态", value="", visible=False, interactive=False)

        # 处理函数
        def on_convert(audio_paths, use_cuda, use_quantize, progress=gr.Progress()):
            if not audio_paths:
                return "请选择输入音频文件", [], gr.update(visible=False)

            all_files = []
            results = []
            total_files = len(audio_paths)

            for i, audio_path in enumerate(audio_paths):
                file_name = Path(audio_path).name
                progress_offset = (i / total_files)
                progress_scale = (1 / total_files) * 0.9

                progress(progress_offset, desc=f"处理文件 {i+1}/{total_files}: {file_name}")
                result = process_audio(audio_path, use_cuda, use_quantize, progress,
                                      file_progress_offset=progress_offset,
                                      file_progress_scale=progress_scale)
                results.append(result["output"])
                all_files.extend(result["files"])

            progress(1.0, desc="全部完成!")
            download_btn_update = gr.update(visible=True) if all_files else gr.update(visible=False)
            download_status_update = gr.update(visible=False)
            return f"转换完成!共处理 {total_files} 个文件\n" + "\n".join(results), all_files, download_btn_update, download_status_update, all_files

        # 下载所有文件的函数
        def download_all_files(file_paths, status_output=None):
            import tempfile
            import os
            import shutil
            import zipfile
            from pathlib import Path

            if not file_paths or len(file_paths) == 0:
                return None, gr.update(value="没有文件可下载", visible=True), gr.update(visible=False)

            try:
                # 创建临时目录用于存放文件
                temp_dir = tempfile.mkdtemp(prefix="midi_files_")

                # 创建ZIP文件
                zip_path = os.path.join(temp_dir, "all_midi_files.zip")

                # 直接创建ZIP文件,不使用shutil.make_archive
                with zipfile.ZipFile(zip_path, 'w') as zipf:
                    for file_path in file_paths:
                        if os.path.exists(file_path):
                            # 只添加文件名,不包含路径
                            zipf.write(file_path, os.path.basename(file_path))

                return zip_path, gr.update(value="下载准备完成,请点击上方文件链接下载", visible=True), gr.update(visible=False)
            except Exception as e:
                return None, gr.update(value=f"下载准备失败: {str(e)}", visible=True), gr.update(visible=True)

        # 绑定按钮事件
        convert_btn.click(
            fn=on_convert,
            inputs=[input_audio, use_cuda, use_quantize],
            outputs=[status_output, file_output, download_all_btn, download_status, file_paths_store]
        )

        # 绑定下载按钮事件
        download_all_btn.click(
            fn=download_all_files,
            inputs=[file_paths_store, download_status],
            outputs=[file_output, download_status, download_all_btn]
        )

    return app

# 启动应用
def main():
    app = create_interface()
    # It's better to launch on 0.0.0.0 for broader access, though 127.0.0.1 is fine for local.
    # 最好在0.0.0.0上启动以便更广泛的访问,不过127.0.0.1用于本地也是可以的。
    app.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)

if __name__ == "__main__":
    main()