TranskunGUI / app.py
Lollikit's picture
optimize cpu usage (?)
929272c
import os
import time
import torch
import gradio as gr
import threading
import traceback
import moduleconf
import transkun.transcribe
from pathlib import Path
import tempfile
import shutil
os.environ['NO_PROXY'] = "localhost, 127.0.0.1, ::1"
# 设置ffmpeg路径
current_dir = os.path.dirname(os.path.abspath(__file__))
ffmpeg_bin_path = os.path.join(current_dir, "ffmpeg_bin")
# 检查路径是否已在PATH中,避免重复添加
if ffmpeg_bin_path not in os.environ['PATH'].split(os.pathsep):
os.environ['PATH'] = ffmpeg_bin_path + os.pathsep + os.environ['PATH']
# 检查CUDA是否可用
cuda_available = torch.cuda.is_available()
import mido
from collections import defaultdict, Counter
# 从原始代码复制MIDI处理函数
def midi_quantize(midi_path, debug=False, optimize_bpm=True):
"""
分贝对于左右手(左右手可以通过C4 上下进行分隔):
对于当前同时按下的音符(按下时间间隔短),他们的时值统一到下一个音符被按下的时间
注意只能拉伸尾端,音符的头端不能被动,相当于不能改按下事件
optimize_bpm: 是否进行BPM优化
"""
try:
# 读取MIDI文件
mid = mido.MidiFile(midi_path)
# C4的MIDI音符号是60
C4_NOTE = 60
# 为每个音轨处理
for track_idx, track in enumerate(mid.tracks):
# 检查音轨是否包含音符事件
has_notes = any(msg.type in ['note_on', 'note_off'] for msg in track)
if not has_notes:
continue
# 收集所有音符事件
notes_on = [] # 存储note_on事件
notes_off = [] # 存储note_off事件
other_events = [] # 存储其他事件
current_time = 0
# 解析音轨中的所有事件
for msg in track:
current_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
notes_on.append({
'time': current_time,
'note': msg.note,
'velocity': msg.velocity,
'channel': msg.channel,
'msg': msg
})
elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
notes_off.append({
'time': current_time,
'note': msg.note,
'velocity': msg.velocity if msg.type == 'note_off' else 0,
'channel': msg.channel,
'msg': msg
})
else:
other_events.append({
'time': current_time,
'msg': msg
})
# 按左右手分组(以C4为界)
left_hand_notes = [] # C4以下(包含C4)
right_hand_notes = [] # C4以上
for note in notes_on:
if note['note'] <= C4_NOTE:
left_hand_notes.append(note)
else:
right_hand_notes.append(note)
# 处理左右手的音符
def process_hand_notes(hand_notes, hand_name=""):
if len(hand_notes) < 1:
return
# 按时间排序
hand_notes.sort(key=lambda x: x['time'])
if debug:
print(f"\n处理{hand_name},共{len(hand_notes)}个音符:")
for note in hand_notes:
print(f" 音符{note['note']} 时间{note['time']}")
# 找到同时按下的音符组
TIME_THRESHOLD = 100 # 50 ticks内认为是同时按下
i = 0
while i < len(hand_notes):
# 找到当前时间点的所有同时按下的音符
current_time = hand_notes[i]['time']
simultaneous_notes = [hand_notes[i]]
j = i + 1
while j < len(hand_notes) and hand_notes[j]['time'] - current_time <= TIME_THRESHOLD:
simultaneous_notes.append(hand_notes[j])
j += 1
# 找到下一个音符组的开始时间
next_note_time = None
if j < len(hand_notes):
next_note_time = hand_notes[j]['time']
if debug:
note_names = [str(n['note']) for n in simultaneous_notes]
print(f" 同时音符组: {note_names} 在时间{current_time}, 下一组时间: {next_note_time}")
# 对于当前组的音符(不管是单个还是多个),都要调整时值
if next_note_time is not None:
# 调整当前组中所有音符的note_off时间
for note in simultaneous_notes:
# 找到对应的note_off事件(找到时间最近的且未被处理的那个)
best_off_event = None
min_time_diff = float('inf')
for off_event in notes_off:
if (off_event['note'] == note['note'] and
off_event['channel'] == note['channel'] and
off_event['time'] > note['time'] and
not off_event.get('processed', False)): # 确保未被处理过
time_diff = off_event['time'] - note['time']
if time_diff < min_time_diff:
min_time_diff = time_diff
best_off_event = off_event
if best_off_event is not None:
old_time = best_off_event['time']
# 将note_off时间设置到下一个音符开始前的一小段时间
# 确保不会太晚,也不会早于原始的最小持续时间
min_duration = 100 # 最小持续时间
target_time = next_note_time - 10
best_off_event['time'] = max(note['time'] + min_duration, target_time)
best_off_event['processed'] = True # 标记为已处理
if debug:
print(f" 音符{note['note']} off时间: {old_time} -> {best_off_event['time']}")
i = j
# 处理左右手
process_hand_notes(left_hand_notes, "左手")
process_hand_notes(right_hand_notes, "右手")
# 重建音轨
all_events = []
# 添加所有事件并按时间排序
for note in notes_on:
all_events.append(('note_on', note['time'], note))
for note in notes_off:
all_events.append(('note_off', note['time'], note))
for event in other_events:
all_events.append(('other', event['time'], event))
# 按时间排序
all_events.sort(key=lambda x: x[1])
# 重建MIDI消息
new_messages = []
last_time = 0
for event_type, event_time, event_data in all_events:
delta_time = event_time - last_time
if event_type == 'note_on':
msg = mido.Message('note_on',
channel=event_data['channel'],
note=event_data['note'],
velocity=event_data['velocity'],
time=delta_time)
elif event_type == 'note_off':
msg = mido.Message('note_off',
channel=event_data['channel'],
note=event_data['note'],
velocity=event_data['velocity'],
time=delta_time)
else:
msg = event_data['msg'].copy(time=delta_time)
new_messages.append(msg)
last_time = event_time
# 替换原音轨
track.clear()
track.extend(new_messages)
# 裁剪MIDI首尾空白
if debug:
print("裁剪MIDI首尾空白...")
trim_midi_silence(mid, debug)
# 保存处理后的文件
output_path = os.path.splitext(midi_path)[0] + '_quantized.mid'
mid.save(output_path)
return output_path
except Exception as e:
raise Exception(f"处理MIDI文件时出错: {str(e)}")
def trim_midi_silence(mid, debug=False):
"""
裁剪MIDI文件首尾的空白部分
"""
try:
# 找到第一个和最后一个音符事件的时间
first_note_time = float('inf')
last_note_time = 0
for track in mid.tracks:
current_time = 0
track_first_note = None
track_last_note = 0
for msg in track:
current_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
if track_first_note is None:
track_first_note = current_time
track_last_note = current_time
if track_first_note is not None:
first_note_time = min(first_note_time, track_first_note)
last_note_time = max(last_note_time, track_last_note)
if first_note_time == float('inf'):
if debug:
print("没有找到音符,跳过裁剪")
return
if debug:
print(f"音符时间范围: {first_note_time} - {last_note_time}")
# 调整所有音轨的时间
for track in mid.tracks:
if not track:
continue
# 重建消息,调整时间
new_messages = []
current_time = 0
for msg in track:
current_time += msg.time
# 只保留在音符范围内的事件,或者是重要的元事件
if (first_note_time <= current_time <= last_note_time + 1000 or # 音符范围内
msg.type in ['set_tempo', 'key_signature', 'time_signature'] or # 重要元事件
current_time < first_note_time): # 开头的设置事件
# 调整时间:减去开头的空白时间
adjusted_time = max(0, current_time - first_note_time)
new_messages.append((adjusted_time, msg))
# 重建track
track.clear()
if new_messages:
last_time = 0
for abs_time, msg in new_messages:
delta_time = abs_time - last_time
new_msg = msg.copy(time=delta_time)
track.append(new_msg)
last_time = abs_time
if debug:
print("MIDI裁剪完成")
except Exception as e:
if debug:
print(f"MIDI裁剪失败: {e}")
import subprocess
# 核心转换函数
def process_audio(input_file, use_cuda=True, use_quantize=True, progress=gr.Progress(), file_progress_offset=0.0, file_progress_scale=1.0):
"""
处理音频文件并生成MIDI文件。
:param input_file: 输入音频文件路径。
:param use_cuda: 是否使用CUDA加速。
:param use_quantize: 是否对生成的MIDI文件进行量化处理。
:param progress: Gradio进度条对象。
:param file_progress_offset: 进度条的起始偏移量,用于批量处理。
:param file_progress_scale: 进度条的缩放比例,用于批量处理。
:return: 包含处理结果的字典。
"""
temp_dir = None
try:
# 创建一个临时目录来存储所有的输出文件
temp_dir = tempfile.mkdtemp()
# 从输入文件中获取一个有意义的文件名
input_name = Path(input_file).stem
# 在临时目录中创建MIDI文件的路径
output_file = Path(temp_dir) / f"{input_name}.mid"
quantized_output_file = None
start_time = time.time()
progress(file_progress_offset, desc="准备转录...")
# 使用命令行调用transkun
progress(file_progress_offset + 0.3 * file_progress_scale, desc="转录中...")
if use_cuda and cuda_available:
cmd = ["transkun", input_file, str(output_file), "--device", "cuda"]
else:
cmd = ["transkun", input_file, str(output_file)]
# 执行命令
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate()
# 检查命令是否成功执行
if process.returncode != 0:
raise Exception(f"transkun命令执行失败: {stderr}")
progress(file_progress_offset + 0.7 * file_progress_scale, desc="保存MIDI...")
# 如果勾选了规整化选项,则进行MIDI规整化
if use_quantize:
progress(file_progress_offset + 0.8 * file_progress_scale, desc="规整化MIDI...")
try:
# midi_quantize函数将以预期的名称写入输出文件
quantized_output_file = midi_quantize(str(output_file), debug=False, optimize_bpm=True)
except Exception as e:
print(f"规整化处理失败: {str(e)}")
# 规整化失败不影响主流程
end_time = time.time()
process_time = round(end_time - start_time, 2)
progress(file_progress_offset + 1.0 * file_progress_scale, desc="完成!")
# 返回结果
result_files = [str(output_file)]
if quantized_output_file:
result_files.append(quantized_output_file)
return {
"output": f"转换完成!用时 {process_time}秒",
"files": result_files
}
except Exception as e:
traceback.print_exc()
return {
"output": f"转换失败: {str(e)}",
"files": []
}
# 删除了手动清理代码块,现在由 Gradio 来处理。
# 创建Gradio界面
def create_interface():
with gr.Blocks(title="Transkun - Piano Audio to MIDI", theme=gr.themes.Soft(primary_hue="blue")) as app:
gr.Markdown(
"""
# Transkun - 钢琴音频转MIDI
将钢琴演奏音频转换为MIDI文件
"""
)
with gr.Row():
with gr.Column(scale=2):
# 输入部分
gr.Markdown("### 1. 选择输入音频文件")
input_audio = gr.File(label="输入音频文件", file_count="multiple", file_types=["audio"], interactive=True)
gr.Markdown("### 2. 选择转换选项")
with gr.Row():
use_cuda = gr.Checkbox(
label=f"启用CUDA加速 (CUDA {'可用 ✓' if cuda_available else '不可用 ✗'})",
value=cuda_available,
interactive=cuda_available
)
def show_warning(is_checked):
"""Show a warning message when the checkbox is checked."""
if is_checked:
gr.Warning("注意:开启此项可能影响踏板效果。")
return gr.update(info="")
else:
return gr.update(info="")
use_quantize = gr.Checkbox(
label="使用MIDI规整化(附带有_quantized后缀的输出文件),建议只用于阅读MIDI(❗注意:此项可能影响踏板效果)",
value=False
)
use_quantize.change(
fn=show_warning,
inputs=use_quantize,
outputs=use_quantize
)
convert_btn = gr.Button("开始转换", variant="primary")
with gr.Column(scale=1):
# 输出部分
status_output = gr.Textbox(label="状态", value="准备就绪", interactive=False)
file_output = gr.File(label="生成的MIDI文件", interactive=False, file_count="multiple")
# 创建一个隐藏的文本框来存储文件路径
file_paths_store = gr.State([])
# 下载按钮
with gr.Row():
download_all_btn = gr.Button("一键下载全部文件", variant="secondary", visible=False)
download_status = gr.Textbox(label="下载状态", value="", visible=False, interactive=False)
# 处理函数
def on_convert(audio_paths, use_cuda, use_quantize, progress=gr.Progress()):
if not audio_paths:
return "请选择输入音频文件", [], gr.update(visible=False)
all_files = []
results = []
total_files = len(audio_paths)
for i, audio_path in enumerate(audio_paths):
file_name = Path(audio_path).name
progress_offset = (i / total_files)
progress_scale = (1 / total_files) * 0.9
progress(progress_offset, desc=f"处理文件 {i+1}/{total_files}: {file_name}")
result = process_audio(audio_path, use_cuda, use_quantize, progress,
file_progress_offset=progress_offset,
file_progress_scale=progress_scale)
results.append(result["output"])
all_files.extend(result["files"])
progress(1.0, desc="全部完成!")
download_btn_update = gr.update(visible=True) if all_files else gr.update(visible=False)
download_status_update = gr.update(visible=False)
return f"转换完成!共处理 {total_files} 个文件\n" + "\n".join(results), all_files, download_btn_update, download_status_update, all_files
# 下载所有文件的函数
def download_all_files(file_paths, status_output=None):
import tempfile
import os
import shutil
import zipfile
from pathlib import Path
if not file_paths or len(file_paths) == 0:
return None, gr.update(value="没有文件可下载", visible=True), gr.update(visible=False)
try:
# 创建临时目录用于存放文件
temp_dir = tempfile.mkdtemp(prefix="midi_files_")
# 创建ZIP文件
zip_path = os.path.join(temp_dir, "all_midi_files.zip")
# 直接创建ZIP文件,不使用shutil.make_archive
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file_path in file_paths:
if os.path.exists(file_path):
# 只添加文件名,不包含路径
zipf.write(file_path, os.path.basename(file_path))
return zip_path, gr.update(value="下载准备完成,请点击上方文件链接下载", visible=True), gr.update(visible=False)
except Exception as e:
return None, gr.update(value=f"下载准备失败: {str(e)}", visible=True), gr.update(visible=True)
# 绑定按钮事件
convert_btn.click(
fn=on_convert,
inputs=[input_audio, use_cuda, use_quantize],
outputs=[status_output, file_output, download_all_btn, download_status, file_paths_store]
)
# 绑定下载按钮事件
download_all_btn.click(
fn=download_all_files,
inputs=[file_paths_store, download_status],
outputs=[file_output, download_status, download_all_btn]
)
return app
# 启动应用
def main():
app = create_interface()
# It's better to launch on 0.0.0.0 for broader access, though 127.0.0.1 is fine for local.
# 最好在0.0.0.0上启动以便更广泛的访问,不过127.0.0.1用于本地也是可以的。
app.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__":
main()