Lollikit commited on
Commit
e7ddfc6
·
1 Parent(s): 8cb0a75

Add application file

Browse files
Files changed (1) hide show
  1. app.py +550 -0
app.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import gradio as gr
5
+ import threading
6
+ import traceback
7
+ import moduleconf
8
+ import transkun.transcribe
9
+ from pathlib import Path
10
+ import tempfile
11
+ import shutil
12
+
13
+ os.environ['NO_PROXY'] = "localhost, 127.0.0.1, ::1"
14
+
15
+ # 设置ffmpeg路径
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ ffmpeg_bin_path = os.path.join(current_dir, "ffmpeg_bin")
18
+
19
+ # 检查路径是否已在PATH中,避免重复添加
20
+ if ffmpeg_bin_path not in os.environ['PATH'].split(os.pathsep):
21
+ os.environ['PATH'] = ffmpeg_bin_path + os.pathsep + os.environ['PATH']
22
+
23
+ # 检查CUDA是否可用
24
+ cuda_available = torch.cuda.is_available()
25
+
26
+ import mido
27
+ from collections import defaultdict, Counter
28
+
29
+ # 从原始代码复制MIDI处理函数
30
+ def midi_quantize(midi_path, debug=False, optimize_bpm=True):
31
+ """
32
+ 分贝对于左右手(左右手可以通过C4 上下进行分隔):
33
+ 对于当前同时按下的音符(按下时间间隔短),他们的时值统一到下一个音符被按下的时间
34
+ 注意只能拉伸尾端,音符的头端不能被动,相当于不能改按下事件
35
+
36
+ optimize_bpm: 是否进行BPM优化
37
+ """
38
+ try:
39
+ # 读取MIDI文件
40
+ mid = mido.MidiFile(midi_path)
41
+
42
+ # C4的MIDI音符号是60
43
+ C4_NOTE = 60
44
+
45
+ # 为每个音轨处理
46
+ for track_idx, track in enumerate(mid.tracks):
47
+ # 检查音轨是否包含音符事件
48
+ has_notes = any(msg.type in ['note_on', 'note_off'] for msg in track)
49
+ if not has_notes:
50
+ continue
51
+
52
+ # 收集所有音符事件
53
+ notes_on = [] # 存储note_on事件
54
+ notes_off = [] # 存储note_off事件
55
+ other_events = [] # 存储其他事件
56
+
57
+ current_time = 0
58
+
59
+ # 解析音轨中的所有事件
60
+ for msg in track:
61
+ current_time += msg.time
62
+
63
+ if msg.type == 'note_on' and msg.velocity > 0:
64
+ notes_on.append({
65
+ 'time': current_time,
66
+ 'note': msg.note,
67
+ 'velocity': msg.velocity,
68
+ 'channel': msg.channel,
69
+ 'msg': msg
70
+ })
71
+ elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
72
+ notes_off.append({
73
+ 'time': current_time,
74
+ 'note': msg.note,
75
+ 'velocity': msg.velocity if msg.type == 'note_off' else 0,
76
+ 'channel': msg.channel,
77
+ 'msg': msg
78
+ })
79
+ else:
80
+ other_events.append({
81
+ 'time': current_time,
82
+ 'msg': msg
83
+ })
84
+
85
+ # 按左右手分组(以C4为界)
86
+ left_hand_notes = [] # C4以下(包含C4)
87
+ right_hand_notes = [] # C4以上
88
+
89
+ for note in notes_on:
90
+ if note['note'] <= C4_NOTE:
91
+ left_hand_notes.append(note)
92
+ else:
93
+ right_hand_notes.append(note)
94
+
95
+ # 处理左右手的音符
96
+ def process_hand_notes(hand_notes, hand_name=""):
97
+ if len(hand_notes) < 1:
98
+ return
99
+
100
+ # 按时间排序
101
+ hand_notes.sort(key=lambda x: x['time'])
102
+
103
+ if debug:
104
+ print(f"\n处理{hand_name},共{len(hand_notes)}个音符:")
105
+ for note in hand_notes:
106
+ print(f" 音符{note['note']} 时间{note['time']}")
107
+
108
+ # 找到同时按下的音符组
109
+ TIME_THRESHOLD = 100 # 50 ticks内认为是同时按下
110
+
111
+ i = 0
112
+ while i < len(hand_notes):
113
+ # 找到当前时间点的所有同时按下的音符
114
+ current_time = hand_notes[i]['time']
115
+ simultaneous_notes = [hand_notes[i]]
116
+
117
+ j = i + 1
118
+ while j < len(hand_notes) and hand_notes[j]['time'] - current_time <= TIME_THRESHOLD:
119
+ simultaneous_notes.append(hand_notes[j])
120
+ j += 1
121
+
122
+ # 找到下一个音符组的开始时间
123
+ next_note_time = None
124
+ if j < len(hand_notes):
125
+ next_note_time = hand_notes[j]['time']
126
+
127
+ if debug:
128
+ note_names = [str(n['note']) for n in simultaneous_notes]
129
+ print(f" 同时音符组: {note_names} 在时间{current_time}, 下一组时间: {next_note_time}")
130
+
131
+ # 对于当前组的音符(不管是单个还是多个),都要调整时值
132
+ if next_note_time is not None:
133
+ # 调整当前组中所有音符的note_off时间
134
+ for note in simultaneous_notes:
135
+ # 找到对应的note_off事件(找到时间最近的且未被处理的那个)
136
+ best_off_event = None
137
+ min_time_diff = float('inf')
138
+
139
+ for off_event in notes_off:
140
+ if (off_event['note'] == note['note'] and
141
+ off_event['channel'] == note['channel'] and
142
+ off_event['time'] > note['time'] and
143
+ not off_event.get('processed', False)): # 确保未被处理过
144
+
145
+ time_diff = off_event['time'] - note['time']
146
+ if time_diff < min_time_diff:
147
+ min_time_diff = time_diff
148
+ best_off_event = off_event
149
+
150
+ if best_off_event is not None:
151
+ old_time = best_off_event['time']
152
+ # 将note_off时间设置到下一个音符开始前的一小段时间
153
+ # 确保不会太晚,也不会早于原始的最小持续时间
154
+ min_duration = 100 # 最小持续时间
155
+ target_time = next_note_time - 10
156
+ best_off_event['time'] = max(note['time'] + min_duration, target_time)
157
+ best_off_event['processed'] = True # 标记为已处理
158
+
159
+ if debug:
160
+ print(f" 音符{note['note']} off时间: {old_time} -> {best_off_event['time']}")
161
+
162
+ i = j
163
+
164
+ # 处理左右手
165
+ process_hand_notes(left_hand_notes, "左手")
166
+ process_hand_notes(right_hand_notes, "右手")
167
+
168
+ # 重建音轨
169
+ all_events = []
170
+
171
+ # 添加所有事件并按时间排序
172
+ for note in notes_on:
173
+ all_events.append(('note_on', note['time'], note))
174
+
175
+ for note in notes_off:
176
+ all_events.append(('note_off', note['time'], note))
177
+
178
+ for event in other_events:
179
+ all_events.append(('other', event['time'], event))
180
+
181
+ # 按时间排序
182
+ all_events.sort(key=lambda x: x[1])
183
+
184
+ # 重建MIDI消息
185
+ new_messages = []
186
+ last_time = 0
187
+
188
+ for event_type, event_time, event_data in all_events:
189
+ delta_time = event_time - last_time
190
+
191
+ if event_type == 'note_on':
192
+ msg = mido.Message('note_on',
193
+ channel=event_data['channel'],
194
+ note=event_data['note'],
195
+ velocity=event_data['velocity'],
196
+ time=delta_time)
197
+ elif event_type == 'note_off':
198
+ msg = mido.Message('note_off',
199
+ channel=event_data['channel'],
200
+ note=event_data['note'],
201
+ velocity=event_data['velocity'],
202
+ time=delta_time)
203
+ else:
204
+ msg = event_data['msg'].copy(time=delta_time)
205
+
206
+ new_messages.append(msg)
207
+ last_time = event_time
208
+
209
+ # 替换原音轨
210
+ track.clear()
211
+ track.extend(new_messages)
212
+
213
+ # 裁剪MIDI首尾空白
214
+ if debug:
215
+ print("裁剪MIDI首尾空白...")
216
+ trim_midi_silence(mid, debug)
217
+
218
+ # 保存处理后的文件
219
+ output_path = os.path.splitext(midi_path)[0] + '_quantized.mid'
220
+ mid.save(output_path)
221
+
222
+ return output_path
223
+
224
+ except Exception as e:
225
+ raise Exception(f"处理MIDI文件时出错: {str(e)}")
226
+
227
+ def trim_midi_silence(mid, debug=False):
228
+ """
229
+ 裁剪MIDI文件首尾的空白部分
230
+ """
231
+ try:
232
+ # 找到第一个和最后一个音符事件的时间
233
+ first_note_time = float('inf')
234
+ last_note_time = 0
235
+
236
+ for track in mid.tracks:
237
+ current_time = 0
238
+ track_first_note = None
239
+ track_last_note = 0
240
+
241
+ for msg in track:
242
+ current_time += msg.time
243
+
244
+ if msg.type == 'note_on' and msg.velocity > 0:
245
+ if track_first_note is None:
246
+ track_first_note = current_time
247
+ track_last_note = current_time
248
+
249
+ if track_first_note is not None:
250
+ first_note_time = min(first_note_time, track_first_note)
251
+ last_note_time = max(last_note_time, track_last_note)
252
+
253
+ if first_note_time == float('inf'):
254
+ if debug:
255
+ print("没有找到音符,跳过裁剪")
256
+ return
257
+
258
+ if debug:
259
+ print(f"音符时间范围: {first_note_time} - {last_note_time}")
260
+
261
+ # 调整所有音轨的时间
262
+ for track in mid.tracks:
263
+ if not track:
264
+ continue
265
+
266
+ # 重建消息,调整时间
267
+ new_messages = []
268
+ current_time = 0
269
+
270
+ for msg in track:
271
+ current_time += msg.time
272
+
273
+ # 只保留在音符范围内的事件,或者是重要的元事件
274
+ if (first_note_time <= current_time <= last_note_time + 1000 or # 音符范围内
275
+ msg.type in ['set_tempo', 'key_signature', 'time_signature'] or # 重要元事件
276
+ current_time < first_note_time): # 开头的设置事件
277
+
278
+ # 调整时间:减去开头的空白时间
279
+ adjusted_time = max(0, current_time - first_note_time)
280
+ new_messages.append((adjusted_time, msg))
281
+
282
+ # 重建track
283
+ track.clear()
284
+ if new_messages:
285
+ last_time = 0
286
+ for abs_time, msg in new_messages:
287
+ delta_time = abs_time - last_time
288
+ new_msg = msg.copy(time=delta_time)
289
+ track.append(new_msg)
290
+ last_time = abs_time
291
+
292
+ if debug:
293
+ print("MIDI裁剪完成")
294
+
295
+ except Exception as e:
296
+ if debug:
297
+ print(f"MIDI裁剪失败: {e}")
298
+
299
+ # 核心转换函数
300
+ def process_audio(input_file, use_cuda=True, use_quantize=True, progress=gr.Progress(), file_progress_offset=0.0, file_progress_scale=1.0):
301
+ """
302
+ 处理音频文件并生成MIDI文件。
303
+
304
+ :param input_file: 输入音频文件路径。
305
+ :param use_cuda: 是否使用CUDA加速。
306
+ :param use_quantize: 是否对生成的MIDI文件进行量化处理。
307
+ :param progress: Gradio进度条对象。
308
+ :param file_progress_offset: 进度条的起始偏移量,用于批量处理。
309
+ :param file_progress_scale: 进度条的缩放比例,用于批量处理。
310
+ :return: 包含处理结果的字典。
311
+ """
312
+ temp_dir = None
313
+ try:
314
+ # The fix: create a temporary directory to store all output files
315
+ # 修复:创建一个临时目录来存储所有的输出文件
316
+ temp_dir = tempfile.mkdtemp()
317
+
318
+ # Get a meaningful filename from the input file
319
+ # 从输入文件中获取一个有意义的文件名
320
+ input_name = Path(input_file).stem
321
+
322
+ # Create the path for the non-quantized MIDI file inside the temp directory
323
+ # 在临时目录中创建非量化MIDI文件的路径
324
+ output_file = Path(temp_dir) / f"{input_name}.mid"
325
+
326
+ quantized_output_file = None
327
+
328
+ device = "cuda" if use_cuda and cuda_available else "cpu"
329
+
330
+ start_time = time.time()
331
+ progress(file_progress_offset, desc="准备模型...")
332
+
333
+ # 加载模型和配置
334
+ default_weight = os.path.join(current_dir, "models\\2.0.pt")
335
+ default_conf = os.path.join(current_dir, "models\\2.0.conf")
336
+
337
+ # 检查模型文件是否存在
338
+ if not os.path.exists(default_weight) or not os.path.exists(default_conf):
339
+ raise FileNotFoundError(
340
+ f"找不到模型文件!请确保以下文件存在:\n"
341
+ f"{default_weight}\n"
342
+ f"{default_conf}"
343
+ )
344
+
345
+ # 加载配置
346
+ conf_manager = moduleconf.parseFromFile(default_conf)
347
+ TransKun = conf_manager["Model"].module.TransKun
348
+ conf = conf_manager["Model"].config
349
+
350
+ # 加载模型
351
+ checkpoint = torch.load(default_weight, map_location=device)
352
+ model = TransKun(conf=conf).to(device)
353
+ if "best_state_dict" not in checkpoint:
354
+ model.load_state_dict(checkpoint["state_dict"], strict=False)
355
+ else:
356
+ model.load_state_dict(checkpoint["best_state_dict"], strict=False)
357
+ model.eval()
358
+
359
+ progress(file_progress_offset + 0.2 * file_progress_scale, desc="读取音频...")
360
+ # 读取并处理音频
361
+ fs, audio = transkun.transcribe.readAudio(input_file)
362
+ if fs != model.fs:
363
+ import soxr
364
+ audio = soxr.resample(audio, fs, model.fs)
365
+
366
+ x = torch.from_numpy(audio).to(device)
367
+
368
+ progress(file_progress_offset + 0.4 * file_progress_scale, desc="转录中...")
369
+ # 转录
370
+ with torch.no_grad():
371
+ notes_est = model.transcribe(x)
372
+
373
+ progress(file_progress_offset + 0.7 * file_progress_scale, desc="保存MIDI...")
374
+ # 保存MIDI到临时目录,将 Path 对象转换为字符串
375
+ output_midi = transkun.transcribe.writeMidi(notes_est)
376
+ output_midi.write(str(output_file))
377
+
378
+ # 如果勾选了规整化选项,则进行MIDI规整化
379
+ if use_quantize:
380
+ progress(file_progress_offset + 0.8 * file_progress_scale, desc="规整化MIDI...")
381
+ try:
382
+ # The midi_quantize function will now write the output file with the expected name
383
+ # midi_quantize函数现在将以预期的名称写入输出文件
384
+ quantized_output_file = midi_quantize(str(output_file), debug=False, optimize_bpm=True)
385
+ except Exception as e:
386
+ print(f"规整化处理失败: {str(e)}")
387
+ # 规整化失败不影响主流程
388
+
389
+ end_time = time.time()
390
+ process_time = round(end_time - start_time, 2)
391
+
392
+ progress(file_progress_offset + 1.0 * file_progress_scale, desc="完成!")
393
+
394
+ # 返回结果
395
+ result_files = [str(output_file)]
396
+ if quantized_output_file:
397
+ result_files.append(quantized_output_file)
398
+
399
+ return {
400
+ "output": f"转换完成!用时 {process_time}秒",
401
+ "files": result_files
402
+ }
403
+
404
+ except Exception as e:
405
+ traceback.print_exc()
406
+ return {
407
+ "output": f"转换失败: {str(e)}",
408
+ "files": []
409
+ }
410
+ # Removed the manual cleanup block, Gradio will handle this now.
411
+ # 删除了手动清理代码块,现在由 Gradio 来处理。
412
+
413
+ # 创建Gradio界面
414
+ def create_interface():
415
+ with gr.Blocks(title="Transkun - Piano Audio to MIDI", theme=gr.themes.Soft(primary_hue="blue")) as app:
416
+ gr.Markdown(
417
+ """
418
+ # Transkun - 钢琴音频转MIDI
419
+ 将钢琴演奏音频转换为MIDI文件
420
+ """
421
+ )
422
+
423
+ with gr.Row():
424
+ with gr.Column(scale=2):
425
+ # 输入部分
426
+ gr.Markdown("### 1. 选择输入音频文件")
427
+ input_audio = gr.File(label="输入音频文件", file_count="multiple", file_types=["audio"], interactive=True)
428
+
429
+ gr.Markdown("### 2. 选择转换选项")
430
+ with gr.Row():
431
+ use_cuda = gr.Checkbox(
432
+ label=f"启用CUDA加速 (CUDA {'可用 ✓' if cuda_available else '不可用 ✗'})",
433
+ value=cuda_available,
434
+ interactive=cuda_available
435
+ )
436
+
437
+
438
+ def show_warning(is_checked):
439
+ """Show a warning message when the checkbox is checked."""
440
+ if is_checked:
441
+ gr.Warning("注意:开启此项可能影响踏板效果。")
442
+ return gr.update(info="")
443
+ else:
444
+ return gr.update(info="")
445
+ use_quantize = gr.Checkbox(
446
+ label="使用MIDI规整化(附带有_quantized后缀的输出文件),建议只用于阅读MIDI(❗注意:此项可能影响踏板效果)",
447
+ value=False
448
+ )
449
+
450
+ use_quantize.change(
451
+ fn=show_warning,
452
+ inputs=use_quantize,
453
+ outputs=use_quantize
454
+ )
455
+
456
+ convert_btn = gr.Button("开始转换", variant="primary")
457
+
458
+ with gr.Column(scale=1):
459
+ # 输出部分
460
+ status_output = gr.Textbox(label="状态", value="准备就绪", interactive=False)
461
+ file_output = gr.File(label="生成的MIDI文件", interactive=False, file_count="multiple")
462
+
463
+ # 创建一个隐藏的文本框来存储文件路径
464
+ file_paths_store = gr.State([])
465
+
466
+ # 下载按钮
467
+ with gr.Row():
468
+ download_all_btn = gr.Button("一键下载全部文件", variant="secondary", visible=False)
469
+ download_status = gr.Textbox(label="下载状态", value="", visible=False, interactive=False)
470
+
471
+ # 处理函数
472
+ def on_convert(audio_paths, use_cuda, use_quantize, progress=gr.Progress()):
473
+ if not audio_paths:
474
+ return "请选择输入音频文件", [], gr.update(visible=False)
475
+
476
+ all_files = []
477
+ results = []
478
+ total_files = len(audio_paths)
479
+
480
+ for i, audio_path in enumerate(audio_paths):
481
+ file_name = Path(audio_path).name
482
+ progress_offset = (i / total_files)
483
+ progress_scale = (1 / total_files) * 0.9
484
+
485
+ progress(progress_offset, desc=f"处理文件 {i+1}/{total_files}: {file_name}")
486
+ result = process_audio(audio_path, use_cuda, use_quantize, progress,
487
+ file_progress_offset=progress_offset,
488
+ file_progress_scale=progress_scale)
489
+ results.append(result["output"])
490
+ all_files.extend(result["files"])
491
+
492
+ progress(1.0, desc="全部完成!")
493
+ download_btn_update = gr.update(visible=True) if all_files else gr.update(visible=False)
494
+ download_status_update = gr.update(visible=False)
495
+ return f"转换完成!共处理 {total_files} 个文件\n" + "\n".join(results), all_files, download_btn_update, download_status_update, all_files
496
+
497
+ # 下载所有文件的函数
498
+ def download_all_files(file_paths, status_output=None):
499
+ import tempfile
500
+ import os
501
+ import shutil
502
+ import zipfile
503
+ from pathlib import Path
504
+
505
+ if not file_paths or len(file_paths) == 0:
506
+ return None, gr.update(value="没有文件可下载", visible=True), gr.update(visible=False)
507
+
508
+ try:
509
+ # 创建临时目录用于存放文件
510
+ temp_dir = tempfile.mkdtemp(prefix="midi_files_")
511
+
512
+ # 创建ZIP文件
513
+ zip_path = os.path.join(temp_dir, "all_midi_files.zip")
514
+
515
+ # 直接创建ZIP文件,不使用shutil.make_archive
516
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
517
+ for file_path in file_paths:
518
+ if os.path.exists(file_path):
519
+ # 只添加文件名,不包含路径
520
+ zipf.write(file_path, os.path.basename(file_path))
521
+
522
+ return zip_path, gr.update(value="下载准备完成,请点击上方文件链接下载", visible=True), gr.update(visible=False)
523
+ except Exception as e:
524
+ return None, gr.update(value=f"下载准备失败: {str(e)}", visible=True), gr.update(visible=True)
525
+
526
+ # 绑定按钮事件
527
+ convert_btn.click(
528
+ fn=on_convert,
529
+ inputs=[input_audio, use_cuda, use_quantize],
530
+ outputs=[status_output, file_output, download_all_btn, download_status, file_paths_store]
531
+ )
532
+
533
+ # 绑定下载按钮事件
534
+ download_all_btn.click(
535
+ fn=download_all_files,
536
+ inputs=[file_paths_store, download_status],
537
+ outputs=[file_output, download_status, download_all_btn]
538
+ )
539
+
540
+ return app
541
+
542
+ # 启动应用
543
+ def main():
544
+ app = create_interface()
545
+ # It's better to launch on 0.0.0.0 for broader access, though 127.0.0.1 is fine for local.
546
+ # 最好在0.0.0.0上启动以便更广泛的访问,不过127.0.0.1用于本地也是可以的。
547
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
548
+
549
+ if __name__ == "__main__":
550
+ main()