InfiniteTalk / run.py
corwen's picture
Upload 2 files
491457f verified
import os
import subprocess
from huggingface_hub import snapshot_download, hf_hub_download
# 权重保存的基础目录 (对应 README 中的 ./weights)
WEIGHTS_BASE = "weights"
def download_models():
print("🚀 [Step 1/2] 开始下载模型权重,请耐心等待...")
# 1. 下载 Wan2.1 底模
# README: huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
wan_path = os.path.join(WEIGHTS_BASE, "Wan2.1-I2V-14B-480P")
if not os.path.exists(wan_path):
print(f"正在下载 Wan2.1-I2V-14B-480P 到 {wan_path} ...")
snapshot_download(repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir=wan_path)
# 2. 下载 InfiniteTalk 权重
# README: huggingface-cli download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk
it_path = os.path.join(WEIGHTS_BASE, "InfiniteTalk")
if not os.path.exists(it_path):
print(f"正在下载 InfiniteTalk 到 {it_path} ...")
snapshot_download(repo_id="MeiGen-AI/InfiniteTalk", local_dir=it_path)
# 3. 下载 Audio Encoder (wav2vec2)
# README 提到这个比较特殊,分为两步:主仓库文件 + 特殊分支的 model.safetensors
w2v_path = os.path.join(WEIGHTS_BASE, "chinese-wav2vec2-base")
# 3.1 下载主仓库
if not os.path.exists(w2v_path):
print(f"正在下载 chinese-wav2vec2-base 主体到 {w2v_path} ...")
snapshot_download(repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir=w2v_path)
# 3.2 下载特定 Revision 的 model.safetensors
# README: ... model.safetensors --revision refs/pr/1 ...
safetensor_file = os.path.join(w2v_path, "model.safetensors")
# 简单的检查:如果文件不存在或文件太小(可能没下载对),则下载
if not os.path.exists(safetensor_file) or os.path.getsize(safetensor_file) < 1024:
print("正在下载 wav2vec2 的 model.safetensors (revision refs/pr/1) ...")
hf_hub_download(
repo_id="TencentGameMate/chinese-wav2vec2-base",
filename="model.safetensors",
revision="refs/pr/1",
local_dir=w2v_path,
local_dir_use_symlinks=False
)
print("✅ 所有模型下载完成!")
def start_app():
print("🚀 [Step 2/2] 启动 Gradio 应用...")
# 构建启动命令
# 参考 README "Run with Gradio" 部分
# python app.py --ckpt_dir ... --wav2vec_dir ...
cmd = [
"python", "app.py",
"--ckpt_dir", "weights/Wan2.1-I2V-14B-480P",
"--wav2vec_dir", "weights/chinese-wav2vec2-base",
# README 示例中单人模式指向 single/infinitetalk.safetensors
"--infinitetalk_dir", "weights/InfiniteTalk/single/infinitetalk.safetensors",
# 添加低显存参数,防止 Space OOM (参考 README: Run with very low VRAM)
"--num_persistent_param_in_dit", "0",
"--motion_frame", "9"
]
# 打印命令方便调试
print(f"执行命令: {' '.join(cmd)}")
# 运行
subprocess.run(cmd, check=True)
if __name__ == "__main__":
download_models()
start_app()