Update gradio_app.py
Browse files- gradio_app.py +94 -126
gradio_app.py
CHANGED
|
@@ -6,70 +6,27 @@ import gradio as gr
|
|
| 6 |
import soundfile as sf
|
| 7 |
import tempfile
|
| 8 |
import hashlib
|
| 9 |
-
import
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
# ================= 1.
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
""
|
| 17 |
-
print("
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
"
|
| 26 |
-
|
| 27 |
-
"
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
"numpy": "numpy",
|
| 31 |
-
"soundfile": "soundfile",
|
| 32 |
-
"librosa": "librosa",
|
| 33 |
-
"scipy": "scipy"
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
# 遍历当前所有已加载模块
|
| 37 |
-
for name in list(sys.modules.keys()):
|
| 38 |
-
root_package = name.split('.')[0]
|
| 39 |
-
|
| 40 |
-
# 排除内置模块
|
| 41 |
-
if root_package in sys.builtin_module_names:
|
| 42 |
-
continue
|
| 43 |
-
|
| 44 |
-
final_name = mapping.get(root_package, root_package)
|
| 45 |
-
used_packages.add(final_name)
|
| 46 |
-
|
| 47 |
-
# 过滤掉本地文件夹模块和 pip 相关工具
|
| 48 |
-
# 根据你的项目结构,排除 logger 和 utils
|
| 49 |
-
excluded = {'pip', 'setuptools', 'wheel', 'pkg_resources', 'logger', 'utils', 'importlib'}
|
| 50 |
-
final_list = sorted(list(used_packages - excluded))
|
| 51 |
-
|
| 52 |
-
output_path = 'used_requirements.txt'
|
| 53 |
-
lines = []
|
| 54 |
-
for pkg in final_list:
|
| 55 |
-
try:
|
| 56 |
-
# 获取版本号
|
| 57 |
-
ver = version(pkg)
|
| 58 |
-
lines.append(f"{pkg}=={ver}")
|
| 59 |
-
except PackageNotFoundError:
|
| 60 |
-
# 可能是本地库或者无法识别安装来源
|
| 61 |
-
if pkg not in ['__main__', 'atexit', 'tempfile', 'hashlib']:
|
| 62 |
-
lines.append(f"{pkg}")
|
| 63 |
-
|
| 64 |
-
with open(output_path, 'w', encoding='utf-8') as f:
|
| 65 |
-
f.write("\n".join(lines))
|
| 66 |
-
|
| 67 |
-
msg = f"✨ 依赖清单已更新至: {os.path.abspath(output_path)}"
|
| 68 |
-
print(msg)
|
| 69 |
-
return msg
|
| 70 |
-
|
| 71 |
-
# 注册正常退出时的钩子
|
| 72 |
-
atexit.register(save_used_dependencies)
|
| 73 |
|
| 74 |
# ================= 2. 路径与模型加载逻辑 =================
|
| 75 |
now_dir = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -78,7 +35,6 @@ utils_path = os.path.join(now_dir, 'utils')
|
|
| 78 |
if utils_path not in sys.path:
|
| 79 |
sys.path.append(utils_path)
|
| 80 |
|
| 81 |
-
# 注意:这些导入需要确保你的目录结构正确
|
| 82 |
from logger.utils import load_config
|
| 83 |
from utils.models.models_v2_beta import load_hq_svc
|
| 84 |
from utils.vocoder import Vocoder
|
|
@@ -107,47 +63,76 @@ def initialize_models(config_path):
|
|
| 107 |
"content_encoder": None, "spk_encoder": None
|
| 108 |
}
|
| 109 |
|
| 110 |
-
# ================= 3. 推理逻辑 =================
|
| 111 |
def predict(source_audio, target_files, shift_key, adjust_f0):
|
| 112 |
global TARGET_CACHE
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
sr, encoder_sr, device = ARGS.sample_rate, ARGS.encoder_sr, ARGS.device
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# ================= 4. UI 界面 =================
|
| 153 |
custom_css = """
|
|
@@ -175,7 +160,7 @@ def build_ui():
|
|
| 175 |
</div>
|
| 176 |
</div>
|
| 177 |
""")
|
| 178 |
-
gr.Markdown("# 🎸
|
| 179 |
|
| 180 |
with gr.Row():
|
| 181 |
with gr.Column():
|
|
@@ -191,36 +176,19 @@ def build_ui():
|
|
| 191 |
result_audio = gr.Audio(label="OUTPUT (44.1kHz HQ)")
|
| 192 |
|
| 193 |
run_btn.click(predict, [src_audio, tar_files, key_shift, auto_f0], [status_box, result_audio])
|
| 194 |
-
|
| 195 |
-
# 底部管理按钮
|
| 196 |
-
with gr.Row():
|
| 197 |
-
export_btn = gr.Button("📦 导出依赖清单", variant="secondary")
|
| 198 |
-
exit_btn = gr.Button("🚫 关闭系统", variant="stop")
|
| 199 |
-
|
| 200 |
-
# 逻辑绑定
|
| 201 |
-
export_btn.click(fn=save_used_dependencies, inputs=None, outputs=status_box)
|
| 202 |
-
|
| 203 |
-
def safe_exit():
|
| 204 |
-
save_used_dependencies()
|
| 205 |
-
print("系统正在关闭...")
|
| 206 |
-
sys.exit(0) # 触发 atexit 钩子
|
| 207 |
-
|
| 208 |
-
exit_btn.click(fn=safe_exit, inputs=None, outputs=None)
|
| 209 |
|
| 210 |
return demo
|
| 211 |
|
| 212 |
if __name__ == "__main__":
|
| 213 |
-
# 确保配置文件路径正确
|
| 214 |
config_p = "configs/hq_svc_infer.yaml"
|
| 215 |
if os.path.exists(config_p):
|
| 216 |
initialize_models(config_p)
|
| 217 |
else:
|
| 218 |
-
print(f"警告: 找不到配置文件 {config_p}
|
| 219 |
|
| 220 |
demo = build_ui()
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
demo.launch(share=True, allowed_paths=[os.path.join(now_dir, "images")])
|
|
|
|
| 6 |
import soundfile as sf
|
| 7 |
import tempfile
|
| 8 |
import hashlib
|
| 9 |
+
import requests
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
|
| 12 |
+
# ================= 1. 环境与自动同步逻辑 =================
|
| 13 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
| 14 |
+
|
| 15 |
+
def sync_model_files():
|
| 16 |
+
repo_id = "shawnpi/HQ-SVC"
|
| 17 |
+
print(f">>> 正在同步模型权重 ({repo_id})...")
|
| 18 |
+
try:
|
| 19 |
+
snapshot_download(
|
| 20 |
+
repo_id=repo_id,
|
| 21 |
+
allow_patterns=["utils/pretrain/*", "config.json"],
|
| 22 |
+
local_dir=".",
|
| 23 |
+
local_dir_use_symlinks=False
|
| 24 |
+
)
|
| 25 |
+
print(">>> 权重同步完成")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f">>> 同步失败: {e}")
|
| 28 |
+
|
| 29 |
+
sync_model_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# ================= 2. 路径与模型加载逻辑 =================
|
| 32 |
now_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
| 35 |
if utils_path not in sys.path:
|
| 36 |
sys.path.append(utils_path)
|
| 37 |
|
|
|
|
| 38 |
from logger.utils import load_config
|
| 39 |
from utils.models.models_v2_beta import load_hq_svc
|
| 40 |
from utils.vocoder import Vocoder
|
|
|
|
| 63 |
"content_encoder": None, "spk_encoder": None
|
| 64 |
}
|
| 65 |
|
| 66 |
+
# ================= 3. 推理逻辑 (增强鲁棒性) =================
|
| 67 |
def predict(source_audio, target_files, shift_key, adjust_f0):
|
| 68 |
global TARGET_CACHE
|
| 69 |
+
|
| 70 |
+
# --- 鲁棒性检查 1: 检查源音频是否上传完毕 ---
|
| 71 |
+
if source_audio is None:
|
| 72 |
+
return "⚠️ 系统提示:未检测到源音频。请确认已选择文件,并等待上传进度条走完后再重新转换。", None
|
| 73 |
+
|
| 74 |
+
# --- 鲁棒性检查 2: 检查文件路径有效性 ---
|
| 75 |
+
if not os.path.exists(source_audio):
|
| 76 |
+
return "❌ 系统错误:音频文件传输中断,请刷新页面重新上传音频。", None
|
| 77 |
+
|
| 78 |
+
# --- 鲁棒性检查 3: 检查音频格式 (防止上传了奇怪的文件) ---
|
| 79 |
+
valid_exts = ['.wav', '.mp3', '.flac', '.m4a', '.ogg', '.opus']
|
| 80 |
+
if not any(source_audio.lower().endswith(ext) for ext in valid_exts):
|
| 81 |
+
return f"❌ 系统错误:不支持该文件格式。请上传 {', '.join(valid_exts)} 格式的音频。", None
|
| 82 |
+
|
| 83 |
sr, encoder_sr, device = ARGS.sample_rate, ARGS.encoder_sr, ARGS.device
|
| 84 |
|
| 85 |
+
try:
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
is_reconstruction = (target_files is None or len(target_files) == 0)
|
| 88 |
+
target_names = "".join([f.name if hasattr(f, 'name') else f for f in (target_files or [])])
|
| 89 |
+
current_hash = hashlib.md5(target_names.encode()).hexdigest()
|
| 90 |
+
|
| 91 |
+
if is_reconstruction:
|
| 92 |
+
t_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
|
| 93 |
+
spk_ave, all_tar_f0 = t_data['spk'].squeeze().to(device), t_data['f0_origin']
|
| 94 |
+
status = "✨ Super-Resolution"
|
| 95 |
+
elif TARGET_CACHE["file_hash"] == current_hash:
|
| 96 |
+
spk_ave, all_tar_f0 = TARGET_CACHE["spk_ave"], TARGET_CACHE["all_tar_f0"]
|
| 97 |
+
status = "🚀 Cache Loaded"
|
| 98 |
+
else:
|
| 99 |
+
spk_list, f0_list = [], []
|
| 100 |
+
for f in (target_files[:20] if target_files else []):
|
| 101 |
+
# 再次校验目标参考音频是否有效
|
| 102 |
+
f_path = f.name if hasattr(f, 'name') else f
|
| 103 |
+
if not f_path or not os.path.exists(f_path): continue
|
| 104 |
+
|
| 105 |
+
t_data = get_processed_file(f_path, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
|
| 106 |
+
if t_data:
|
| 107 |
+
spk_list.append(t_data['spk'])
|
| 108 |
+
f0_list.append(t_data['f0_origin'])
|
| 109 |
+
|
| 110 |
+
if not spk_list:
|
| 111 |
+
return "❌ 终端提示:目标参考音频上传失败或格式不正确,请重新上传。", None
|
| 112 |
+
|
| 113 |
+
spk_ave = torch.stack(spk_list).mean(dim=0).squeeze().to(device)
|
| 114 |
+
all_tar_f0 = np.concatenate(f0_list)
|
| 115 |
+
TARGET_CACHE.update({"file_hash": current_hash, "spk_ave": spk_ave, "all_tar_f0": all_tar_f0})
|
| 116 |
+
status = "✅ VOICE CONVERSION"
|
| 117 |
+
|
| 118 |
+
src_data = get_processed_file(source_audio, sr, encoder_sr, VOCODER, PREPROCESSORS["volume_extractor"], PREPROCESSORS["f0_extractor"], PREPROCESSORS["fa_encoder"], PREPROCESSORS["fa_decoder"], None, None, device=device)
|
| 119 |
+
f0 = src_data['f0'].unsqueeze(0).to(device)
|
| 120 |
+
|
| 121 |
+
if adjust_f0 and not is_reconstruction:
|
| 122 |
+
src_f0_valid = src_data['f0_origin'][src_data['f0_origin'] > 0]
|
| 123 |
+
tar_f0_valid = all_tar_f0[all_tar_f0 > 0]
|
| 124 |
+
if len(src_f0_valid) > 0 and len(tar_f0_valid) > 0:
|
| 125 |
+
shift_key = round(12 * np.log2(tar_f0_valid.mean() / src_f0_valid.mean()))
|
| 126 |
+
|
| 127 |
+
f0 = f0 * 2 ** (float(shift_key) / 12)
|
| 128 |
+
mel_g = NET_G(src_data['vq_post'].unsqueeze(0).to(device), f0, src_data['vol'].unsqueeze(0).to(device), spk_ave, gt_spec=None, infer=True, infer_speedup=ARGS.infer_speedup, method=ARGS.infer_method, vocoder=VOCODER)
|
| 129 |
+
wav_g = VOCODER.infer(mel_g, f0) if ARGS.vocoder == 'nsf-hifigan' else VOCODER.infer(mel_g)
|
| 130 |
+
|
| 131 |
+
out_p = tempfile.mktemp(suffix=".wav")
|
| 132 |
+
sf.write(out_p, wav_g.squeeze().cpu().numpy(), 44100)
|
| 133 |
+
return f"{status} | Pitch Shifted: {shift_key}", out_p
|
| 134 |
+
except Exception as e:
|
| 135 |
+
return f"❌ 推理运行出错:{str(e)}。请尝试刷新页面并重新上传音频。", None
|
| 136 |
|
| 137 |
# ================= 4. UI 界面 =================
|
| 138 |
custom_css = """
|
|
|
|
| 160 |
</div>
|
| 161 |
</div>
|
| 162 |
""")
|
| 163 |
+
gr.Markdown("# 🎸HQ-SVC: SINGING VOICE CONVERSION AND SUPER-RESOLUTION🍰")
|
| 164 |
|
| 165 |
with gr.Row():
|
| 166 |
with gr.Column():
|
|
|
|
| 176 |
result_audio = gr.Audio(label="OUTPUT (44.1kHz HQ)")
|
| 177 |
|
| 178 |
run_btn.click(predict, [src_audio, tar_files, key_shift, auto_f0], [status_box, result_audio])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
return demo
|
| 181 |
|
| 182 |
if __name__ == "__main__":
|
|
|
|
| 183 |
config_p = "configs/hq_svc_infer.yaml"
|
| 184 |
if os.path.exists(config_p):
|
| 185 |
initialize_models(config_p)
|
| 186 |
else:
|
| 187 |
+
print(f"警告: 找不到配置文件 {config_p}。")
|
| 188 |
|
| 189 |
demo = build_ui()
|
| 190 |
+
temp_dir = tempfile.gettempdir()
|
| 191 |
+
demo.launch(
|
| 192 |
+
share=True,
|
| 193 |
+
allowed_paths=[os.path.join(now_dir, "images"), now_dir, temp_dir]
|
| 194 |
+
)
|
|
|