Spaces:
Paused
Paused
Upload 2 files
Browse files- Dockerfile.txt +56 -0
- run.py +76 -0
Dockerfile.txt
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. 使用与 README 要求一致的 PyTorch 版本 (Torch 2.4.1)
|
| 2 |
+
FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-devel
|
| 3 |
+
|
| 4 |
+
# 设置非交互式安装
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
|
| 7 |
+
# 2. 安装系统依赖 (README 要求安装 FFmpeg)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
git \
|
| 10 |
+
git-lfs \
|
| 11 |
+
ffmpeg \
|
| 12 |
+
libgl1-mesa-glx \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# 设置工作目录
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
|
| 19 |
+
# 3. 克隆 InfiniteTalk 仓库
|
| 20 |
+
RUN git clone https://github.com/MeiGen-AI/InfiniteTalk.git
|
| 21 |
+
|
| 22 |
+
# 进入仓库目录
|
| 23 |
+
WORKDIR /app/InfiniteTalk
|
| 24 |
+
|
| 25 |
+
# 4. 安装 Python 依赖 (严格按照 README 顺序)
|
| 26 |
+
# 先升级 pip
|
| 27 |
+
RUN pip install --no-cache-dir --upgrade pip
|
| 28 |
+
|
| 29 |
+
# 安装 PyTorch 相关 (基础镜像已有,但为了确保 xformers 匹配,我们可以再次确认或跳过)
|
| 30 |
+
# README: pip install -U xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121
|
| 31 |
+
RUN pip install -U xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121
|
| 32 |
+
|
| 33 |
+
# 安装 Flash Attention 及其依赖 (README 步骤 2)
|
| 34 |
+
# misaki[en], ninja, psutil, packaging, wheel
|
| 35 |
+
RUN pip install --no-cache-dir misaki[en] ninja psutil packaging wheel
|
| 36 |
+
# 安装 flash-attn
|
| 37 |
+
RUN pip install flash_attn==2.7.4.post1
|
| 38 |
+
|
| 39 |
+
# 安装其他依赖 (README 步骤 3)
|
| 40 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 41 |
+
# 补充一些 Space 运行需要的库
|
| 42 |
+
RUN pip install --no-cache-dir gradio huggingface_hub
|
| 43 |
+
|
| 44 |
+
# 5. 设置权限 (Hugging Face Space 安全要求)
|
| 45 |
+
RUN useradd -m -u 1000 user
|
| 46 |
+
RUN chown -R user:user /app
|
| 47 |
+
USER user
|
| 48 |
+
|
| 49 |
+
# 6. 设置 Gradio 环境变量,确保可以通过公网访问
|
| 50 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 51 |
+
|
| 52 |
+
# 7. 复制我们的启动脚本 (下一步创建)
|
| 53 |
+
COPY --chown=user:user run.py /app/InfiniteTalk/run.py
|
| 54 |
+
|
| 55 |
+
# 8. 启动
|
| 56 |
+
CMD ["python", "run.py"]
|
run.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 4 |
+
|
| 5 |
+
# 权重保存的基础目录 (对应 README 中的 ./weights)
|
| 6 |
+
WEIGHTS_BASE = "weights"
|
| 7 |
+
|
| 8 |
+
def download_models():
|
| 9 |
+
print("🚀 [Step 1/2] 开始下载模型权重,请耐心等待...")
|
| 10 |
+
|
| 11 |
+
# 1. 下载 Wan2.1 底模
|
| 12 |
+
# README: huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
|
| 13 |
+
wan_path = os.path.join(WEIGHTS_BASE, "Wan2.1-I2V-14B-480P")
|
| 14 |
+
if not os.path.exists(wan_path):
|
| 15 |
+
print(f"正在下载 Wan2.1-I2V-14B-480P 到 {wan_path} ...")
|
| 16 |
+
snapshot_download(repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir=wan_path)
|
| 17 |
+
|
| 18 |
+
# 2. 下载 InfiniteTalk 权重
|
| 19 |
+
# README: huggingface-cli download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk
|
| 20 |
+
it_path = os.path.join(WEIGHTS_BASE, "InfiniteTalk")
|
| 21 |
+
if not os.path.exists(it_path):
|
| 22 |
+
print(f"正在下载 InfiniteTalk 到 {it_path} ...")
|
| 23 |
+
snapshot_download(repo_id="MeiGen-AI/InfiniteTalk", local_dir=it_path)
|
| 24 |
+
|
| 25 |
+
# 3. 下载 Audio Encoder (wav2vec2)
|
| 26 |
+
# README 提到这个比较特殊,分为两步:主仓库文件 + 特殊分支的 model.safetensors
|
| 27 |
+
w2v_path = os.path.join(WEIGHTS_BASE, "chinese-wav2vec2-base")
|
| 28 |
+
|
| 29 |
+
# 3.1 下载主仓库
|
| 30 |
+
if not os.path.exists(w2v_path):
|
| 31 |
+
print(f"正在下载 chinese-wav2vec2-base 主体到 {w2v_path} ...")
|
| 32 |
+
snapshot_download(repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir=w2v_path)
|
| 33 |
+
|
| 34 |
+
# 3.2 下载特定 Revision 的 model.safetensors
|
| 35 |
+
# README: ... model.safetensors --revision refs/pr/1 ...
|
| 36 |
+
safetensor_file = os.path.join(w2v_path, "model.safetensors")
|
| 37 |
+
# 简单的检查:如果文件不存在或文件太小(可能没下载对),则下载
|
| 38 |
+
if not os.path.exists(safetensor_file) or os.path.getsize(safetensor_file) < 1024:
|
| 39 |
+
print("正在下载 wav2vec2 的 model.safetensors (revision refs/pr/1) ...")
|
| 40 |
+
hf_hub_download(
|
| 41 |
+
repo_id="TencentGameMate/chinese-wav2vec2-base",
|
| 42 |
+
filename="model.safetensors",
|
| 43 |
+
revision="refs/pr/1",
|
| 44 |
+
local_dir=w2v_path,
|
| 45 |
+
local_dir_use_symlinks=False
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
print("✅ 所有模型下载完成!")
|
| 49 |
+
|
| 50 |
+
def start_app():
|
| 51 |
+
print("🚀 [Step 2/2] 启动 Gradio 应用...")
|
| 52 |
+
|
| 53 |
+
# 构建启动命令
|
| 54 |
+
# 参考 README "Run with Gradio" 部分
|
| 55 |
+
# python app.py --ckpt_dir ... --wav2vec_dir ...
|
| 56 |
+
|
| 57 |
+
cmd = [
|
| 58 |
+
"python", "app.py",
|
| 59 |
+
"--ckpt_dir", "weights/Wan2.1-I2V-14B-480P",
|
| 60 |
+
"--wav2vec_dir", "weights/chinese-wav2vec2-base",
|
| 61 |
+
# README 示例中单人模式指向 single/infinitetalk.safetensors
|
| 62 |
+
"--infinitetalk_dir", "weights/InfiniteTalk/single/infinitetalk.safetensors",
|
| 63 |
+
# 添加低显存参数,防止 Space OOM (参考 README: Run with very low VRAM)
|
| 64 |
+
"--num_persistent_param_in_dit", "0",
|
| 65 |
+
"--motion_frame", "9"
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# 打印命令方便调试
|
| 69 |
+
print(f"执行命令: {' '.join(cmd)}")
|
| 70 |
+
|
| 71 |
+
# 运行
|
| 72 |
+
subprocess.run(cmd, check=True)
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
download_models()
|
| 76 |
+
start_app()
|