TNOT commited on
Commit
75e21d7
·
1 Parent(s): 00a56c1

fix: 并发上线;音频格式

Browse files
Files changed (4) hide show
  1. app.py +44 -0
  2. src/gui_cloud.py +65 -1
  3. src/mfa_runner.py +22 -2
  4. src/pipeline.py +11 -8
app.py CHANGED
@@ -49,6 +49,42 @@ MODELS_DIR = None # 延迟初始化
49
  MFA_DIR = None
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def setup_environment():
53
  """初始化云端环境"""
54
  global MODELS_DIR, MFA_DIR
@@ -65,6 +101,10 @@ def setup_environment():
65
  Path("/home/studio_service").exists(), # 魔搭创空间特征目录
66
  ])
67
 
 
 
 
 
68
  # 魔搭创空间无法访问 HuggingFace,使用镜像
69
  if is_cloud and Path("/home/studio_service").exists():
70
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
@@ -538,6 +578,10 @@ def main():
538
  app = create_cloud_ui()
539
 
540
  # 云端配置
 
 
 
 
541
  app.launch(
542
  server_name="0.0.0.0",
543
  server_port=7860,
 
49
  MFA_DIR = None
50
 
51
 
52
+ def ensure_ffmpeg():
53
+ """确保 ffmpeg 已安装(用于音频格式转换,支持 m4a 等格式)"""
54
+ import shutil
55
+
56
+ if shutil.which("ffmpeg"):
57
+ logger.info("ffmpeg 已安装")
58
+ return True
59
+
60
+ logger.info("ffmpeg 未安装,尝试安装...")
61
+
62
+ try:
63
+ # 尝试使用 apt-get 安装(Debian/Ubuntu)
64
+ result = subprocess.run(
65
+ ["apt-get", "update"],
66
+ capture_output=True, text=True, timeout=60
67
+ )
68
+ result = subprocess.run(
69
+ ["apt-get", "install", "-y", "ffmpeg"],
70
+ capture_output=True, text=True, timeout=120
71
+ )
72
+
73
+ if shutil.which("ffmpeg"):
74
+ logger.info("ffmpeg 安装成功")
75
+ return True
76
+ else:
77
+ logger.warning("ffmpeg 安装后仍未找到")
78
+ return False
79
+
80
+ except subprocess.TimeoutExpired:
81
+ logger.warning("ffmpeg 安装超时")
82
+ return False
83
+ except Exception as e:
84
+ logger.warning(f"ffmpeg 安装失败: {e}")
85
+ return False
86
+
87
+
88
  def setup_environment():
89
  """初始化云端环境"""
90
  global MODELS_DIR, MFA_DIR
 
101
  Path("/home/studio_service").exists(), # 魔搭创空间特征目录
102
  ])
103
 
104
+ # 确保 ffmpeg 已安装(支持 m4a 等音频格式)
105
+ if is_cloud or platform.system() != "Windows":
106
+ ensure_ffmpeg()
107
+
108
  # 魔搭创空间无法访问 HuggingFace,使用镜像
109
  if is_cloud and Path("/home/studio_service").exists():
110
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
 
578
  app = create_cloud_ui()
579
 
580
  # 云端配置
