Spaces:
Paused
Paused
File size: 3,211 Bytes
491457f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import subprocess
from huggingface_hub import snapshot_download, hf_hub_download
# 权重保存的基础目录 (对应 README 中的 ./weights)
WEIGHTS_BASE = "weights"
def download_models():
print("🚀 [Step 1/2] 开始下载模型权重,请耐心等待...")
# 1. 下载 Wan2.1 底模
# README: huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
wan_path = os.path.join(WEIGHTS_BASE, "Wan2.1-I2V-14B-480P")
if not os.path.exists(wan_path):
print(f"正在下载 Wan2.1-I2V-14B-480P 到 {wan_path} ...")
snapshot_download(repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir=wan_path)
# 2. 下载 InfiniteTalk 权重
# README: huggingface-cli download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk
it_path = os.path.join(WEIGHTS_BASE, "InfiniteTalk")
if not os.path.exists(it_path):
print(f"正在下载 InfiniteTalk 到 {it_path} ...")
snapshot_download(repo_id="MeiGen-AI/InfiniteTalk", local_dir=it_path)
# 3. 下载 Audio Encoder (wav2vec2)
# README 提到这个比较特殊,分为两步:主仓库文件 + 特殊分支的 model.safetensors
w2v_path = os.path.join(WEIGHTS_BASE, "chinese-wav2vec2-base")
# 3.1 下载主仓库
if not os.path.exists(w2v_path):
print(f"正在下载 chinese-wav2vec2-base 主体到 {w2v_path} ...")
snapshot_download(repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir=w2v_path)
# 3.2 下载特定 Revision 的 model.safetensors
# README: ... model.safetensors --revision refs/pr/1 ...
safetensor_file = os.path.join(w2v_path, "model.safetensors")
# 简单的检查:如果文件不存在或文件太小(可能没下载对),则下载
if not os.path.exists(safetensor_file) or os.path.getsize(safetensor_file) < 1024:
print("正在下载 wav2vec2 的 model.safetensors (revision refs/pr/1) ...")
hf_hub_download(
repo_id="TencentGameMate/chinese-wav2vec2-base",
filename="model.safetensors",
revision="refs/pr/1",
local_dir=w2v_path,
local_dir_use_symlinks=False
)
print("✅ 所有模型下载完成!")
def start_app():
print("🚀 [Step 2/2] 启动 Gradio 应用...")
# 构建启动命令
# 参考 README "Run with Gradio" 部分
# python app.py --ckpt_dir ... --wav2vec_dir ...
cmd = [
"python", "app.py",
"--ckpt_dir", "weights/Wan2.1-I2V-14B-480P",
"--wav2vec_dir", "weights/chinese-wav2vec2-base",
# README 示例中单人模式指向 single/infinitetalk.safetensors
"--infinitetalk_dir", "weights/InfiniteTalk/single/infinitetalk.safetensors",
# 添加低显存参数,防止 Space OOM (参考 README: Run with very low VRAM)
"--num_persistent_param_in_dit", "0",
"--motion_frame", "9"
]
# 打印命令方便调试
print(f"执行命令: {' '.join(cmd)}")
# 运行
subprocess.run(cmd, check=True)
if __name__ == "__main__":
download_models()
start_app() |