Spaces:
Running
Running
fix: 并发上线;音频格式
Browse files- app.py +44 -0
- src/gui_cloud.py +65 -1
- src/mfa_runner.py +22 -2
- src/pipeline.py +11 -8
app.py
CHANGED
|
@@ -49,6 +49,42 @@ MODELS_DIR = None # 延迟初始化
|
|
| 49 |
MFA_DIR = None
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def setup_environment():
|
| 53 |
"""初始化云端环境"""
|
| 54 |
global MODELS_DIR, MFA_DIR
|
|
@@ -65,6 +101,10 @@ def setup_environment():
|
|
| 65 |
Path("/home/studio_service").exists(), # 魔搭创空间特征目录
|
| 66 |
])
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# 魔搭创空间无法访问 HuggingFace,使用镜像
|
| 69 |
if is_cloud and Path("/home/studio_service").exists():
|
| 70 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
@@ -538,6 +578,10 @@ def main():
|
|
| 538 |
app = create_cloud_ui()
|
| 539 |
|
| 540 |
# 云端配置
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
app.launch(
|
| 542 |
server_name="0.0.0.0",
|
| 543 |
server_port=7860,
|
|
|
|
| 49 |
MFA_DIR = None
|
| 50 |
|
| 51 |
|
| 52 |
+
def ensure_ffmpeg():
|
| 53 |
+
"""确保 ffmpeg 已安装(用于音频格式转换,支持 m4a 等格式)"""
|
| 54 |
+
import shutil
|
| 55 |
+
|
| 56 |
+
if shutil.which("ffmpeg"):
|
| 57 |
+
logger.info("ffmpeg 已安装")
|
| 58 |
+
return True
|
| 59 |
+
|
| 60 |
+
logger.info("ffmpeg 未安装,尝试安装...")
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# 尝试使用 apt-get 安装(Debian/Ubuntu)
|
| 64 |
+
result = subprocess.run(
|
| 65 |
+
["apt-get", "update"],
|
| 66 |
+
capture_output=True, text=True, timeout=60
|
| 67 |
+
)
|
| 68 |
+
result = subprocess.run(
|
| 69 |
+
["apt-get", "install", "-y", "ffmpeg"],
|
| 70 |
+
capture_output=True, text=True, timeout=120
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if shutil.which("ffmpeg"):
|
| 74 |
+
logger.info("ffmpeg 安装成功")
|
| 75 |
+
return True
|
| 76 |
+
else:
|
| 77 |
+
logger.warning("ffmpeg 安装后仍未找到")
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
except subprocess.TimeoutExpired:
|
| 81 |
+
logger.warning("ffmpeg 安装超时")
|
| 82 |
+
return False
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.warning(f"ffmpeg 安装失败: {e}")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
def setup_environment():
|
| 89 |
"""初始化云端环境"""
|
| 90 |
global MODELS_DIR, MFA_DIR
|
|
|
|
| 101 |
Path("/home/studio_service").exists(), # 魔搭创空间特征目录
|
| 102 |
])
|
| 103 |
|
| 104 |
+
# 确保 ffmpeg 已安装(支持 m4a 等音频格式)
|
| 105 |
+
if is_cloud or platform.system() != "Windows":
|
| 106 |
+
ensure_ffmpeg()
|
| 107 |
+
|
| 108 |
# 魔搭创空间无法访问 HuggingFace,使用镜像
|
| 109 |
if is_cloud and Path("/home/studio_service").exists():
|
| 110 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
|
|
| 578 |
app = create_cloud_ui()
|
| 579 |
|
| 580 |
# 云端配置
|
| 581 |
+
# 启用队列并设置并发数,允许多用户同时处理
|
| 582 |
+
app.queue(
|
| 583 |
+
default_concurrency_limit=25, # 同时处理的请求数
|
| 584 |
+
)
|
| 585 |
app.launch(
|
| 586 |
server_name="0.0.0.0",
|
| 587 |
server_port=7860,
|
src/gui_cloud.py
CHANGED
|
@@ -16,6 +16,7 @@ import tempfile
|
|
| 16 |
import zipfile
|
| 17 |
import shutil
|
| 18 |
import uuid
|
|
|
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Optional, List, Dict, Tuple, Any
|
| 21 |
|
|
@@ -27,6 +28,33 @@ logging.basicConfig(
|
|
| 27 |
)
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def safe_gradio_handler(func):
|
| 32 |
"""
|
|
@@ -204,6 +232,9 @@ def process_make_voicebank(
|
|
| 204 |
|
| 205 |
返回: (状态, 日志, 下载文件路径, 会话存储的音源包路径)
|
| 206 |
"""
|
|
|
|
|
|
|
|
|
|
| 207 |
logs = []
|
| 208 |
workspace = None
|
| 209 |
|
|
@@ -216,16 +247,19 @@ def process_make_voicebank(
|
|
| 216 |
from src.pipeline import PipelineConfig, VoiceBankPipeline
|
| 217 |
except Exception as e:
|
| 218 |
logger.error(f"导入模块失败: {e}", exc_info=True)
|
|
|
|
| 219 |
return f"❌ 系统错误: 模块加载失败", str(e), None, None
|
| 220 |
|
| 221 |
# 验证输入
|
| 222 |
if not source_name or not source_name.strip():
|
|
|
|
| 223 |
return "❌ 请输入音源名称", "", None, None
|
| 224 |
|
| 225 |
source_name = source_name.strip()
|
| 226 |
|
| 227 |
valid, msg, file_paths = validate_audio_upload(audio_files)
|
| 228 |
if not valid:
|
|
|
|
| 229 |
return f"❌ {msg}", "", None, None
|
| 230 |
|
| 231 |
log(f"📁 {msg}")
|
|
@@ -342,12 +376,15 @@ def process_make_voicebank(
|
|
| 342 |
log(f"📦 已打包: {os.path.basename(zip_path)}")
|
| 343 |
progress(1.0, desc="完成")
|
| 344 |
# 返回路径到会话状态,供导出页面使用
|
|
|
|
| 345 |
return "✅ 音源制作完成", "\n".join(logs), zip_path, zip_path
|
| 346 |
else:
|
|
|
|
| 347 |
return "❌ 打包失败", "\n".join(logs), None, None
|
| 348 |
|
| 349 |
except Exception as e:
|
| 350 |
logger.error(f"制作音源失败: {e}", exc_info=True)
|
|
|
|
| 351 |
return f"❌ 处理失败: {e}", "\n".join(logs), None, None
|
| 352 |
|
| 353 |
finally:
|
|
@@ -456,6 +493,9 @@ def process_export_voicebank(
|
|
| 456 |
|
| 457 |
返回: (状态, 日志, 下载文件路径)
|
| 458 |
"""
|
|
|
|
|
|
|
|
|
|
| 459 |
logs = []
|
| 460 |
def log(msg):
|
| 461 |
logs.append(msg)
|
|
@@ -464,6 +504,7 @@ def process_export_voicebank(
|
|
| 464 |
# 验证输入
|
| 465 |
valid, msg, source_name = validate_voicebank_zip(zip_file)
|
| 466 |
if not valid:
|
|
|
|
| 467 |
return f"❌ {msg}", "", None
|
| 468 |
|
| 469 |
log(f"📦 {msg}")
|
|
@@ -576,12 +617,15 @@ def process_export_voicebank(
|
|
| 576 |
file_count = len([f for f in os.listdir(export_dir) if f.endswith(('.wav', '.ini'))])
|
| 577 |
log(f"📦 已打包: {file_count} 个文件")
|
| 578 |
progress(1.0, desc="完成")
|
|
|
|
| 579 |
return "✅ 导出完成", "\n".join(logs), result_zip
|
| 580 |
else:
|
|
|
|
| 581 |
return "❌ 打包失败", "\n".join(logs), None
|
| 582 |
|
| 583 |
except Exception as e:
|
| 584 |
logger.error(f"导出失败: {e}", exc_info=True)
|
|
|
|
| 585 |
return f"❌ 处理失败: {e}", "\n".join(logs), None
|
| 586 |
|
| 587 |
finally:
|
|
@@ -793,7 +837,16 @@ def create_cloud_ui():
|
|
| 793 |
# 会话状态:存储当前用户制作的音源包路径
|
| 794 |
session_voicebank = gr.State(value=None)
|
| 795 |
|
| 796 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
gr.Markdown("语音数据集处理工具 - 自动化制作语音音源库")
|
| 798 |
gr.Markdown("> ☁️ 云端版:上传音频 → 自动处理 → 下载结果")
|
| 799 |
|
|
@@ -1127,6 +1180,13 @@ def create_cloud_ui():
|
|
| 1127 |
|
| 1128 |
本工具集成 Montreal Forced Aligner (MIT License)
|
| 1129 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
|
| 1131 |
return app
|
| 1132 |
|
|
@@ -1134,6 +1194,10 @@ def create_cloud_ui():
|
|
| 1134 |
def main():
|
| 1135 |
"""云端入口"""
|
| 1136 |
app = create_cloud_ui()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1137 |
app.launch(
|
| 1138 |
server_name="0.0.0.0",
|
| 1139 |
server_port=7860,
|
|
|
|
| 16 |
import zipfile
|
| 17 |
import shutil
|
| 18 |
import uuid
|
| 19 |
+
import threading
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Optional, List, Dict, Tuple, Any
|
| 22 |
|
|
|
|
| 28 |
)
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
| 31 |
+
# ==================== 并发计数器 ====================
|
| 32 |
+
MAX_CONCURRENCY = 25
|
| 33 |
+
_concurrency_lock = threading.Lock()
|
| 34 |
+
_current_concurrency = 0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def increment_concurrency():
|
| 38 |
+
"""增加并发计数"""
|
| 39 |
+
global _current_concurrency
|
| 40 |
+
with _concurrency_lock:
|
| 41 |
+
_current_concurrency += 1
|
| 42 |
+
return _current_concurrency
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def decrement_concurrency():
|
| 46 |
+
"""减少并发计数"""
|
| 47 |
+
global _current_concurrency
|
| 48 |
+
with _concurrency_lock:
|
| 49 |
+
_current_concurrency = max(0, _current_concurrency - 1)
|
| 50 |
+
return _current_concurrency
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_concurrency_status() -> str:
|
| 54 |
+
"""获取当前并发状态文本"""
|
| 55 |
+
with _concurrency_lock:
|
| 56 |
+
return f"当前并发数:{_current_concurrency}/{MAX_CONCURRENCY}"
|
| 57 |
+
|
| 58 |
|
| 59 |
def safe_gradio_handler(func):
|
| 60 |
"""
|
|
|
|
| 232 |
|
| 233 |
返回: (状态, 日志, 下载文件路径, 会话存储的音源包路径)
|
| 234 |
"""
|
| 235 |
+
# 增加并发计数
|
| 236 |
+
increment_concurrency()
|
| 237 |
+
|
| 238 |
logs = []
|
| 239 |
workspace = None
|
| 240 |
|
|
|
|
| 247 |
from src.pipeline import PipelineConfig, VoiceBankPipeline
|
| 248 |
except Exception as e:
|
| 249 |
logger.error(f"导入模块失败: {e}", exc_info=True)
|
| 250 |
+
decrement_concurrency()
|
| 251 |
return f"❌ 系统错误: 模块加载失败", str(e), None, None
|
| 252 |
|
| 253 |
# 验证输入
|
| 254 |
if not source_name or not source_name.strip():
|
| 255 |
+
decrement_concurrency()
|
| 256 |
return "❌ 请输入音源名称", "", None, None
|
| 257 |
|
| 258 |
source_name = source_name.strip()
|
| 259 |
|
| 260 |
valid, msg, file_paths = validate_audio_upload(audio_files)
|
| 261 |
if not valid:
|
| 262 |
+
decrement_concurrency()
|
| 263 |
return f"❌ {msg}", "", None, None
|
| 264 |
|
| 265 |
log(f"📁 {msg}")
|
|
|
|
| 376 |
log(f"📦 已打包: {os.path.basename(zip_path)}")
|
| 377 |
progress(1.0, desc="完成")
|
| 378 |
# 返回路径到会话状态,供导出页面使用
|
| 379 |
+
decrement_concurrency()
|
| 380 |
return "✅ 音源制作完成", "\n".join(logs), zip_path, zip_path
|
| 381 |
else:
|
| 382 |
+
decrement_concurrency()
|
| 383 |
return "❌ 打包失败", "\n".join(logs), None, None
|
| 384 |
|
| 385 |
except Exception as e:
|
| 386 |
logger.error(f"制作音源失败: {e}", exc_info=True)
|
| 387 |
+
decrement_concurrency()
|
| 388 |
return f"❌ 处理失败: {e}", "\n".join(logs), None, None
|
| 389 |
|
| 390 |
finally:
|
|
|
|
| 493 |
|
| 494 |
返回: (状态, 日志, 下载文件路径)
|
| 495 |
"""
|
| 496 |
+
# 增加并发计数
|
| 497 |
+
increment_concurrency()
|
| 498 |
+
|
| 499 |
logs = []
|
| 500 |
def log(msg):
|
| 501 |
logs.append(msg)
|
|
|
|
| 504 |
# 验证输入
|
| 505 |
valid, msg, source_name = validate_voicebank_zip(zip_file)
|
| 506 |
if not valid:
|
| 507 |
+
decrement_concurrency()
|
| 508 |
return f"❌ {msg}", "", None
|
| 509 |
|
| 510 |
log(f"📦 {msg}")
|
|
|
|
| 617 |
file_count = len([f for f in os.listdir(export_dir) if f.endswith(('.wav', '.ini'))])
|
| 618 |
log(f"📦 已打包: {file_count} 个文件")
|
| 619 |
progress(1.0, desc="完成")
|
| 620 |
+
decrement_concurrency()
|
| 621 |
return "✅ 导出完成", "\n".join(logs), result_zip
|
| 622 |
else:
|
| 623 |
+
decrement_concurrency()
|
| 624 |
return "❌ 打包失败", "\n".join(logs), None
|
| 625 |
|
| 626 |
except Exception as e:
|
| 627 |
logger.error(f"导出失败: {e}", exc_info=True)
|
| 628 |
+
decrement_concurrency()
|
| 629 |
return f"❌ 处理失败: {e}", "\n".join(logs), None
|
| 630 |
|
| 631 |
finally:
|
|
|
|
| 837 |
# 会话状态:存储当前用户制作的音源包路径
|
| 838 |
session_voicebank = gr.State(value=None)
|
| 839 |
|
| 840 |
+
# 标题行:左侧标题 + 右侧并发状态
|
| 841 |
+
with gr.Row():
|
| 842 |
+
with gr.Column(scale=4):
|
| 843 |
+
gr.Markdown("# 🎤 人力V助手 (JinrikiHelper)")
|
| 844 |
+
with gr.Column(scale=1, min_width=200):
|
| 845 |
+
concurrency_display = gr.Markdown(
|
| 846 |
+
value=get_concurrency_status(),
|
| 847 |
+
elem_id="concurrency-status"
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
gr.Markdown("语音数据集处理工具 - 自动化制作语音音源库")
|
| 851 |
gr.Markdown("> ☁️ 云端版:上传音频 → 自动处理 → 下载结果")
|
| 852 |
|
|
|
|
| 1180 |
|
| 1181 |
本工具集成 Montreal Forced Aligner (MIT License)
|
| 1182 |
""")
|
| 1183 |
+
|
| 1184 |
+
# 定时刷新并发状态(每3秒)
|
| 1185 |
+
app.load(
|
| 1186 |
+
fn=get_concurrency_status,
|
| 1187 |
+
outputs=[concurrency_display],
|
| 1188 |
+
every=3
|
| 1189 |
+
)
|
| 1190 |
|
| 1191 |
return app
|
| 1192 |
|
|
|
|
| 1194 |
def main():
|
| 1195 |
"""云端入口"""
|
| 1196 |
app = create_cloud_ui()
|
| 1197 |
+
# 启用队列并设置并发数,允许多用户同时处理
|
| 1198 |
+
app.queue(
|
| 1199 |
+
default_concurrency_limit=MAX_CONCURRENCY, # 同时处理的请求数
|
| 1200 |
+
)
|
| 1201 |
app.launch(
|
| 1202 |
server_name="0.0.0.0",
|
| 1203 |
server_port=7860,
|
src/mfa_runner.py
CHANGED
|
@@ -182,7 +182,7 @@ def run_mfa_alignment(
|
|
| 182 |
output_dir: TextGrid 输出目录
|
| 183 |
dict_path: 字典文件路径,默认使用 models/mandarin.dict
|
| 184 |
model_path: 声学模型路径,默认使用 models/mandarin.zip
|
| 185 |
-
temp_dir: 临时目录,默认使用 mfa_temp
|
| 186 |
single_speaker: 是否为单说话人模式
|
| 187 |
clean: 是否清理旧缓存
|
| 188 |
progress_callback: 进度回调函数
|
|
@@ -190,6 +190,8 @@ def run_mfa_alignment(
|
|
| 190 |
返回:
|
| 191 |
(成功标志, 输出信息或错误信息)
|
| 192 |
"""
|
|
|
|
|
|
|
| 193 |
def log(msg: str):
|
| 194 |
logger.info(msg)
|
| 195 |
if progress_callback:
|
|
@@ -203,7 +205,11 @@ def run_mfa_alignment(
|
|
| 203 |
# 设置默认路径
|
| 204 |
dict_path = dict_path or str(DEFAULT_DICT_PATH)
|
| 205 |
model_path = model_path or str(DEFAULT_MODEL_PATH)
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# 验证路径
|
| 209 |
if not os.path.isdir(corpus_dir):
|
|
@@ -261,6 +267,13 @@ def run_mfa_alignment(
|
|
| 261 |
|
| 262 |
if result.returncode == 0:
|
| 263 |
log("MFA 对齐完成!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
return True, result.stdout
|
| 265 |
else:
|
| 266 |
error_msg = result.stderr or result.stdout or "未知错误"
|
|
@@ -275,6 +288,13 @@ def run_mfa_alignment(
|
|
| 275 |
msg = f"MFA 执行异常: {e}"
|
| 276 |
log(msg)
|
| 277 |
return False, msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
def run_mfa_validate(
|
|
|
|
| 182 |
output_dir: TextGrid 输出目录
|
| 183 |
dict_path: 字典文件路径,默认使用 models/mandarin.dict
|
| 184 |
model_path: 声学模型路径,默认使用 models/mandarin.zip
|
| 185 |
+
temp_dir: 临时目录,默认使用 mfa_temp(云端会自动创建独立目录)
|
| 186 |
single_speaker: 是否为单说话人模式
|
| 187 |
clean: 是否清理旧缓存
|
| 188 |
progress_callback: 进度回调函数
|
|
|
|
| 190 |
返回:
|
| 191 |
(成功标志, 输出信息或错误信息)
|
| 192 |
"""
|
| 193 |
+
import uuid
|
| 194 |
+
|
| 195 |
def log(msg: str):
|
| 196 |
logger.info(msg)
|
| 197 |
if progress_callback:
|
|
|
|
| 205 |
# 设置默认路径
|
| 206 |
dict_path = dict_path or str(DEFAULT_DICT_PATH)
|
| 207 |
model_path = model_path or str(DEFAULT_MODEL_PATH)
|
| 208 |
+
|
| 209 |
+
# 临时目录:如果未指定,创建独立目录避免多用户冲突
|
| 210 |
+
if temp_dir is None:
|
| 211 |
+
session_id = uuid.uuid4().hex[:8]
|
| 212 |
+
temp_dir = str(DEFAULT_TEMP_DIR / f"session_{session_id}")
|
| 213 |
|
| 214 |
# 验证路径
|
| 215 |
if not os.path.isdir(corpus_dir):
|
|
|
|
| 267 |
|
| 268 |
if result.returncode == 0:
|
| 269 |
log("MFA 对齐完成!")
|
| 270 |
+
# 清理临时目录(仅清理会话独立目录)
|
| 271 |
+
if "session_" in temp_dir and os.path.exists(temp_dir):
|
| 272 |
+
try:
|
| 273 |
+
shutil.rmtree(temp_dir)
|
| 274 |
+
log(f"已清理临时目录: {temp_dir}")
|
| 275 |
+
except Exception as e:
|
| 276 |
+
logger.warning(f"清理临时目录失败: {e}")
|
| 277 |
return True, result.stdout
|
| 278 |
else:
|
| 279 |
error_msg = result.stderr or result.stdout or "未知错误"
|
|
|
|
| 288 |
msg = f"MFA 执行异常: {e}"
|
| 289 |
log(msg)
|
| 290 |
return False, msg
|
| 291 |
+
finally:
|
| 292 |
+
# 确保临时目录被清理(即使出错)
|
| 293 |
+
if "session_" in temp_dir and os.path.exists(temp_dir):
|
| 294 |
+
try:
|
| 295 |
+
shutil.rmtree(temp_dir)
|
| 296 |
+
except Exception:
|
| 297 |
+
pass
|
| 298 |
|
| 299 |
|
| 300 |
def run_mfa_validate(
|
src/pipeline.py
CHANGED
|
@@ -341,30 +341,33 @@ class VoiceBankPipeline:
|
|
| 341 |
VAD切片
|
| 342 |
|
| 343 |
输出格式统一为: 16bit 44.1kHz 单声道 WAV
|
|
|
|
| 344 |
"""
|
| 345 |
import torch
|
|
|
|
| 346 |
import soundfile as sf
|
| 347 |
import numpy as np
|
| 348 |
|
| 349 |
# 标准输出格式
|
| 350 |
TARGET_SR = 44100
|
| 351 |
|
| 352 |
-
# 读取
|
| 353 |
-
|
| 354 |
|
| 355 |
# 转换为单声道
|
| 356 |
-
if
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
# 重采样到 44.1kHz(标准格式)
|
| 360 |
if sr != TARGET_SR:
|
| 361 |
-
import torchaudio
|
| 362 |
-
audio_tensor = torch.from_numpy(audio).float()
|
| 363 |
resampler = torchaudio.transforms.Resample(sr, TARGET_SR)
|
| 364 |
-
audio = resampler(
|
| 365 |
|
| 366 |
# VAD 需要 16kHz,单独重采样用于检测
|
| 367 |
-
import torchaudio
|
| 368 |
audio_tensor = torch.from_numpy(audio).float()
|
| 369 |
resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000)
|
| 370 |
wav_16k = resampler_16k(audio_tensor)
|
|
|
|
| 341 |
VAD切片
|
| 342 |
|
| 343 |
输出格式统一为: 16bit 44.1kHz 单声道 WAV
|
| 344 |
+
支持格式: wav, mp3, flac, ogg, m4a 等 (通过 torchaudio/ffmpeg)
|
| 345 |
"""
|
| 346 |
import torch
|
| 347 |
+
import torchaudio
|
| 348 |
import soundfile as sf
|
| 349 |
import numpy as np
|
| 350 |
|
| 351 |
# 标准输出格式
|
| 352 |
TARGET_SR = 44100
|
| 353 |
|
| 354 |
+
# 使用 torchaudio 读取音频(支持更多格式,包括 m4a)
|
| 355 |
+
audio_tensor, sr = torchaudio.load(audio_path)
|
| 356 |
|
| 357 |
# 转换为单声道
|
| 358 |
+
if audio_tensor.shape[0] > 1:
|
| 359 |
+
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
|
| 360 |
+
audio_tensor = audio_tensor.squeeze(0) # [samples]
|
| 361 |
+
|
| 362 |
+
# 转为 numpy
|
| 363 |
+
audio = audio_tensor.numpy()
|
| 364 |
|
| 365 |
# 重采样到 44.1kHz(标准格式)
|
| 366 |
if sr != TARGET_SR:
|
|
|
|
|
|
|
| 367 |
resampler = torchaudio.transforms.Resample(sr, TARGET_SR)
|
| 368 |
+
audio = resampler(torch.from_numpy(audio).float()).numpy()
|
| 369 |
|
| 370 |
# VAD 需要 16kHz,单独重采样用于检测
|
|
|
|
| 371 |
audio_tensor = torch.from_numpy(audio).float()
|
| 372 |
resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000)
|
| 373 |
wav_16k = resampler_16k(audio_tensor)
|