File size: 3,211 Bytes
491457f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import subprocess
from huggingface_hub import snapshot_download, hf_hub_download

# 权重保存的基础目录 (对应 README 中的 ./weights)
WEIGHTS_BASE = "weights"

def download_models():
    print("🚀 [Step 1/2] 开始下载模型权重,请耐心等待...")
    
    # 1. 下载 Wan2.1 底模
    # README: huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
    wan_path = os.path.join(WEIGHTS_BASE, "Wan2.1-I2V-14B-480P")
    if not os.path.exists(wan_path):
        print(f"正在下载 Wan2.1-I2V-14B-480P 到 {wan_path} ...")
        snapshot_download(repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir=wan_path)

    # 2. 下载 InfiniteTalk 权重
    # README: huggingface-cli download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk
    it_path = os.path.join(WEIGHTS_BASE, "InfiniteTalk")
    if not os.path.exists(it_path):
        print(f"正在下载 InfiniteTalk 到 {it_path} ...")
        snapshot_download(repo_id="MeiGen-AI/InfiniteTalk", local_dir=it_path)

    # 3. 下载 Audio Encoder (wav2vec2)
    # README 提到这个比较特殊,分为两步:主仓库文件 + 特殊分支的 model.safetensors
    w2v_path = os.path.join(WEIGHTS_BASE, "chinese-wav2vec2-base")
    
    # 3.1 下载主仓库
    if not os.path.exists(w2v_path):
        print(f"正在下载 chinese-wav2vec2-base 主体到 {w2v_path} ...")
        snapshot_download(repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir=w2v_path)
    
    # 3.2 下载特定 Revision 的 model.safetensors
    # README: ... model.safetensors --revision refs/pr/1 ...
    safetensor_file = os.path.join(w2v_path, "model.safetensors")
    # 简单的检查:如果文件不存在或文件太小(可能没下载对),则下载
    if not os.path.exists(safetensor_file) or os.path.getsize(safetensor_file) < 1024:
        print("正在下载 wav2vec2 的 model.safetensors (revision refs/pr/1) ...")
        hf_hub_download(
            repo_id="TencentGameMate/chinese-wav2vec2-base",
            filename="model.safetensors",
            revision="refs/pr/1",
            local_dir=w2v_path,
            local_dir_use_symlinks=False
        )

    print("✅ 所有模型下载完成!")

def start_app():
    print("🚀 [Step 2/2] 启动 Gradio 应用...")
    
    # 构建启动命令
    # 参考 README "Run with Gradio" 部分
    # python app.py --ckpt_dir ... --wav2vec_dir ...
    
    cmd = [
        "python", "app.py",
        "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P",
        "--wav2vec_dir", "weights/chinese-wav2vec2-base",
        # README 示例中单人模式指向 single/infinitetalk.safetensors
        "--infinitetalk_dir", "weights/InfiniteTalk/single/infinitetalk.safetensors",
        # 添加低显存参数,防止 Space OOM (参考 README: Run with very low VRAM)
        "--num_persistent_param_in_dit", "0",
        "--motion_frame", "9"
    ]
    
    # 打印命令方便调试
    print(f"执行命令: {' '.join(cmd)}")
    
    # 运行
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    download_models()
    start_app()