Spaces:
Running
Running
| import os | |
| import time | |
| import torch | |
| import gradio as gr | |
| import threading | |
| import traceback | |
| import moduleconf | |
| import transkun.transcribe | |
| from pathlib import Path | |
| import tempfile | |
| import shutil | |
| os.environ['NO_PROXY'] = "localhost, 127.0.0.1, ::1" | |
| # 设置ffmpeg路径 | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| ffmpeg_bin_path = os.path.join(current_dir, "ffmpeg_bin") | |
| # 检查路径是否已在PATH中,避免重复添加 | |
| if ffmpeg_bin_path not in os.environ['PATH'].split(os.pathsep): | |
| os.environ['PATH'] = ffmpeg_bin_path + os.pathsep + os.environ['PATH'] | |
| # 检查CUDA是否可用 | |
| cuda_available = torch.cuda.is_available() | |
| import mido | |
| from collections import defaultdict, Counter | |
| # 从原始代码复制MIDI处理函数 | |
| def midi_quantize(midi_path, debug=False, optimize_bpm=True): | |
| """ | |
| 分贝对于左右手(左右手可以通过C4 上下进行分隔): | |
| 对于当前同时按下的音符(按下时间间隔短),他们的时值统一到下一个音符被按下的时间 | |
| 注意只能拉伸尾端,音符的头端不能被动,相当于不能改按下事件 | |
| optimize_bpm: 是否进行BPM优化 | |
| """ | |
| try: | |
| # 读取MIDI文件 | |
| mid = mido.MidiFile(midi_path) | |
| # C4的MIDI音符号是60 | |
| C4_NOTE = 60 | |
| # 为每个音轨处理 | |
| for track_idx, track in enumerate(mid.tracks): | |
| # 检查音轨是否包含音符事件 | |
| has_notes = any(msg.type in ['note_on', 'note_off'] for msg in track) | |
| if not has_notes: | |
| continue | |
| # 收集所有音符事件 | |
| notes_on = [] # 存储note_on事件 | |
| notes_off = [] # 存储note_off事件 | |
| other_events = [] # 存储其他事件 | |
| current_time = 0 | |
| # 解析音轨中的所有事件 | |
| for msg in track: | |
| current_time += msg.time | |
| if msg.type == 'note_on' and msg.velocity > 0: | |
| notes_on.append({ | |
| 'time': current_time, | |
| 'note': msg.note, | |
| 'velocity': msg.velocity, | |
| 'channel': msg.channel, | |
| 'msg': msg | |
| }) | |
| elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): | |
| notes_off.append({ | |
| 'time': current_time, | |
| 'note': msg.note, | |
| 'velocity': msg.velocity if msg.type == 'note_off' else 0, | |
| 'channel': msg.channel, | |
| 'msg': msg | |
| }) | |
| else: | |
| other_events.append({ | |
| 'time': current_time, | |
| 'msg': msg | |
| }) | |
| # 按左右手分组(以C4为界) | |
| left_hand_notes = [] # C4以下(包含C4) | |
| right_hand_notes = [] # C4以上 | |
| for note in notes_on: | |
| if note['note'] <= C4_NOTE: | |
| left_hand_notes.append(note) | |
| else: | |
| right_hand_notes.append(note) | |
| # 处理左右手的音符 | |
| def process_hand_notes(hand_notes, hand_name=""): | |
| if len(hand_notes) < 1: | |
| return | |
| # 按时间排序 | |
| hand_notes.sort(key=lambda x: x['time']) | |
| if debug: | |
| print(f"\n处理{hand_name},共{len(hand_notes)}个音符:") | |
| for note in hand_notes: | |
| print(f" 音符{note['note']} 时间{note['time']}") | |
| # 找到同时按下的音符组 | |
| TIME_THRESHOLD = 100 # 50 ticks内认为是同时按下 | |
| i = 0 | |
| while i < len(hand_notes): | |
| # 找到当前时间点的所有同时按下的音符 | |
| current_time = hand_notes[i]['time'] | |
| simultaneous_notes = [hand_notes[i]] | |
| j = i + 1 | |
| while j < len(hand_notes) and hand_notes[j]['time'] - current_time <= TIME_THRESHOLD: | |
| simultaneous_notes.append(hand_notes[j]) | |
| j += 1 | |
| # 找到下一个音符组的开始时间 | |
| next_note_time = None | |
| if j < len(hand_notes): | |
| next_note_time = hand_notes[j]['time'] | |
| if debug: | |
| note_names = [str(n['note']) for n in simultaneous_notes] | |
| print(f" 同时音符组: {note_names} 在时间{current_time}, 下一组时间: {next_note_time}") | |
| # 对于当前组的音符(不管是单个还是多个),都要调整时值 | |
| if next_note_time is not None: | |
| # 调整当前组中所有音符的note_off时间 | |
| for note in simultaneous_notes: | |
| # 找到对应的note_off事件(找到时间最近的且未被处理的那个) | |
| best_off_event = None | |
| min_time_diff = float('inf') | |
| for off_event in notes_off: | |
| if (off_event['note'] == note['note'] and | |
| off_event['channel'] == note['channel'] and | |
| off_event['time'] > note['time'] and | |
| not off_event.get('processed', False)): # 确保未被处理过 | |
| time_diff = off_event['time'] - note['time'] | |
| if time_diff < min_time_diff: | |
| min_time_diff = time_diff | |
| best_off_event = off_event | |
| if best_off_event is not None: | |
| old_time = best_off_event['time'] | |
| # 将note_off时间设置到下一个音符开始前的一小段时间 | |
| # 确保不会太晚,也不会早于原始的最小持续时间 | |
| min_duration = 100 # 最小持续时间 | |
| target_time = next_note_time - 10 | |
| best_off_event['time'] = max(note['time'] + min_duration, target_time) | |
| best_off_event['processed'] = True # 标记为已处理 | |
| if debug: | |
| print(f" 音符{note['note']} off时间: {old_time} -> {best_off_event['time']}") | |
| i = j | |
| # 处理左右手 | |
| process_hand_notes(left_hand_notes, "左手") | |
| process_hand_notes(right_hand_notes, "右手") | |
| # 重建音轨 | |
| all_events = [] | |
| # 添加所有事件并按时间排序 | |
| for note in notes_on: | |
| all_events.append(('note_on', note['time'], note)) | |
| for note in notes_off: | |
| all_events.append(('note_off', note['time'], note)) | |
| for event in other_events: | |
| all_events.append(('other', event['time'], event)) | |
| # 按时间排序 | |
| all_events.sort(key=lambda x: x[1]) | |
| # 重建MIDI消息 | |
| new_messages = [] | |
| last_time = 0 | |
| for event_type, event_time, event_data in all_events: | |
| delta_time = event_time - last_time | |
| if event_type == 'note_on': | |
| msg = mido.Message('note_on', | |
| channel=event_data['channel'], | |
| note=event_data['note'], | |
| velocity=event_data['velocity'], | |
| time=delta_time) | |
| elif event_type == 'note_off': | |
| msg = mido.Message('note_off', | |
| channel=event_data['channel'], | |
| note=event_data['note'], | |
| velocity=event_data['velocity'], | |
| time=delta_time) | |
| else: | |
| msg = event_data['msg'].copy(time=delta_time) | |
| new_messages.append(msg) | |
| last_time = event_time | |
| # 替换原音轨 | |
| track.clear() | |
| track.extend(new_messages) | |
| # 裁剪MIDI首尾空白 | |
| if debug: | |
| print("裁剪MIDI首尾空白...") | |
| trim_midi_silence(mid, debug) | |
| # 保存处理后的文件 | |
| output_path = os.path.splitext(midi_path)[0] + '_quantized.mid' | |
| mid.save(output_path) | |
| return output_path | |
| except Exception as e: | |
| raise Exception(f"处理MIDI文件时出错: {str(e)}") | |
| def trim_midi_silence(mid, debug=False): | |
| """ | |
| 裁剪MIDI文件首尾的空白部分 | |
| """ | |
| try: | |
| # 找到第一个和最后一个音符事件的时间 | |
| first_note_time = float('inf') | |
| last_note_time = 0 | |
| for track in mid.tracks: | |
| current_time = 0 | |
| track_first_note = None | |
| track_last_note = 0 | |
| for msg in track: | |
| current_time += msg.time | |
| if msg.type == 'note_on' and msg.velocity > 0: | |
| if track_first_note is None: | |
| track_first_note = current_time | |
| track_last_note = current_time | |
| if track_first_note is not None: | |
| first_note_time = min(first_note_time, track_first_note) | |
| last_note_time = max(last_note_time, track_last_note) | |
| if first_note_time == float('inf'): | |
| if debug: | |
| print("没有找到音符,跳过裁剪") | |
| return | |
| if debug: | |
| print(f"音符时间范围: {first_note_time} - {last_note_time}") | |
| # 调整所有音轨的时间 | |
| for track in mid.tracks: | |
| if not track: | |
| continue | |
| # 重建消息,调整时间 | |
| new_messages = [] | |
| current_time = 0 | |
| for msg in track: | |
| current_time += msg.time | |
| # 只保留在音符范围内的事件,或者是重要的元事件 | |
| if (first_note_time <= current_time <= last_note_time + 1000 or # 音符范围内 | |
| msg.type in ['set_tempo', 'key_signature', 'time_signature'] or # 重要元事件 | |
| current_time < first_note_time): # 开头的设置事件 | |
| # 调整时间:减去开头的空白时间 | |
| adjusted_time = max(0, current_time - first_note_time) | |
| new_messages.append((adjusted_time, msg)) | |
| # 重建track | |
| track.clear() | |
| if new_messages: | |
| last_time = 0 | |
| for abs_time, msg in new_messages: | |
| delta_time = abs_time - last_time | |
| new_msg = msg.copy(time=delta_time) | |
| track.append(new_msg) | |
| last_time = abs_time | |
| if debug: | |
| print("MIDI裁剪完成") | |
| except Exception as e: | |
| if debug: | |
| print(f"MIDI裁剪失败: {e}") | |
| import subprocess | |
| # 核心转换函数 | |
| def process_audio(input_file, use_cuda=True, use_quantize=True, progress=gr.Progress(), file_progress_offset=0.0, file_progress_scale=1.0): | |
| """ | |
| 处理音频文件并生成MIDI文件。 | |
| :param input_file: 输入音频文件路径。 | |
| :param use_cuda: 是否使用CUDA加速。 | |
| :param use_quantize: 是否对生成的MIDI文件进行量化处理。 | |
| :param progress: Gradio进度条对象。 | |
| :param file_progress_offset: 进度条的起始偏移量,用于批量处理。 | |
| :param file_progress_scale: 进度条的缩放比例,用于批量处理。 | |
| :return: 包含处理结果的字典。 | |
| """ | |
| temp_dir = None | |
| try: | |
| # 创建一个临时目录来存储所有的输出文件 | |
| temp_dir = tempfile.mkdtemp() | |
| # 从输入文件中获取一个有意义的文件名 | |
| input_name = Path(input_file).stem | |
| # 在临时目录中创建MIDI文件的路径 | |
| output_file = Path(temp_dir) / f"{input_name}.mid" | |
| quantized_output_file = None | |
| start_time = time.time() | |
| progress(file_progress_offset, desc="准备转录...") | |
| # 使用命令行调用transkun | |
| progress(file_progress_offset + 0.3 * file_progress_scale, desc="转录中...") | |
| if use_cuda and cuda_available: | |
| cmd = ["transkun", input_file, str(output_file), "--device", "cuda"] | |
| else: | |
| cmd = ["transkun", input_file, str(output_file)] | |
| # 执行命令 | |
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| stdout, stderr = process.communicate() | |
| # 检查命令是否成功执行 | |
| if process.returncode != 0: | |
| raise Exception(f"transkun命令执行失败: {stderr}") | |
| progress(file_progress_offset + 0.7 * file_progress_scale, desc="保存MIDI...") | |
| # 如果勾选了规整化选项,则进行MIDI规整化 | |
| if use_quantize: | |
| progress(file_progress_offset + 0.8 * file_progress_scale, desc="规整化MIDI...") | |
| try: | |
| # midi_quantize函数将以预期的名称写入输出文件 | |
| quantized_output_file = midi_quantize(str(output_file), debug=False, optimize_bpm=True) | |
| except Exception as e: | |
| print(f"规整化处理失败: {str(e)}") | |
| # 规整化失败不影响主流程 | |
| end_time = time.time() | |
| process_time = round(end_time - start_time, 2) | |
| progress(file_progress_offset + 1.0 * file_progress_scale, desc="完成!") | |
| # 返回结果 | |
| result_files = [str(output_file)] | |
| if quantized_output_file: | |
| result_files.append(quantized_output_file) | |
| return { | |
| "output": f"转换完成!用时 {process_time}秒", | |
| "files": result_files | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| return { | |
| "output": f"转换失败: {str(e)}", | |
| "files": [] | |
| } | |
| # 删除了手动清理代码块,现在由 Gradio 来处理。 | |
| # 创建Gradio界面 | |
| def create_interface(): | |
| with gr.Blocks(title="Transkun - Piano Audio to MIDI", theme=gr.themes.Soft(primary_hue="blue")) as app: | |
| gr.Markdown( | |
| """ | |
| # Transkun - 钢琴音频转MIDI | |
| 将钢琴演奏音频转换为MIDI文件 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # 输入部分 | |
| gr.Markdown("### 1. 选择输入音频文件") | |
| input_audio = gr.File(label="输入音频文件", file_count="multiple", file_types=["audio"], interactive=True) | |
| gr.Markdown("### 2. 选择转换选项") | |
| with gr.Row(): | |
| use_cuda = gr.Checkbox( | |
| label=f"启用CUDA加速 (CUDA {'可用 ✓' if cuda_available else '不可用 ✗'})", | |
| value=cuda_available, | |
| interactive=cuda_available | |
| ) | |
| def show_warning(is_checked): | |
| """Show a warning message when the checkbox is checked.""" | |
| if is_checked: | |
| gr.Warning("注意:开启此项可能影响踏板效果。") | |
| return gr.update(info="") | |
| else: | |
| return gr.update(info="") | |
| use_quantize = gr.Checkbox( | |
| label="使用MIDI规整化(附带有_quantized后缀的输出文件),建议只用于阅读MIDI(❗注意:此项可能影响踏板效果)", | |
| value=False | |
| ) | |
| use_quantize.change( | |
| fn=show_warning, | |
| inputs=use_quantize, | |
| outputs=use_quantize | |
| ) | |
| convert_btn = gr.Button("开始转换", variant="primary") | |
| with gr.Column(scale=1): | |
| # 输出部分 | |
| status_output = gr.Textbox(label="状态", value="准备就绪", interactive=False) | |
| file_output = gr.File(label="生成的MIDI文件", interactive=False, file_count="multiple") | |
| # 创建一个隐藏的文本框来存储文件路径 | |
| file_paths_store = gr.State([]) | |
| # 下载按钮 | |
| with gr.Row(): | |
| download_all_btn = gr.Button("一键下载全部文件", variant="secondary", visible=False) | |
| download_status = gr.Textbox(label="下载状态", value="", visible=False, interactive=False) | |
| # 处理函数 | |
| def on_convert(audio_paths, use_cuda, use_quantize, progress=gr.Progress()): | |
| if not audio_paths: | |
| return "请选择输入音频文件", [], gr.update(visible=False) | |
| all_files = [] | |
| results = [] | |
| total_files = len(audio_paths) | |
| for i, audio_path in enumerate(audio_paths): | |
| file_name = Path(audio_path).name | |
| progress_offset = (i / total_files) | |
| progress_scale = (1 / total_files) * 0.9 | |
| progress(progress_offset, desc=f"处理文件 {i+1}/{total_files}: {file_name}") | |
| result = process_audio(audio_path, use_cuda, use_quantize, progress, | |
| file_progress_offset=progress_offset, | |
| file_progress_scale=progress_scale) | |
| results.append(result["output"]) | |
| all_files.extend(result["files"]) | |
| progress(1.0, desc="全部完成!") | |
| download_btn_update = gr.update(visible=True) if all_files else gr.update(visible=False) | |
| download_status_update = gr.update(visible=False) | |
| return f"转换完成!共处理 {total_files} 个文件\n" + "\n".join(results), all_files, download_btn_update, download_status_update, all_files | |
| # 下载所有文件的函数 | |
| def download_all_files(file_paths, status_output=None): | |
| import tempfile | |
| import os | |
| import shutil | |
| import zipfile | |
| from pathlib import Path | |
| if not file_paths or len(file_paths) == 0: | |
| return None, gr.update(value="没有文件可下载", visible=True), gr.update(visible=False) | |
| try: | |
| # 创建临时目录用于存放文件 | |
| temp_dir = tempfile.mkdtemp(prefix="midi_files_") | |
| # 创建ZIP文件 | |
| zip_path = os.path.join(temp_dir, "all_midi_files.zip") | |
| # 直接创建ZIP文件,不使用shutil.make_archive | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file_path in file_paths: | |
| if os.path.exists(file_path): | |
| # 只添加文件名,不包含路径 | |
| zipf.write(file_path, os.path.basename(file_path)) | |
| return zip_path, gr.update(value="下载准备完成,请点击上方文件链接下载", visible=True), gr.update(visible=False) | |
| except Exception as e: | |
| return None, gr.update(value=f"下载准备失败: {str(e)}", visible=True), gr.update(visible=True) | |
| # 绑定按钮事件 | |
| convert_btn.click( | |
| fn=on_convert, | |
| inputs=[input_audio, use_cuda, use_quantize], | |
| outputs=[status_output, file_output, download_all_btn, download_status, file_paths_store] | |
| ) | |
| # 绑定下载按钮事件 | |
| download_all_btn.click( | |
| fn=download_all_files, | |
| inputs=[file_paths_store, download_status], | |
| outputs=[file_output, download_status, download_all_btn] | |
| ) | |
| return app | |
| # 启动应用 | |
| def main(): | |
| app = create_interface() | |
| # It's better to launch on 0.0.0.0 for broader access, though 127.0.0.1 is fine for local. | |
| # 最好在0.0.0.0上启动以便更广泛的访问,不过127.0.0.1用于本地也是可以的。 | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) | |
| if __name__ == "__main__": | |
| main() | |