shawnpi commited on
Commit
d28e6d7
·
verified ·
1 Parent(s): bde27c9

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +94 -126
gradio_app.py CHANGED
@@ -6,70 +6,27 @@ import gradio as gr
6
  import soundfile as sf
7
  import tempfile
8
  import hashlib
9
- import atexit
10
- from importlib.metadata import version, PackageNotFoundError
11
-
12
- # ================= 1. 增强型依赖追踪逻辑 =================
13
- def save_used_dependencies():
14
- """
15
- 导出当前运行环境下已加载的第三方库及其版本。
16
- """
17
- print("\n[System] 正在扫描内存中的依赖组件...")
18
- used_packages = set()
19
-
20
- # 映射表:导入名 -> PyPI 上的安装包名
21
- # 研三学生常用库映射
22
- mapping = {
23
- "yaml": "PyYAML",
24
- "cv2": "opencv-python",
25
- "sklearn": "scikit-learn",
26
- "skimage": "scikit-image",
27
- "faiss": "faiss-cpu", # 或 faiss-gpu
28
- "gradio": "gradio",
29
- "torch": "torch",
30
- "numpy": "numpy",
31
- "soundfile": "soundfile",
32
- "librosa": "librosa",
33
- "scipy": "scipy"
34
- }
35
-
36
- # 遍历当前所有已加载模块
37
- for name in list(sys.modules.keys()):
38
- root_package = name.split('.')[0]
39
-
40
- # 排除内置模块
41
- if root_package in sys.builtin_module_names:
42
- continue
43
-
44
- final_name = mapping.get(root_package, root_package)
45
- used_packages.add(final_name)
46
-
47
- # 过滤掉本地文件夹模块和 pip 相关工具
48
- # 根据你的项目结构,排除 logger 和 utils
49
- excluded = {'pip', 'setuptools', 'wheel', 'pkg_resources', 'logger', 'utils', 'importlib'}
50
- final_list = sorted(list(used_packages - excluded))
51
-
52
- output_path = 'used_requirements.txt'
53
- lines = []
54
- for pkg in final_list:
55
- try:
56
- # 获取版本号
57
- ver = version(pkg)
58
- lines.append(f"{pkg}=={ver}")
59
- except PackageNotFoundError:
60
- # 可能是本地库或者无法识别安装来源
61
- if pkg not in ['__main__', 'atexit', 'tempfile', 'hashlib']:
62
- lines.append(f"{pkg}")
63
-
64
- with open(output_path, 'w', encoding='utf-8') as f:
65
- f.write("\n".join(lines))
66
-
67
- msg = f"✨ 依赖清单已更新至: {os.path.abspath(output_path)}"
68
- print(msg)
69
- return msg
70
-
71
- # 注册正常退出时的钩子
72
- atexit.register(save_used_dependencies)
73
 
74
  # ================= 2. 路径与模型加载逻辑 =================
75
  now_dir = os.path.dirname(os.path.abspath(__file__))
@@ -78,7 +35,6 @@ utils_path = os.path.join(now_dir, 'utils')
78
  if utils_path not in sys.path:
79
  sys.path.append(utils_path)
80
 
81
- # 注意:这些导入需要确保你的目录结构正确
82
  from logger.utils import load_config
83
  from utils.models.models_v2_beta import load_hq_svc
84
  from utils.vocoder import Vocoder
@@ -107,47 +63,76 @@ def initialize_models(config_path):
107
  "content_encoder": None, "spk_encoder": None
108
  }
109
 
110
- # ================= 3. 推理逻辑 =================
111
  def predict(source_audio, target_files, shift_key, adjust_f0):
112
  global TARGET_CACHE
113
- if source_audio is None: return "错误: 未检测到源音频", None
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  sr, encoder_sr, device = ARGS.sample_rate, ARGS.encoder_sr, ARGS.device
115
 
