Spaces:
Running
Running
Add application file
Browse files
app.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import threading
|
| 6 |
+
import traceback
|
| 7 |
+
import moduleconf
|
| 8 |
+
import transkun.transcribe
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import tempfile
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
os.environ['NO_PROXY'] = "localhost, 127.0.0.1, ::1"
|
| 14 |
+
|
| 15 |
+
# 设置ffmpeg路径
|
| 16 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
ffmpeg_bin_path = os.path.join(current_dir, "ffmpeg_bin")
|
| 18 |
+
|
| 19 |
+
# 检查路径是否已在PATH中,避免重复添加
|
| 20 |
+
if ffmpeg_bin_path not in os.environ['PATH'].split(os.pathsep):
|
| 21 |
+
os.environ['PATH'] = ffmpeg_bin_path + os.pathsep + os.environ['PATH']
|
| 22 |
+
|
| 23 |
+
# 检查CUDA是否可用
|
| 24 |
+
cuda_available = torch.cuda.is_available()
|
| 25 |
+
|
| 26 |
+
import mido
|
| 27 |
+
from collections import defaultdict, Counter
|
| 28 |
+
|
| 29 |
+
# 从原始代码复制MIDI处理函数
|
| 30 |
+
def midi_quantize(midi_path, debug=False, optimize_bpm=True):
|
| 31 |
+
"""
|
| 32 |
+
分贝对于左右手(左右手可以通过C4 上下进行分隔):
|
| 33 |
+
对于当前同时按下的音符(按下时间间隔短),他们的时值统一到下一个音符被按下的时间
|
| 34 |
+
注意只能拉伸尾端,音符的头端不能被动,相当于不能改按下事件
|
| 35 |
+
|
| 36 |
+
optimize_bpm: 是否进行BPM优化
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
# 读取MIDI文件
|
| 40 |
+
mid = mido.MidiFile(midi_path)
|
| 41 |
+
|
| 42 |
+
# C4的MIDI音符号是60
|
| 43 |
+
C4_NOTE = 60
|
| 44 |
+
|
| 45 |
+
# 为每个音轨处理
|
| 46 |
+
for track_idx, track in enumerate(mid.tracks):
|
| 47 |
+
# 检查音轨是否包含音符事件
|
| 48 |
+
has_notes = any(msg.type in ['note_on', 'note_off'] for msg in track)
|
| 49 |
+
if not has_notes:
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
# 收集所有音符事件
|
| 53 |
+
notes_on = [] # 存储note_on事件
|
| 54 |
+
notes_off = [] # 存储note_off事件
|
| 55 |
+
other_events = [] # 存储其他事件
|
| 56 |
+
|
| 57 |
+
current_time = 0
|
| 58 |
+
|
| 59 |
+
# 解析音轨中的所有事件
|
| 60 |
+
for msg in track:
|
| 61 |
+
current_time += msg.time
|
| 62 |
+
|
| 63 |
+
if msg.type == 'note_on' and msg.velocity > 0:
|
| 64 |
+
notes_on.append({
|
| 65 |
+
'time': current_time,
|
| 66 |
+
'note': msg.note,
|
| 67 |
+
'velocity': msg.velocity,
|
| 68 |
+
'channel': msg.channel,
|
| 69 |
+
'msg': msg
|
| 70 |
+
})
|
| 71 |
+
elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
|
| 72 |
+
notes_off.append({
|
| 73 |
+
'time': current_time,
|
| 74 |
+
'note': msg.note,
|
| 75 |
+
'velocity': msg.velocity if msg.type == 'note_off' else 0,
|
| 76 |
+
'channel': msg.channel,
|
| 77 |
+
'msg': msg
|
| 78 |
+
})
|
| 79 |
+
else:
|
| 80 |
+
other_events.append({
|
| 81 |
+
'time': current_time,
|
| 82 |
+
'msg': msg
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# 按左右手分组(以C4为界)
|
| 86 |
+
left_hand_notes = [] # C4以下(包含C4)
|
| 87 |
+
right_hand_notes = [] # C4以上
|
| 88 |
+
|
| 89 |
+
for note in notes_on:
|
| 90 |
+
if note['note'] <= C4_NOTE:
|
| 91 |
+
left_hand_notes.append(note)
|
| 92 |
+
else:
|
| 93 |
+
right_hand_notes.append(note)
|
| 94 |
+
|
| 95 |
+
# 处理左右手的音符
|
| 96 |
+
def process_hand_notes(hand_notes, hand_name=""):
|
| 97 |
+
if len(hand_notes) < 1:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
# 按时间排序
|
| 101 |
+
hand_notes.sort(key=lambda x: x['time'])
|
| 102 |
+
|
| 103 |
+
if debug:
|
| 104 |
+
print(f"\n处理{hand_name},共{len(hand_notes)}个音符:")
|
| 105 |
+
for note in hand_notes:
|
| 106 |
+
print(f" 音符{note['note']} 时间{note['time']}")
|
| 107 |
+
|
| 108 |
+
# 找到同时按下的音符组
|
| 109 |
+
TIME_THRESHOLD = 100 # 50 ticks内认为是同时按下
|
| 110 |
+
|
| 111 |
+
i = 0
|
| 112 |
+
while i < len(hand_notes):
|
| 113 |
+
# 找到当前时间点的所有同时按下的音符
|
| 114 |
+
current_time = hand_notes[i]['time']
|
| 115 |
+
simultaneous_notes = [hand_notes[i]]
|
| 116 |
+
|
| 117 |
+
j = i + 1
|
| 118 |
+
while j < len(hand_notes) and hand_notes[j]['time'] - current_time <= TIME_THRESHOLD:
|
| 119 |
+
simultaneous_notes.append(hand_notes[j])
|
| 120 |
+
j += 1
|
| 121 |
+
|
| 122 |
+
# 找到下一个音符组的开始时间
|
| 123 |
+
next_note_time = None
|
| 124 |
+
if j < len(hand_notes):
|
| 125 |
+
next_note_time = hand_notes[j]['time']
|
| 126 |
+
|
| 127 |
+
if debug:
|
| 128 |
+
note_names = [str(n['note']) for n in simultaneous_notes]
|
| 129 |
+
print(f" 同时音符组: {note_names} 在时间{current_time}, 下一组时间: {next_note_time}")
|
| 130 |
+
|
| 131 |
+
# 对于当前组的音符(不管是单个还是多个),都要调整时值
|
| 132 |
+
if next_note_time is not None:
|
| 133 |
+
# 调整当前组中所有音符的note_off时间
|
| 134 |
+
for note in simultaneous_notes:
|
| 135 |
+
# 找到对应的note_off事件(找到时间最近的且未被处理的那个)
|
| 136 |
+
best_off_event = None
|
| 137 |
+
min_time_diff = float('inf')
|
| 138 |
+
|
| 139 |
+
for off_event in notes_off:
|
| 140 |
+
if (off_event['note'] == note['note'] and
|
| 141 |
+
off_event['channel'] == note['channel'] and
|
| 142 |
+
off_event['time'] > note['time'] and
|
| 143 |
+
not off_event.get('processed', False)): # 确保未被处理过
|
| 144 |
+
|
| 145 |
+
time_diff = off_event['time'] - note['time']
|
| 146 |
+
if time_diff < min_time_diff:
|
| 147 |
+
min_time_diff = time_diff
|
| 148 |
+
best_off_event = off_event
|
| 149 |
+
|
| 150 |
+
if best_off_event is not None:
|
| 151 |
+
old_time = best_off_event['time']
|
| 152 |
+
# 将note_off时间设置到下一个音符开始前的一小段时间
|
| 153 |
+
# 确保不会太晚,也不会早于原始的最小持续时间
|
| 154 |
+
min_duration = 100 # 最小持续时间
|
| 155 |
+
target_time = next_note_time - 10
|
| 156 |
+
best_off_event['time'] = max(note['time'] + min_duration, target_time)
|
| 157 |
+
best_off_event['processed'] = True # 标记为已处理
|
| 158 |
+
|
| 159 |
+
if debug:
|
| 160 |
+
print(f" 音符{note['note']} off时间: {old_time} -> {best_off_event['time']}")
|
| 161 |
+
|
| 162 |
+
i = j
|
| 163 |
+
|
| 164 |
+
# 处理左右手
|
| 165 |
+
process_hand_notes(left_hand_notes, "左手")
|
| 166 |
+
process_hand_notes(right_hand_notes, "右手")
|
| 167 |
+
|
| 168 |
+
# 重建音轨
|
| 169 |
+
all_events = []
|
| 170 |
+
|
| 171 |
+
# 添加所有事件并按时间排序
|
| 172 |
+
for note in notes_on:
|
| 173 |
+
all_events.append(('note_on', note['time'], note))
|
| 174 |
+
|
| 175 |
+
for note in notes_off:
|
| 176 |
+
all_events.append(('note_off', note['time'], note))
|
| 177 |
+
|
| 178 |
+
for event in other_events:
|
| 179 |
+
all_events.append(('other', event['time'], event))
|
| 180 |
+
|
| 181 |
+
# 按时间排序
|
| 182 |
+
all_events.sort(key=lambda x: x[1])
|
| 183 |
+
|
| 184 |
+
# 重建MIDI消息
|
| 185 |
+
new_messages = []
|
| 186 |
+
last_time = 0
|
| 187 |
+
|
| 188 |
+
for event_type, event_time, event_data in all_events:
|
| 189 |
+
delta_time = event_time - last_time
|
| 190 |
+
|
| 191 |
+
if event_type == 'note_on':
|
| 192 |
+
msg = mido.Message('note_on',
|
| 193 |
+
channel=event_data['channel'],
|
| 194 |
+
note=event_data['note'],
|
| 195 |
+
velocity=event_data['velocity'],
|
| 196 |
+
time=delta_time)
|
| 197 |
+
elif event_type == 'note_off':
|
| 198 |
+
msg = mido.Message('note_off',
|
| 199 |
+
channel=event_data['channel'],
|
| 200 |
+
note=event_data['note'],
|
| 201 |
+
velocity=event_data['velocity'],
|
| 202 |
+
time=delta_time)
|
| 203 |
+
else:
|
| 204 |
+
msg = event_data['msg'].copy(time=delta_time)
|
| 205 |
+
|
| 206 |
+
new_messages.append(msg)
|
| 207 |
+
last_time = event_time
|
| 208 |
+
|
| 209 |
+
# 替换原音轨
|
| 210 |
+
track.clear()
|
| 211 |
+
track.extend(new_messages)
|
| 212 |
+
|
| 213 |
+
# 裁剪MIDI首尾空白
|
| 214 |
+
if debug:
|
| 215 |
+
print("裁剪MIDI首尾空白...")
|
| 216 |
+
trim_midi_silence(mid, debug)
|
| 217 |
+
|
| 218 |
+
# 保存处理后的文件
|
| 219 |
+
output_path = os.path.splitext(midi_path)[0] + '_quantized.mid'
|
| 220 |
+
mid.save(output_path)
|
| 221 |
+
|
| 222 |
+
return output_path
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
raise Exception(f"处理MIDI文件时出错: {str(e)}")
|
| 226 |
+
|
| 227 |
+
def trim_midi_silence(mid, debug=False):
|
| 228 |
+
"""
|
| 229 |
+
裁剪MIDI文件首尾的空白部分
|
| 230 |
+
"""
|
| 231 |
+
try:
|
| 232 |
+
# 找到第一个和最后一个音符事件的时间
|
| 233 |
+
first_note_time = float('inf')
|
| 234 |
+
last_note_time = 0
|
| 235 |
+
|
| 236 |
+
for track in mid.tracks:
|
| 237 |
+
current_time = 0
|
| 238 |
+
track_first_note = None
|
| 239 |
+
track_last_note = 0
|
| 240 |
+
|
| 241 |
+
for msg in track:
|
| 242 |
+
current_time += msg.time
|
| 243 |
+
|
| 244 |
+
if msg.type == 'note_on' and msg.velocity > 0:
|
| 245 |
+
if track_first_note is None:
|
| 246 |
+
track_first_note = current_time
|
| 247 |
+
track_last_note = current_time
|
| 248 |
+
|
| 249 |
+
if track_first_note is not None:
|
| 250 |
+
first_note_time = min(first_note_time, track_first_note)
|
| 251 |
+
last_note_time = max(last_note_time, track_last_note)
|
| 252 |
+
|
| 253 |
+
if first_note_time == float('inf'):
|
| 254 |
+
if debug:
|
| 255 |
+
print("没有找到音符,跳过裁剪")
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
if debug:
|
| 259 |
+
print(f"音符时间范围: {first_note_time} - {last_note_time}")
|
| 260 |
+
|
| 261 |
+
# 调整所有音轨的时间
|
| 262 |
+
for track in mid.tracks:
|
| 263 |
+
if not track:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
# 重建消息,调整时间
|
| 267 |
+
new_messages = []
|
| 268 |
+
current_time = 0
|
| 269 |
+
|
| 270 |
+
for msg in track:
|
| 271 |
+
current_time += msg.time
|
| 272 |
+
|
| 273 |
+
# 只保留在音符范围内的事件,或者是重要的元事件
|
| 274 |
+
if (first_note_time <= current_time <= last_note_time + 1000 or # 音符范围内
|
| 275 |
+
msg.type in ['set_tempo', 'key_signature', 'time_signature'] or # 重要元事件
|
| 276 |
+
current_time < first_note_time): # 开头的设置事件
|
| 277 |
+
|
| 278 |
+
# 调整时间:减去开头的空白时间
|
| 279 |
+
adjusted_time = max(0, current_time - first_note_time)
|
| 280 |
+
new_messages.append((adjusted_time, msg))
|
| 281 |
+
|
| 282 |
+
# 重建track
|
| 283 |
+
track.clear()
|
| 284 |
+
if new_messages:
|
| 285 |
+
last_time = 0
|
| 286 |
+
for abs_time, msg in new_messages:
|
| 287 |
+
delta_time = abs_time - last_time
|
| 288 |
+
new_msg = msg.copy(time=delta_time)
|
| 289 |
+
track.append(new_msg)
|
| 290 |
+
last_time = abs_time
|
| 291 |
+
|
| 292 |
+
if debug:
|
| 293 |
+
print("MIDI裁剪完成")
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
if debug:
|
| 297 |
+
print(f"MIDI裁剪失败: {e}")
|
| 298 |
+
|
| 299 |
+
# 核心转换函数
|
| 300 |
+
def process_audio(input_file, use_cuda=True, use_quantize=True, progress=gr.Progress(), file_progress_offset=0.0, file_progress_scale=1.0):
|
| 301 |
+
"""
|
| 302 |
+
处理音频文件并生成MIDI文件。
|
| 303 |
+
|
| 304 |
+
:param input_file: 输入音频文件路径。
|
| 305 |
+
:param use_cuda: 是否使用CUDA加速。
|
| 306 |
+
:param use_quantize: 是否对生成的MIDI文件进行量化处理。
|
| 307 |
+
:param progress: Gradio进度条对象。
|
| 308 |
+
:param file_progress_offset: 进度条的起始偏移量,用于批量处理。
|
| 309 |
+
:param file_progress_scale: 进度条的缩放比例,用于批量处理。
|
| 310 |
+
:return: 包含处理结果的字典。
|
| 311 |
+
"""
|
| 312 |
+
temp_dir = None
|
| 313 |
+
try:
|
| 314 |
+
# The fix: create a temporary directory to store all output files
|
| 315 |
+
# 修复:创建一个临时目录来存储所有的输出文件
|
| 316 |
+
temp_dir = tempfile.mkdtemp()
|
| 317 |
+
|
| 318 |
+
# Get a meaningful filename from the input file
|
| 319 |
+
# 从输入文件中获取一个有意义的文件名
|
| 320 |
+
input_name = Path(input_file).stem
|
| 321 |
+
|
| 322 |
+
# Create the path for the non-quantized MIDI file inside the temp directory
|
| 323 |
+
# 在临时目录中创建非量化MIDI文件的路径
|
| 324 |
+
output_file = Path(temp_dir) / f"{input_name}.mid"
|
| 325 |
+
|
| 326 |
+
quantized_output_file = None
|
| 327 |
+
|
| 328 |
+
device = "cuda" if use_cuda and cuda_available else "cpu"
|
| 329 |
+
|
| 330 |
+
start_time = time.time()
|
| 331 |
+
progress(file_progress_offset, desc="准备模型...")
|
| 332 |
+
|
| 333 |
+
# 加载模型和配置
|
| 334 |
+
default_weight = os.path.join(current_dir, "models\\2.0.pt")
|
| 335 |
+
default_conf = os.path.join(current_dir, "models\\2.0.conf")
|
| 336 |
+
|
| 337 |
+
# 检查模型文件是否存在
|
| 338 |
+
if not os.path.exists(default_weight) or not os.path.exists(default_conf):
|
| 339 |
+
raise FileNotFoundError(
|
| 340 |
+
f"找不到模型文件!请确保以下文件存在:\n"
|
| 341 |
+
f"{default_weight}\n"
|
| 342 |
+
f"{default_conf}"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# 加载配置
|
| 346 |
+
conf_manager = moduleconf.parseFromFile(default_conf)
|
| 347 |
+
TransKun = conf_manager["Model"].module.TransKun
|
| 348 |
+
conf = conf_manager["Model"].config
|
| 349 |
+
|
| 350 |
+
# 加载模型
|
| 351 |
+
checkpoint = torch.load(default_weight, map_location=device)
|
| 352 |
+
model = TransKun(conf=conf).to(device)
|
| 353 |
+
if "best_state_dict" not in checkpoint:
|
| 354 |
+
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
| 355 |
+
else:
|
| 356 |
+
model.load_state_dict(checkpoint["best_state_dict"], strict=False)
|
| 357 |
+
model.eval()
|
| 358 |
+
|
| 359 |
+
progress(file_progress_offset + 0.2 * file_progress_scale, desc="读取音频...")
|
| 360 |
+
# 读取并处理音频
|
| 361 |
+
fs, audio = transkun.transcribe.readAudio(input_file)
|
| 362 |
+
if fs != model.fs:
|
| 363 |
+
import soxr
|
| 364 |
+
audio = soxr.resample(audio, fs, model.fs)
|
| 365 |
+
|
| 366 |
+
x = torch.from_numpy(audio).to(device)
|
| 367 |
+
|
| 368 |
+
progress(file_progress_offset + 0.4 * file_progress_scale, desc="转录中...")
|
| 369 |
+
# 转录
|
| 370 |
+
with torch.no_grad():
|
| 371 |
+
notes_est = model.transcribe(x)
|
| 372 |
+
|
| 373 |
+
progress(file_progress_offset + 0.7 * file_progress_scale, desc="保存MIDI...")
|
| 374 |
+
# 保存MIDI到临时目录,将 Path 对象转换为字符串
|
| 375 |
+
output_midi = transkun.transcribe.writeMidi(notes_est)
|
| 376 |
+
output_midi.write(str(output_file))
|
| 377 |
+
|
| 378 |
+
# 如果勾选了规整化选项,则进行MIDI规整化
|
| 379 |
+
if use_quantize:
|
| 380 |
+
progress(file_progress_offset + 0.8 * file_progress_scale, desc="规整化MIDI...")
|
| 381 |
+
try:
|
| 382 |
+
# The midi_quantize function will now write the output file with the expected name
|
| 383 |
+
# midi_quantize函数现在将以预期的名称写入输出文件
|
| 384 |
+
quantized_output_file = midi_quantize(str(output_file), debug=False, optimize_bpm=True)
|
| 385 |
+
except Exception as e:
|
| 386 |
+
print(f"规整化处理失败: {str(e)}")
|
| 387 |
+
# 规整化失败不影响主流程
|
| 388 |
+
|
| 389 |
+
end_time = time.time()
|
| 390 |
+
process_time = round(end_time - start_time, 2)
|
| 391 |
+
|
| 392 |
+
progress(file_progress_offset + 1.0 * file_progress_scale, desc="完成!")
|
| 393 |
+
|
| 394 |
+
# 返回结果
|
| 395 |
+
result_files = [str(output_file)]
|
| 396 |
+
if quantized_output_file:
|
| 397 |
+
result_files.append(quantized_output_file)
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
"output": f"转换完成!用时 {process_time}秒",
|
| 401 |
+
"files": result_files
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
traceback.print_exc()
|
| 406 |
+
return {
|
| 407 |
+
"output": f"转换失败: {str(e)}",
|
| 408 |
+
"files": []
|
| 409 |
+
}
|
| 410 |
+
# Removed the manual cleanup block, Gradio will handle this now.
|
| 411 |
+
# 删除了手动清理代码块,现在由 Gradio 来处理。
|
| 412 |
+
|
| 413 |
+
# 创建Gradio界面
|
| 414 |
+
def create_interface():
|
| 415 |
+
with gr.Blocks(title="Transkun - Piano Audio to MIDI", theme=gr.themes.Soft(primary_hue="blue")) as app:
|
| 416 |
+
gr.Markdown(
|
| 417 |
+
"""
|
| 418 |
+
# Transkun - 钢琴音频转MIDI
|
| 419 |
+
将钢琴演奏音频转换为MIDI文件
|
| 420 |
+
"""
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
with gr.Row():
|
| 424 |
+
with gr.Column(scale=2):
|
| 425 |
+
# 输入部分
|
| 426 |
+
gr.Markdown("### 1. 选择输入音频文件")
|
| 427 |
+
input_audio = gr.File(label="输入音频文件", file_count="multiple", file_types=["audio"], interactive=True)
|
| 428 |
+
|
| 429 |
+
gr.Markdown("### 2. 选择转换选项")
|
| 430 |
+
with gr.Row():
|
| 431 |
+
use_cuda = gr.Checkbox(
|
| 432 |
+
label=f"启用CUDA加速 (CUDA {'可用 ✓' if cuda_available else '不可用 ✗'})",
|
| 433 |
+
value=cuda_available,
|
| 434 |
+
interactive=cuda_available
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def show_warning(is_checked):
|
| 439 |
+
"""Show a warning message when the checkbox is checked."""
|
| 440 |
+
if is_checked:
|
| 441 |
+
gr.Warning("注意:开启此项可能影响踏板效果。")
|
| 442 |
+
return gr.update(info="")
|
| 443 |
+
else:
|
| 444 |
+
return gr.update(info="")
|
| 445 |
+
use_quantize = gr.Checkbox(
|
| 446 |
+
label="使用MIDI规整化(附带有_quantized后缀的输出文件),建议只用于阅读MIDI(❗注意:此项可能影响踏板效果)",
|
| 447 |
+
value=False
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
use_quantize.change(
|
| 451 |
+
fn=show_warning,
|
| 452 |
+
inputs=use_quantize,
|
| 453 |
+
outputs=use_quantize
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
convert_btn = gr.Button("开始转换", variant="primary")
|
| 457 |
+
|
| 458 |
+
with gr.Column(scale=1):
|
| 459 |
+
# 输出部分
|
| 460 |
+
status_output = gr.Textbox(label="状态", value="准备就绪", interactive=False)
|
| 461 |
+
file_output = gr.File(label="生成的MIDI文件", interactive=False, file_count="multiple")
|
| 462 |
+
|
| 463 |
+
# 创建一个隐藏的文本框来存储文件路径
|
| 464 |
+
file_paths_store = gr.State([])
|
| 465 |
+
|
| 466 |
+
# 下载按钮
|
| 467 |
+
with gr.Row():
|
| 468 |
+
download_all_btn = gr.Button("一键下载全部文件", variant="secondary", visible=False)
|
| 469 |
+
download_status = gr.Textbox(label="下载状态", value="", visible=False, interactive=False)
|
| 470 |
+
|
| 471 |
+
# 处理函数
|
| 472 |
+
def on_convert(audio_paths, use_cuda, use_quantize, progress=gr.Progress()):
|
| 473 |
+
if not audio_paths:
|
| 474 |
+
return "请选择输入音频文件", [], gr.update(visible=False)
|
| 475 |
+
|
| 476 |
+
all_files = []
|
| 477 |
+
results = []
|
| 478 |
+
total_files = len(audio_paths)
|
| 479 |
+
|
| 480 |
+
for i, audio_path in enumerate(audio_paths):
|
| 481 |
+
file_name = Path(audio_path).name
|
| 482 |
+
progress_offset = (i / total_files)
|
| 483 |
+
progress_scale = (1 / total_files) * 0.9
|
| 484 |
+
|
| 485 |
+
progress(progress_offset, desc=f"处理文件 {i+1}/{total_files}: {file_name}")
|
| 486 |
+
result = process_audio(audio_path, use_cuda, use_quantize, progress,
|
| 487 |
+
file_progress_offset=progress_offset,
|
| 488 |
+
file_progress_scale=progress_scale)
|
| 489 |
+
results.append(result["output"])
|
| 490 |
+
all_files.extend(result["files"])
|
| 491 |
+
|
| 492 |
+
progress(1.0, desc="全部完成!")
|
| 493 |
+
download_btn_update = gr.update(visible=True) if all_files else gr.update(visible=False)
|
| 494 |
+
download_status_update = gr.update(visible=False)
|
| 495 |
+
return f"转换完成!共处理 {total_files} 个文件\n" + "\n".join(results), all_files, download_btn_update, download_status_update, all_files
|
| 496 |
+
|
| 497 |
+
# 下载所有文件的函数
|
| 498 |
+
def download_all_files(file_paths, status_output=None):
|
| 499 |
+
import tempfile
|
| 500 |
+
import os
|
| 501 |
+
import shutil
|
| 502 |
+
import zipfile
|
| 503 |
+
from pathlib import Path
|
| 504 |
+
|
| 505 |
+
if not file_paths or len(file_paths) == 0:
|
| 506 |
+
return None, gr.update(value="没有文件可下载", visible=True), gr.update(visible=False)
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
# 创建临时目录用于存放文件
|
| 510 |
+
temp_dir = tempfile.mkdtemp(prefix="midi_files_")
|
| 511 |
+
|
| 512 |
+
# 创建ZIP文件
|
| 513 |
+
zip_path = os.path.join(temp_dir, "all_midi_files.zip")
|
| 514 |
+
|
| 515 |
+
# 直接创建ZIP文件,不使用shutil.make_archive
|
| 516 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 517 |
+
for file_path in file_paths:
|
| 518 |
+
if os.path.exists(file_path):
|
| 519 |
+
# 只添加文件名,不包含路径
|
| 520 |
+
zipf.write(file_path, os.path.basename(file_path))
|
| 521 |
+
|
| 522 |
+
return zip_path, gr.update(value="下载准备完成,请点击上方文件链接下载", visible=True), gr.update(visible=False)
|
| 523 |
+
except Exception as e:
|
| 524 |
+
return None, gr.update(value=f"下载准备失败: {str(e)}", visible=True), gr.update(visible=True)
|
| 525 |
+
|
| 526 |
+
# 绑定按钮事件
|
| 527 |
+
convert_btn.click(
|
| 528 |
+
fn=on_convert,
|
| 529 |
+
inputs=[input_audio, use_cuda, use_quantize],
|
| 530 |
+
outputs=[status_output, file_output, download_all_btn, download_status, file_paths_store]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# 绑定下载按钮事件
|
| 534 |
+
download_all_btn.click(
|
| 535 |
+
fn=download_all_files,
|
| 536 |
+
inputs=[file_paths_store, download_status],
|
| 537 |
+
outputs=[file_output, download_status, download_all_btn]
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return app
|
| 541 |
+
|
| 542 |
+
# 启动应用
|
| 543 |
+
def main():
|
| 544 |
+
app = create_interface()
|
| 545 |
+
# It's better to launch on 0.0.0.0 for broader access, though 127.0.0.1 is fine for local.
|
| 546 |
+
# 最好在0.0.0.0上启动以便更广泛的访问,不过127.0.0.1用于本地也是可以的。
|
| 547 |
+
app.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
main()
|