581
+ # 启用队列并设置并发数,允许多用户同时处理
582
+ app.queue(
583
+ default_concurrency_limit=25, # 同时处理的请求数
584
+ )
585
  app.launch(
586
  server_name="0.0.0.0",
587
  server_port=7860,
src/gui_cloud.py CHANGED
@@ -16,6 +16,7 @@ import tempfile
16
  import zipfile
17
  import shutil
18
  import uuid
 
19
  from pathlib import Path
20
  from typing import Optional, List, Dict, Tuple, Any
21
 
@@ -27,6 +28,33 @@ logging.basicConfig(
27
  )
28
  logger = logging.getLogger(__name__)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def safe_gradio_handler(func):
32
  """
@@ -204,6 +232,9 @@ def process_make_voicebank(
204
 
205
  返回: (状态, 日志, 下载文件路径, 会话存储的音源包路径)
206
  """
 
 
 
207
  logs = []
208
  workspace = None
209
 
@@ -216,16 +247,19 @@ def process_make_voicebank(
216
  from src.pipeline import PipelineConfig, VoiceBankPipeline
217
  except Exception as e:
218
  logger.error(f"导入模块失败: {e}", exc_info=True)
 
219
  return f"❌ 系统错误: 模块加载失败", str(e), None, None
220
 
221
  # 验证输入
222
  if not source_name or not source_name.strip():
 
223
  return "❌ 请输入音源名称", "", None, None
224
 
225
  source_name = source_name.strip()
226
 
227
  valid, msg, file_paths = validate_audio_upload(audio_files)
228
  if not valid:
 
229
  return f"❌ {msg}", "", None, None
230
 
231
  log(f"📁 {msg}")
@@ -342,12 +376,15 @@ def process_make_voicebank(
342
  log(f"📦 已打包: {os.path.basename(zip_path)}")
343
  progress(1.0, desc="完成")
344
  # 返回路径到会话状态,供导出页面使用
 
345
  return "✅ 音源制作完成", "\n".join(logs), zip_path, zip_path
346
  else:
 
347
  return "❌ 打包失败", "\n".join(logs), None, None
348
 
349
  except Exception as e:
350
  logger.error(f"制作音源失败: {e}", exc_info=True)
 
351
  return f"❌ 处理失败: {e}", "\n".join(logs), None, None
352
 
353
  finally:
@@ -456,6 +493,9 @@ def process_export_voicebank(
456
 
457
  返回: (状态, 日志, 下载文件路径)
458
  """
 
 
 
459
  logs = []
460
  def log(msg):
461
  logs.append(msg)
@@ -464,6 +504,7 @@ def process_export_voicebank(
464
  # 验证输入
465
  valid, msg, source_name = validate_voicebank_zip(zip_file)
466
  if not valid:
 
467
  return f"❌ {msg}", "", None
468
 
469
  log(f"📦 {msg}")
@@ -576,12 +617,15 @@ def process_export_voicebank(
576
  file_count = len([f for f in os.listdir(export_dir) if f.endswith(('.wav', '.ini'))])
577
  log(f"📦 已打包: {file_count} 个文件")
578
  progress(1.0, desc="完成")
 
579
  return "✅ 导出完成", "\n".join(logs), result_zip
580
  else:
 
581
  return "❌ 打包失败", "\n".join(logs), None
582
 
583
  except Exception as e:
584
  logger.error(f"导出失败: {e}", exc_info=True)
 
585
  return f"❌ 处理失败: {e}", "\n".join(logs), None
586
 
587
  finally:
@@ -793,7 +837,16 @@ def create_cloud_ui():
793
  # 会话状态:存储当前用户制作的音源包路径
794
  session_voicebank = gr.State(value=None)
795
 
796
- gr.Markdown("# 🎤 人力V助手 (JinrikiHelper)")
 
 
 
 
 
 
 
 
 
797
  gr.Markdown("语音数据集处理工具 - 自动化制作语音音源库")
798
  gr.Markdown("> ☁️ 云端版:上传音频 → 自动处理 → 下载结果")
799
 
@@ -1127,6 +1180,13 @@ def create_cloud_ui():
1127
 
1128
  本工具集成 Montreal Forced Aligner (MIT License)
1129
  """)
 
 
 
 
 
 
 
1130
 
1131
  return app
1132
 
@@ -1134,6 +1194,10 @@ def create_cloud_ui():
1134
  def main():
1135
  """云端入口"""
1136
  app = create_cloud_ui()
 
 
 
 
1137
  app.launch(
1138
  server_name="0.0.0.0",
1139
  server_port=7860,
 
16
  import zipfile
17
  import shutil
18
  import uuid
19
+ import threading
20
  from pathlib import Path
21
  from typing import Optional, List, Dict, Tuple, Any
22
 
 
28
  )
29
  logger = logging.getLogger(__name__)
30
 
31
+ # ==================== 并发计数器 ====================
32
+ MAX_CONCURRENCY = 25
33
+ _concurrency_lock = threading.Lock()
34
+ _current_concurrency = 0
35
+
36
+
37
+ def increment_concurrency():
38
+ """增加并发计数"""
39
+ global _current_concurrency
40
+ with _concurrency_lock:
41
+ _current_concurrency += 1
42
+ return _current_concurrency
43
+
44
+
45
+ def decrement_concurrency():
46
+ """减少并发计数"""
47
+ global _current_concurrency
48
+ with _concurrency_lock:
49
+ _current_concurrency = max(0, _current_concurrency - 1)
50
+ return _current_concurrency
51
+
52
+
53
+ def get_concurrency_status() -> str:
54
+ """获取当前并发状态文本"""
55
+ with _concurrency_lock:
56
+ return f"当前并发数:{_current_concurrency}/{MAX_CONCURRENCY}"
57
+
58
 
59
  def safe_gradio_handler(func):
60
  """
 
232
 
233
  返回: (状态, 日志, 下载文件路径, 会话存储的音源包路径)
234
  """
235
+ # 增加并发计数
236
+ increment_concurrency()
237
+
238
  logs = []
239
  workspace = None
240
 
 
247
  from src.pipeline import PipelineConfig, VoiceBankPipeline
248
  except Exception as e:
249
  logger.error(f"导入模块失败: {e}", exc_info=True)
250
+ decrement_concurrency()
251
  return f"❌ 系统错误: 模块加载失败", str(e), None, None
252
 
253
  # 验证输入
254
  if not source_name or not source_name.strip():
255
+ decrement_concurrency()
256
  return "❌ 请输入音源名称", "", None, None
257
 
258
  source_name = source_name.strip()
259
 
260
  valid, msg, file_paths = validate_audio_upload(audio_files)
261
  if not valid:
262
+ decrement_concurrency()
263
  return f"❌ {msg}", "", None, None
264
 
265
  log(f"📁 {msg}")
 
376
  log(f"📦 已打包: {os.path.basename(zip_path)}")
377
  progress(1.0, desc="完成")
378
  # 返回路径到会话状态,供导出页面使用
379
+ decrement_concurrency()
380
  return "✅ 音源制作完成", "\n".join(logs), zip_path, zip_path
381
  else:
382
+ decrement_concurrency()
383
  return "❌ 打包失败", "\n".join(logs), None, None
384
 
385
  except Exception as e:
386
  logger.error(f"制作音源失败: {e}", exc_info=True)
387
+ decrement_concurrency()
388
  return f"❌ 处理失败: {e}", "\n".join(logs), None, None
389
 
390
  finally:
 
493
 
494
  返回: (状态, 日志, 下载文件路径)
495
  """
496
+ # 增加并发计数
497
+ increment_concurrency()
498
+
499
  logs = []
500
  def log(msg):
501
  logs.append(msg)
 
504
  # 验证输入
505
  valid, msg, source_name = validate_voicebank_zip(zip_file)
506
  if not valid:
507
+ decrement_concurrency()
508
  return f"❌ {msg}", "", None
509
 
510
  log(f"📦 {msg}")
 
617
  file_count = len([f for f in os.listdir(export_dir) if f.endswith(('.wav', '.ini'))])
618
  log(f"📦 已打包: {file_count} 个文件")
619
  progress(1.0, desc="完成")
620
+ decrement_concurrency()
621
  return "✅ 导出完成", "\n".join(logs), result_zip
622
  else:
623
+ decrement_concurrency()
624
  return "❌ 打包失败", "\n".join(logs), None
625
 
626
  except Exception as e:
627
  logger.error(f"导出失败: {e}", exc_info=True)
628
+ decrement_concurrency()
629
  return f"❌ 处理失败: {e}", "\n".join(logs), None
630
 
631
  finally:
 
837
  # 会话状态:存储当前用户制作的音源包路径
838
  session_voicebank = gr.State(value=None)
839
 
840
+ # 标题行:左侧标题 + 右侧并发状态
841
+ with gr.Row():
842
+ with gr.Column(scale=4):
843
+ gr.Markdown("# 🎤 人力V助手 (JinrikiHelper)")
844
+ with gr.Column(scale=1, min_width=200):
845
+ concurrency_display = gr.Markdown(
846
+ value=get_concurrency_status(),
847
+ elem_id="concurrency-status"
848
+ )
849
+
850
  gr.Markdown("语音数据集处理工具 - 自动化制作语音音源库")
851
  gr.Markdown("> ☁️ 云端版:上传音频 → 自动处理 → 下载结果")
852
 
 
1180
 
1181
  本工具集成 Montreal Forced Aligner (MIT License)
1182
  """)
1183
+
1184
+ # 定时刷新并发状态(每3秒)
1185
+ app.load(
1186
+ fn=get_concurrency_status,
1187
+ outputs=[concurrency_display],
1188
+ every=3
1189
+ )
1190
 
1191
  return app
1192
 
 
1194
  def main():
1195
  """云端入口"""
1196
  app = create_cloud_ui()
1197
+ # 启用队列并设置并发数,允许多用户同时处理
1198
+ app.queue(
1199
+ default_concurrency_limit=MAX_CONCURRENCY, # 同时处理的请求数
1200
+ )
1201
  app.launch(
1202
  server_name="0.0.0.0",
1203
  server_port=7860,
src/mfa_runner.py CHANGED
@@ -182,7 +182,7 @@ def run_mfa_alignment(
182
  output_dir: TextGrid 输出目录
183
  dict_path: 字典文件路径,默认使用 models/mandarin.dict
184
  model_path: 声学模型路径,默认使用 models/mandarin.zip
185
- temp_dir: 临时目录,默认使用 mfa_temp
186
  single_speaker: 是否为单说话人模式
187
  clean: 是否清理旧缓存
188
  progress_callback: 进度回调函数
@@ -190,6 +190,8 @@ def run_mfa_alignment(
190
  返回:
191
  (成功标志, 输出信息或错误信息)
192
  """
 
 
193
  def log(msg: str):
194
  logger.info(msg)
195
  if progress_callback:
@@ -203,7 +205,11 @@ def run_mfa_alignment(
203
  # 设置默认路径
204
  dict_path = dict_path or str(DEFAULT_DICT_PATH)
205
  model_path = model_path or str(DEFAULT_MODEL_PATH)
206
- temp_dir = temp_dir or str(DEFAULT_TEMP_DIR)
 
 
 
 
207
 
208
  # 验证路径
209
  if not os.path.isdir(corpus_dir):
@@ -261,6 +267,13 @@ def run_mfa_alignment(
261
 
262
  if result.returncode == 0:
263
  log("MFA 对齐完成!")
 
 
 
 
 
 
 
264
  return True, result.stdout
265
  else:
266
  error_msg = result.stderr or result.stdout or "未知错误"
@@ -275,6 +288,13 @@ def run_mfa_alignment(
275
  msg = f"MFA 执行异常: {e}"
276
  log(msg)
277
  return False, msg
 
 
 
 
 
 
 
278
 
279
 
280
  def run_mfa_validate(
 
182
  output_dir: TextGrid 输出目录
183
  dict_path: 字典文件路径,默认使用 models/mandarin.dict
184
  model_path: 声学模型路径,默认使用 models/mandarin.zip
185
+ temp_dir: 临时目录,默认使用 mfa_temp(云端会自动创建独立目录)
186
  single_speaker: 是否为单说话人模式
187
  clean: 是否清理旧缓存
188
  progress_callback: 进度回调函数
 
190
  返回:
191
  (成功标志, 输出信息或错误信息)
192
  """
193
+ import uuid
194
+
195
  def log(msg: str):
196
  logger.info(msg)
197
  if progress_callback:
 
205
  # 设置默认路径
206
  dict_path = dict_path or str(DEFAULT_DICT_PATH)
207
  model_path = model_path or str(DEFAULT_MODEL_PATH)
208
+
209
+ # 临时目录:如果未指定,创建独立目录避免多用户冲突
210
+ if temp_dir is None:
211
+ session_id = uuid.uuid4().hex[:8]
212
+ temp_dir = str(DEFAULT_TEMP_DIR / f"session_{session_id}")
213
 
214
  # 验证路径
215
  if not os.path.isdir(corpus_dir):
 
267
 
268
  if result.returncode == 0:
269
  log("MFA 对齐完成!")
270
+ # 清理临时目录(仅清理会话独立目录)
271
+ if "session_" in temp_dir and os.path.exists(temp_dir):
272
+ try:
273
+ shutil.rmtree(temp_dir)
274
+ log(f"已清理临时目录: {temp_dir}")
275
+ except Exception as e:
276
+ logger.warning(f"清理临时目录失败: {e}")
277
  return True, result.stdout
278
  else:
279
  error_msg = result.stderr or result.stdout or "未知错误"
 
288
  msg = f"MFA 执行异常: {e}"
289
  log(msg)
290
  return False, msg
291
+ finally:
292
+ # 确保临时目录被清理(即使出错)
293
+ if "session_" in temp_dir and os.path.exists(temp_dir):
294
+ try:
295
+ shutil.rmtree(temp_dir)
296
+ except Exception:
297
+ pass
298
 
299
 
300
  def run_mfa_validate(
src/pipeline.py CHANGED
@@ -341,30 +341,33 @@ class VoiceBankPipeline:
341
  VAD切片
342
 
343
  输出格式统一为: 16bit 44.1kHz 单声道 WAV
 
344
  """
345
  import torch
 
346
  import soundfile as sf
347
  import numpy as np
348
 
349
  # 标准输出格式
350
  TARGET_SR = 44100
351
 
352
- # 读取并转换为标准格式
353
- audio, sr = sf.read(audio_path, dtype='float32')
354
 
355
  # 转换为单声道
356
- if len(audio.shape) > 1:
357
- audio = np.mean(audio, axis=1)
 
 
 
 
358
 
359
  # 重采样到 44.1kHz(标准格式)
360
  if sr != TARGET_SR:
361
- import torchaudio
362
- audio_tensor = torch.from_numpy(audio).float()
363
  resampler = torchaudio.transforms.Resample(sr, TARGET_SR)
364
- audio = resampler(audio_tensor).numpy()
365
 
366
  # VAD 需要 16kHz,单独重采样用于检测
367
- import torchaudio
368
  audio_tensor = torch.from_numpy(audio).float()
369
  resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000)
370
  wav_16k = resampler_16k(audio_tensor)
 
341
  VAD切片
342
 
343
  输出格式统一为: 16bit 44.1kHz 单声道 WAV
344
+ 支持格式: wav, mp3, flac, ogg, m4a 等 (通过 torchaudio/ffmpeg)
345
  """
346
  import torch
347
+ import torchaudio
348
  import soundfile as sf
349
  import numpy as np
350
 
351
  # 标准输出格式
352
  TARGET_SR = 44100
353
 
354
+ # 使用 torchaudio 读取音频(支持更多格式,包括 m4a)
355
+ audio_tensor, sr = torchaudio.load(audio_path)
356
 
357
  # 转换为单声道
358
+ if audio_tensor.shape[0] > 1:
359
+ audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
360
+ audio_tensor = audio_tensor.squeeze(0) # [samples]
361
+
362
+ # 转为 numpy
363
+ audio = audio_tensor.numpy()
364
 
365
  # 重采样到 44.1kHz(标准格式)
366
  if sr != TARGET_SR:
 
 
367
  resampler = torchaudio.transforms.Resample(sr, TARGET_SR)
368
+ audio = resampler(torch.from_numpy(audio).float()).numpy()
369
 
370
  # VAD 需要 16kHz,单独重采样用于检测
 
371
  audio_tensor = torch.from_numpy(audio).float()
372
  resampler_16k = torchaudio.transforms.Resample(TARGET_SR, 16000)
373
  wav_16k = resampler_16k(audio_tensor)