116
- with torch.no_grad():
117
- is_reconstruction = (target_files is None or len(target_files) == 0)
118
- # 计算目标音频列表哈希以判断是否使用缓存
119
- current_hash = hashlib.md5("".join([f.name if hasattr(f, 'name') else f for f in (target_files or [])]).encode()).hexdigest()
120
-
121
- if is_reconstruction:
122
- t_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
123
- spk_ave, all_tar_f0 = t_data['spk'].squeeze().to(device), t_data['f0_origin']
124
- status = "✨ 超分模式"
125
- elif TARGET_CACHE["file_hash"] == current_hash:
126
- spk_ave, all_tar_f0 = TARGET_CACHE["spk_ave"], TARGET_CACHE["all_tar_f0"]
127
- status = "🚀 缓存命中"
128
- else:
129
- spk_list, f0_list = [], []
130
- for f in target_files[:20]:
131
- t_data = get_processed_file(f.name if hasattr(f, 'name') else f, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
132
- if t_data: spk_list.append(t_data['spk']); f0_list.append(t_data['f0_origin'])
133
- spk_ave = torch.stack(spk_list).mean(dim=0).squeeze().to(device)
134
- all_tar_f0 = np.concatenate(f0_list)
135
- TARGET_CACHE.update({"file_hash": current_hash, "spk_ave": spk_ave, "all_tar_f0": all_tar_f0})
136
- status = " 音色提取完成"
137
-
138
- src_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
139
- f0 = src_data['f0'].unsqueeze(0).to(device)
140
-
141
- if adjust_f0 and not is_reconstruction:
142
- shift_key = round(12 * np.log2(all_tar_f0[all_tar_f0>0].mean()/src_data['f0_origin'][src_data['f0_origin']>0].mean()))
143
-
144
- f0 = f0 * 2 ** (float(shift_key) / 12)
145
- mel_g = NET_G(src_data['vq_post'].unsqueeze(0).to(device), f0, src_data['vol'].unsqueeze(0).to(device), spk_ave, gt_spec=None, infer=True, infer_speedup=ARGS.infer_speedup, method=ARGS.infer_method, vocoder=VOCODER)
146
- wav_g = VOCODER.infer(mel_g, f0) if ARGS.vocoder == 'nsf-hifigan' else VOCODER.infer(mel_g)
147
-
148
- out_p = tempfile.mktemp(suffix=".wav")
149
- sf.write(out_p, wav_g.squeeze().cpu().numpy(), 44100)
150
- return f"{status} | 变调: {shift_key}", out_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # ================= 4. UI 界面 =================
153
  custom_css = """
@@ -175,7 +160,7 @@ def build_ui():
175
  </div>
176
  </div>
177
  """)
178
- gr.Markdown("# 🎸 HQ-SVC: SINGING VOICE CONVERSION 🍰")
179
 
180
  with gr.Row():
181
  with gr.Column():
@@ -191,36 +176,19 @@ def build_ui():
191
  result_audio = gr.Audio(label="OUTPUT (44.1kHz HQ)")
192
 
193
  run_btn.click(predict, [src_audio, tar_files, key_shift, auto_f0], [status_box, result_audio])
194
-
195
- # 底部管理按钮
196
- with gr.Row():
197
- export_btn = gr.Button("📦 导出依赖清单", variant="secondary")
198
- exit_btn = gr.Button("🚫 关闭系统", variant="stop")
199
-
200
- # 逻辑绑定
201
- export_btn.click(fn=save_used_dependencies, inputs=None, outputs=status_box)
202
-
203
- def safe_exit():
204
- save_used_dependencies()
205
- print("系统正在关闭...")
206
- sys.exit(0) # 触发 atexit 钩子
207
-
208
- exit_btn.click(fn=safe_exit, inputs=None, outputs=None)
209
 
210
  return demo
211
 
212
  if __name__ == "__main__":
213
- # 确保配置文件路径正确
214
  config_p = "configs/hq_svc_infer.yaml"
215
  if os.path.exists(config_p):
216
  initialize_models(config_p)
217
  else:
218
- print(f"警告: 找不到配置文件 {config_p},请检查路径。")
219
 
220
  demo = build_ui()
221
-
222
- print(">>> 界面启动成功。")
223
- print(">>> 提示:请进行至少一次转换推理,让系统加载动态依赖。")
224
-
225
- # allowed_paths 允许访问图片文件夹
226
- demo.launch(share=True, allowed_paths=[os.path.join(now_dir, "images")])
 
6
  import soundfile as sf
7
  import tempfile
8
  import hashlib
9
+ import requests
10
+ from huggingface_hub import snapshot_download
11
+
12
+ # ================= 1. 环境与自动同步逻辑 =================
13
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
14
+
15
+ def sync_model_files():
16
+ repo_id = "shawnpi/HQ-SVC"
17
+ print(f">>> 正在同步模型权重 ({repo_id})...")
18
+ try:
19
+ snapshot_download(
20
+ repo_id=repo_id,
21
+ allow_patterns=["utils/pretrain/*", "config.json"],
22
+ local_dir=".",
23
+ local_dir_use_symlinks=False
24
+ )
25
+ print(">>> 权重同步完成")
26
+ except Exception as e:
27
+ print(f">>> 同步失败: {e}")
28
+
29
+ sync_model_files()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # ================= 2. 路径与模型加载逻辑 =================
32
  now_dir = os.path.dirname(os.path.abspath(__file__))
 
35
  if utils_path not in sys.path:
36
  sys.path.append(utils_path)
37
 
 
38
  from logger.utils import load_config
39
  from utils.models.models_v2_beta import load_hq_svc
40
  from utils.vocoder import Vocoder
 
63
  "content_encoder": None, "spk_encoder": None
64
  }
65
 
66
+ # ================= 3. 推理逻辑 (增强鲁棒性) =================
67
  def predict(source_audio, target_files, shift_key, adjust_f0):
68
  global TARGET_CACHE
