import os import subprocess from huggingface_hub import snapshot_download, hf_hub_download # 权重保存的基础目录 (对应 README 中的 ./weights) WEIGHTS_BASE = "weights" def download_models(): print("🚀 [Step 1/2] 开始下载模型权重,请耐心等待...") # 1. 下载 Wan2.1 底模 # README: huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P wan_path = os.path.join(WEIGHTS_BASE, "Wan2.1-I2V-14B-480P") if not os.path.exists(wan_path): print(f"正在下载 Wan2.1-I2V-14B-480P 到 {wan_path} ...") snapshot_download(repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir=wan_path) # 2. 下载 InfiniteTalk 权重 # README: huggingface-cli download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk it_path = os.path.join(WEIGHTS_BASE, "InfiniteTalk") if not os.path.exists(it_path): print(f"正在下载 InfiniteTalk 到 {it_path} ...") snapshot_download(repo_id="MeiGen-AI/InfiniteTalk", local_dir=it_path) # 3. 下载 Audio Encoder (wav2vec2) # README 提到这个比较特殊,分为两步:主仓库文件 + 特殊分支的 model.safetensors w2v_path = os.path.join(WEIGHTS_BASE, "chinese-wav2vec2-base") # 3.1 下载主仓库 if not os.path.exists(w2v_path): print(f"正在下载 chinese-wav2vec2-base 主体到 {w2v_path} ...") snapshot_download(repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir=w2v_path) # 3.2 下载特定 Revision 的 model.safetensors # README: ... model.safetensors --revision refs/pr/1 ... safetensor_file = os.path.join(w2v_path, "model.safetensors") # 简单的检查:如果文件不存在或文件太小(可能没下载对),则下载 if not os.path.exists(safetensor_file) or os.path.getsize(safetensor_file) < 1024: print("正在下载 wav2vec2 的 model.safetensors (revision refs/pr/1) ...") hf_hub_download( repo_id="TencentGameMate/chinese-wav2vec2-base", filename="model.safetensors", revision="refs/pr/1", local_dir=w2v_path, local_dir_use_symlinks=False ) print("✅ 所有模型下载完成!") def start_app(): print("🚀 [Step 2/2] 启动 Gradio 应用...") # 构建启动命令 # 参考 README "Run with Gradio" 部分 # python app.py --ckpt_dir ... --wav2vec_dir ... cmd = [ "python", "app.py", "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P", "--wav2vec_dir", "weights/chinese-wav2vec2-base", # README 示例中单人模式指向 single/infinitetalk.safetensors "--infinitetalk_dir", "weights/InfiniteTalk/single/infinitetalk.safetensors", # 添加低显存参数,防止 Space OOM (参考 README: Run with very low VRAM) "--num_persistent_param_in_dit", "0", "--motion_frame", "9" ] # 打印命令方便调试 print(f"执行命令: {' '.join(cmd)}") # 运行 subprocess.run(cmd, check=True) if __name__ == "__main__": download_models() start_app()