69
+
70
+ # --- 鲁棒性检查 1: 检查源音频是否上传完毕 ---
71
+ if source_audio is None:
72
+ return "⚠️ 系统提示:未检测到源音频。请确认已选择文件,并等待上传进度条走完后再重新转换。", None
73
+
74
+ # --- 鲁棒性检查 2: 检查文件路径有效性 ---
75
+ if not os.path.exists(source_audio):
76
+ return "❌ 系统错误:音频文件传输中断,请刷新页面重新上传音频。", None
77
+
78
+ # --- 鲁棒性检查 3: 检查音频格式 (防止上传了奇怪的文件) ---
79
+ valid_exts = ['.wav', '.mp3', '.flac', '.m4a', '.ogg', '.opus']
80
+ if not any(source_audio.lower().endswith(ext) for ext in valid_exts):
81
+ return f"❌ 系统错误:不支持该文件格式。请上传 {', '.join(valid_exts)} 格式的音频。", None
82
+
83
  sr, encoder_sr, device = ARGS.sample_rate, ARGS.encoder_sr, ARGS.device
84
 
85
+ try:
86
+ with torch.no_grad():
87
+ is_reconstruction = (target_files is None or len(target_files) == 0)
88
+ target_names = "".join([f.name if hasattr(f, 'name') else f for f in (target_files or [])])
89
+ current_hash = hashlib.md5(target_names.encode()).hexdigest()
90
+
91
+ if is_reconstruction:
92
+ t_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
93
+ spk_ave, all_tar_f0 = t_data['spk'].squeeze().to(device), t_data['f0_origin']
94
+ status = "✨ Super-Resolution"
95
+ elif TARGET_CACHE["file_hash"] == current_hash:
96
+ spk_ave, all_tar_f0 = TARGET_CACHE["spk_ave"], TARGET_CACHE["all_tar_f0"]
97
+ status = "🚀 Cache Loaded"
98
+ else:
99
+ spk_list, f0_list = [], []
100
+ for f in (target_files[:20] if target_files else []):
101
+ # 再次校验目标参考音频是否有效
102
+ f_path = f.name if hasattr(f, 'name') else f
103
+ if not f_path or not os.path.exists(f_path): continue
104
+
105
+ t_data = get_processed_file(f_path, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
106
+ if t_data:
107
+ spk_list.append(t_data['spk'])
108
+ f0_list.append(t_data['f0_origin'])
109
+
110
+ if not spk_list:
111
+ return "❌ 终端提示:目标参考音频上传失败或格式不正确,请重新上传。", None
112
+
113
+ spk_ave = torch.stack(spk_list).mean(dim=0).squeeze().to(device)
114
+ all_tar_f0 = np.concatenate(f0_list)
115
+ TARGET_CACHE.update({"file_hash": current_hash, "spk_ave": spk_ave, "all_tar_f0": all_tar_f0})
116
+ status = "✅ VOICE CONVERSION"
117
+
118
+ src_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
119
+ f0 = src_data['f0'].unsqueeze(0).to(device)
120
+
121
+ if adjust_f0 and not is_reconstruction:
122
+ src_f0_valid = src_data['f0_origin'][src_data['f0_origin'] > 0]
123
+ tar_f0_valid = all_tar_f0[all_tar_f0 > 0]
124
+ if len(src_f0_valid) > 0 and len(tar_f0_valid) > 0:
125
+ shift_key = round(12 * np.log2(tar_f0_valid.mean() / src_f0_valid.mean()))
126
+
127
+ f0 = f0 * 2 ** (float(shift_key) / 12)
128
+ mel_g = NET_G(src_data['vq_post'].unsqueeze(0).to(device), f0, src_data['vol'].unsqueeze(0).to(device), spk_ave, gt_spec=None, infer=True, infer_speedup=ARGS.infer_speedup, method=ARGS.infer_method, vocoder=VOCODER)
129
+ wav_g = VOCODER.infer(mel_g, f0) if ARGS.vocoder == 'nsf-hifigan' else VOCODER.infer(mel_g)
130
+
131
+ out_p = tempfile.mktemp(suffix=".wav")
132
+ sf.write(out_p, wav_g.squeeze().cpu().numpy(), 44100)
133
+ return f"{status} | Pitch Shifted: {shift_key}", out_p
134
+ except Exception as e:
135
+ return f"❌ 推理运行出错:{str(e)}。请尝试刷新页面并重新上传音频。", None
136
 
137
  # ================= 4. UI 界面 =================
138
  custom_css = """
 
160
  </div>
161
  </div>
162
  """)
163
+ gr.Markdown("# 🎸HQ-SVC: SINGING VOICE CONVERSION AND SUPER-RESOLUTION🍰")
164
 
165
  with gr.Row():
166
  with gr.Column():
 
176
  result_audio = gr.Audio(label="OUTPUT (44.1kHz HQ)")
177
 
178
  run_btn.click(predict, [src_audio, tar_files, key_shift, auto_f0], [status_box, result_audio])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  return demo
181
 
182
  if __name__ == "__main__":
 
183
  config_p = "configs/hq_svc_infer.yaml"
184
  if os.path.exists(config_p):
185
  initialize_models(config_p)
186
  else:
187
+ print(f"警告: 找不到配置文件 {config_p}。")
188
 
189
  demo = build_ui()
190
+ temp_dir = tempfile.gettempdir()
191
+ demo.launch(
192
+ share=True,
193
+ allowed_paths=[os.path.join(now_dir, "images"), now_dir, temp_dir]
194
+ )