Spaces:
Running
Running
| # app.py | |
| # HuggingFace Spaces (Gradio + ZeroGPU) 单文件示例: | |
| # - 自动下载 LongCat-Video GitHub 代码(zip) | |
| # - 自动下载 LongCat-Video / LongCat-Video-Avatar 权重(HF Hub) | |
| # - 通过 spaces.GPU 在 ZeroGPU 环境下按需申请 GPU 执行推理 | |
| # - 支持单人:AT2V / AI2V | |
| # | |
| # 说明: | |
| # 1) 官方示例使用 torchrun nproc=2(多进程/可能更快): | |
| # 这里默认改为 nproc=1 + context_parallel_size=1,更适合 Spaces。 | |
| # 2) FlashAttention 默认在 config 开启,但在 Spaces 上未必能顺利安装; | |
| # 本示例会尝试把 config 里所有包含 "flash" 的 attention backend 字段递归替换为 "sdpa"。 | |
| # | |
| # 参考: | |
| # - ZeroGPU 官方用法:@spaces.GPU(duration=...) :contentReference[oaicite:5]{index=5}(用户侧不需要引用,代码内不写引用) | |
| # - LongCat-Video-Avatar 模型卡:推理命令/参数/权重目录结构 :contentReference[oaicite:6]{index=6} | |
| import os | |
| import re | |
| import sys | |
| import json | |
| import time | |
| import shutil | |
| import zipfile | |
| import hashlib | |
| import subprocess | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Any, Dict, Tuple, Optional | |
| # ---------------------------- | |
| # 运行时“尽量单文件”的依赖安装 | |
| # ---------------------------- | |
| def _pip_install(pkgs): | |
| """在 Spaces 里尽量避免反复安装:用一个标记文件 + 简单 import 探测。""" | |
| cmd = [sys.executable, "-m", "pip", "install", "-U"] + pkgs | |
| print("[pip]", " ".join(cmd)) | |
| subprocess.check_call(cmd) | |
| def _ensure_imports(): | |
| """ | |
| 只安装本 app 直接需要的包。 | |
| LongCat-Video 自身依赖很多(官方 requirements),这里不强制全量预装, | |
| 而是交给官方脚本在运行时 import;若缺包会在日志里体现,再按需加到下面列表。 | |
| """ | |
| try: | |
| import gradio as gr # noqa | |
| except Exception: | |
| _pip_install(["gradio>=4.0.0"]) | |
| try: | |
| import requests # noqa | |
| except Exception: | |
| _pip_install(["requests>=2.31.0"]) | |
| try: | |
| from huggingface_hub import snapshot_download # noqa | |
| except Exception: | |
| _pip_install(["huggingface_hub[cli]>=0.24.0"]) | |
| # ZeroGPU 推荐的 spaces 包:多数 ZeroGPU 环境自带;没有就装 | |
| try: | |
| import spaces # noqa | |
| except Exception: | |
| _pip_install(["spaces>=0.27.0"]) | |
| _ensure_imports() | |
| import gradio as gr | |
| import requests | |
| from huggingface_hub import snapshot_download | |
| # spaces 在非 ZeroGPU 环境也应可安全使用;若导入失败已在上面安装 | |
| import spaces | |
| # ---------------------------- | |
| # 配置区(可按需改) | |
| # ---------------------------- | |
| GITHUB_ZIP_URL = "https://github.com/meituan-longcat/LongCat-Video/archive/refs/heads/main.zip" | |
| # HF 权重(模型卡说明的目录) :contentReference[oaicite:7]{index=7} | |
| HF_MODEL_LONGCAT_VIDEO = "meituan-longcat/LongCat-Video" | |
| HF_MODEL_LONGCAT_AVATAR = "meituan-longcat/LongCat-Video-Avatar" | |
| # 本地缓存目录:Spaces 上建议放到 /home/user 或当前目录 | |
| BASE_DIR = Path(__file__).parent.resolve() | |
| CACHE_DIR = BASE_DIR / "_cache" | |
| REPO_DIR = CACHE_DIR / "LongCat-Video-main" # zip 解压后的目录名 | |
| WEIGHTS_DIR = CACHE_DIR / "weights" | |
| WEIGHTS_LONGCAT_VIDEO = WEIGHTS_DIR / "LongCat-Video" | |
| WEIGHTS_LONGCAT_AVATAR = WEIGHTS_DIR / "LongCat-Video-Avatar" | |
| OUTPUT_DIR = CACHE_DIR / "outputs" | |
| TMP_DIR = CACHE_DIR / "tmp" | |
| # 为了减少 torch CUDA 内存碎片(有时有用) | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| # ---------------------------- | |
| # 工具函数 | |
| # ---------------------------- | |
| def _sha1(s: str) -> str: | |
| return hashlib.sha1(s.encode("utf-8")).hexdigest()[:10] | |
| def _run(cmd, cwd: Optional[Path] = None, env: Optional[Dict[str, str]] = None) -> Tuple[int, str]: | |
| """运行命令并返回 (code, stdout+stderr)。""" | |
| print("[run]", " ".join(cmd)) | |
| p = subprocess.Popen( | |
| cmd, | |
| cwd=str(cwd) if cwd else None, | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| ) | |
| out_lines = [] | |
| while True: | |
| line = p.stdout.readline() | |
| if not line and p.poll() is not None: | |
| break | |
| if line: | |
| out_lines.append(line) | |
| code = p.wait() | |
| return code, "".join(out_lines) | |
| def _download_and_extract_repo(): | |
| """下载并解压 GitHub zip 到 CACHE_DIR。""" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| zip_path = CACHE_DIR / "LongCat-Video-main.zip" | |
| if REPO_DIR.exists() and (REPO_DIR / "run_demo_avatar_single_audio_to_video.py").exists(): | |
| return | |
| # 清理旧目录 | |
| if REPO_DIR.exists(): | |
| shutil.rmtree(REPO_DIR, ignore_errors=True) | |
| # 下载 zip | |
| if not zip_path.exists(): | |
| r = requests.get(GITHUB_ZIP_URL, stream=True, timeout=120) | |
| r.raise_for_status() | |
| with open(zip_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| # 解压 | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(CACHE_DIR) | |
| # 基本校验 | |
| if not (REPO_DIR / "run_demo_avatar_single_audio_to_video.py").exists(): | |
| raise RuntimeError("仓库解压后未找到 run_demo_avatar_single_audio_to_video.py,可能 GitHub 结构变化。") | |
| def _download_weights(): | |
| """下载 HF 权重到 WEIGHTS_DIR。""" | |
| WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) | |
| # 使用 token(若你在 Space Secrets 里配置了 HF_TOKEN) | |
| token = os.environ.get("HF_TOKEN", None) | |
| if not WEIGHTS_LONGCAT_VIDEO.exists(): | |
| snapshot_download( | |
| repo_id=HF_MODEL_LONGCAT_VIDEO, | |
| local_dir=str(WEIGHTS_LONGCAT_VIDEO), | |
| token=token, | |
| local_dir_use_symlinks=False, | |
| ) | |
| if not WEIGHTS_LONGCAT_AVATAR.exists(): | |
| snapshot_download( | |
| repo_id=HF_MODEL_LONGCAT_AVATAR, | |
| local_dir=str(WEIGHTS_LONGCAT_AVATAR), | |
| token=token, | |
| local_dir_use_symlinks=False, | |
| ) | |
| def _recursive_patch_attention_backend(obj: Any) -> Any: | |
| """ | |
| 递归把 config 里疑似 flash-attn backend 的字段替换为 sdpa。 | |
| 不依赖具体 key 名,尽量“宽松匹配”: | |
| - key 或 value 里出现 flash / flashattn / flash_attn => 改成 "sdpa" | |
| """ | |
| if isinstance(obj, dict): | |
| new = {} | |
| for k, v in obj.items(): | |
| lk = str(k).lower() | |
| if any(x in lk for x in ["attn", "attention", "backend"]): | |
| # 先递归处理 value | |
| vv = _recursive_patch_attention_backend(v) | |
| # 再判断是否需要替换 | |
| if isinstance(vv, str) and ("flash" in vv.lower() or "flash_attn" in vv.lower() or "flashattn" in vv.lower()): | |
| new[k] = "sdpa" | |
| else: | |
| new[k] = vv | |
| else: | |
| new[k] = _recursive_patch_attention_backend(v) | |
| return new | |
| elif isinstance(obj, list): | |
| return [_recursive_patch_attention_backend(x) for x in obj] | |
| else: | |
| # 普通标量 | |
| if isinstance(obj, str): | |
| lo = obj.lower() | |
| if "flash_attn" in lo or "flashattn" in lo or lo.strip() == "flash" or "flash" == lo.strip(): | |
| return "sdpa" | |
| return obj | |
| def _try_patch_avatar_configs(): | |
| """ | |
| 官方说明:avatar_single/config.json 和 avatar_multi/config.json 默认启用 FlashAttention-2 :contentReference[oaicite:8]{index=8} | |
| 这里尽量替换为 sdpa,避免必须安装 flash-attn。 | |
| """ | |
| cfgs = [ | |
| WEIGHTS_LONGCAT_AVATAR / "avatar_single" / "config.json", | |
| WEIGHTS_LONGCAT_AVATAR / "avatar_multi" / "config.json", | |
| ] | |
| for cfg in cfgs: | |
| if not cfg.exists(): | |
| continue | |
| try: | |
| raw = json.loads(cfg.read_text(encoding="utf-8")) | |
| patched = _recursive_patch_attention_backend(raw) | |
| if patched != raw: | |
| cfg.write_text(json.dumps(patched, ensure_ascii=False, indent=2), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[warn] patch config failed: {cfg} -> {e}") | |
| def _load_template_json(template_path: Path) -> Dict[str, Any]: | |
| data = json.loads(template_path.read_text(encoding="utf-8")) | |
| if not isinstance(data, dict): | |
| raise ValueError("模板 JSON 不是 dict 结构,无法安全修改。") | |
| return data | |
| def _recursive_replace_first_match(data: Any, key_pred, value_pred, new_value) -> Tuple[Any, bool]: | |
| """ | |
| 在任意 JSON 结构中,找到第一个满足条件的 (key, value) 并替换 value。 | |
| 返回 (new_data, replaced?) | |
| """ | |
| if isinstance(data, dict): | |
| out = {} | |
| replaced = False | |
| for k, v in data.items(): | |
| if (not replaced) and key_pred(k) and value_pred(v): | |
| out[k] = new_value | |
| replaced = True | |
| else: | |
| nv, r = _recursive_replace_first_match(v, key_pred, value_pred, new_value) | |
| out[k] = nv | |
| replaced = replaced or r | |
| return out, replaced | |
| elif isinstance(data, list): | |
| out_list = [] | |
| replaced = False | |
| for item in data: | |
| if replaced: | |
| out_list.append(item) | |
| continue | |
| nv, r = _recursive_replace_first_match(item, key_pred, value_pred, new_value) | |
| out_list.append(nv) | |
| replaced = replaced or r | |
| return out_list, replaced | |
| else: | |
| return data, False | |
| def _build_input_json_single( | |
| mode: str, | |
| audio_path: Path, | |
| prompt: str, | |
| ref_image_path: Optional[Path], | |
| seed: int, | |
| resolution: str | |
| ) -> Path: | |
| """ | |
| 基于 assets/avatar/single_example_1.json 模板生成 input_json。 | |
| 官方脚本以 --input_json 读取参数 :contentReference[oaicite:9]{index=9} | |
| """ | |
| template = REPO_DIR / "assets" / "avatar" / "single_example_1.json" | |
| if not template.exists(): | |
| raise RuntimeError("未找到模板 assets/avatar/single_example_1.json(仓库结构可能变化)。") | |
| data = _load_template_json(template) | |
| # 替换 prompt:优先找 key 包含 prompt/text 之类 | |
| data, _ = _recursive_replace_first_match( | |
| data, | |
| key_pred=lambda k: "prompt" in str(k).lower() or "text" in str(k).lower(), | |
| value_pred=lambda v: isinstance(v, str), | |
| new_value=prompt.strip() if prompt else "A person is talking." | |
| ) | |
| # 替换 audio path:找 key 包含 audio 且 value 是字符串 | |
| data, _ = _recursive_replace_first_match( | |
| data, | |
| key_pred=lambda k: "audio" in str(k).lower(), | |
| value_pred=lambda v: isinstance(v, str), | |
| new_value=str(audio_path) | |
| ) | |
| # 替换 image path(仅 AI2V) | |
| if mode == "ai2v" and ref_image_path is not None: | |
| data, _ = _recursive_replace_first_match( | |
| data, | |
| key_pred=lambda k: ("image" in str(k).lower()) or ("ref" in str(k).lower()), | |
| value_pred=lambda v: isinstance(v, str), | |
| new_value=str(ref_image_path) | |
| ) | |
| # seed(若模板里有) | |
| data, _ = _recursive_replace_first_match( | |
| data, | |
| key_pred=lambda k: "seed" in str(k).lower(), | |
| value_pred=lambda v: isinstance(v, (int, float, str)), | |
| new_value=int(seed) | |
| ) | |
| # resolution(若模板里有) | |
| data, _ = _recursive_replace_first_match( | |
| data, | |
| key_pred=lambda k: "resolution" in str(k).lower(), | |
| value_pred=lambda v: isinstance(v, str), | |
| new_value=str(resolution) | |
| ) | |
| TMP_DIR.mkdir(parents=True, exist_ok=True) | |
| out_path = TMP_DIR / f"single_{mode}_{_sha1(str(audio_path) + prompt + str(seed) + str(time.time()))}.json" | |
| out_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") | |
| return out_path | |
| def _find_latest_mp4(since_ts: float) -> Optional[Path]: | |
| if not OUTPUT_DIR.exists(): | |
| return None | |
| candidates = [] | |
| for p in OUTPUT_DIR.rglob("*.mp4"): | |
| try: | |
| if p.stat().st_mtime >= since_ts - 2: | |
| candidates.append(p) | |
| except Exception: | |
| pass | |
| if not candidates: | |
| return None | |
| candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True) | |
| return candidates[0] | |
| def _ensure_ready() -> str: | |
| """ | |
| 准备: | |
| - 下载 repo | |
| - 下载权重 | |
| - 尝试 patch attention backend | |
| """ | |
| _download_and_extract_repo() | |
| _download_weights() | |
| _try_patch_avatar_configs() | |
| # 输出目录 | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| return "准备完成:代码与权重已就绪。" | |
| # ---------------------------- | |
| # GPU 推理函数(ZeroGPU 核心) | |
| # ---------------------------- | |
| # 生成视频通常 >60s,给足时间;你可视情况调小/调大 :contentReference[oaicite:10]{index=10} | |
| def generate_single( | |
| mode: str, | |
| audio_file: str, | |
| prompt: str, | |
| ref_image_file: Optional[str], | |
| seed: int, | |
| resolution: str, | |
| num_segments: int, | |
| ref_img_index: int, | |
| mask_frame_range: int, | |
| ) -> Tuple[Optional[str], str]: | |
| """ | |
| 返回:(mp4路径 or None, 日志文本) | |
| """ | |
| t0 = time.time() | |
| # 文件落盘路径(Gradio 传入的是本地临时文件路径字符串) | |
| audio_path = Path(audio_file).resolve() | |
| ref_image_path = Path(ref_image_file).resolve() if ref_image_file else None | |
| # 构造 input_json | |
| input_json = _build_input_json_single( | |
| mode=mode, | |
| audio_path=audio_path, | |
| prompt=prompt, | |
| ref_image_path=ref_image_path, | |
| seed=seed, | |
| resolution=resolution, | |
| ) | |
| # 运行官方脚本(单进程 torchrun) | |
| # 官方示例:torchrun --nproc_per_node=2 ... --context_parallel_size=2 ... :contentReference[oaicite:11]{index=11} | |
| # 这里适配 Space:nproc=1, context_parallel_size=1 | |
| cmd = [ | |
| sys.executable, "-m", "torch.distributed.run", | |
| "--nproc_per_node=1", | |
| "run_demo_avatar_single_audio_to_video.py", | |
| "--context_parallel_size=1", | |
| f"--checkpoint_dir={WEIGHTS_LONGCAT_AVATAR}", | |
| f"--stage_1={mode}", | |
| f"--input_json={input_json}", | |
| f"--resolution={resolution}", | |
| ] | |
| # 续写参数(用户设置 >1 才启用) | |
| if num_segments and int(num_segments) > 1: | |
| cmd += [ | |
| f"--num_segments={int(num_segments)}", | |
| f"--ref_img_index={int(ref_img_index)}", | |
| f"--mask_frame_range={int(mask_frame_range)}", | |
| ] | |
| # 环境变量:让脚本能找到模块 | |
| env = dict(os.environ) | |
| env["PYTHONPATH"] = str(REPO_DIR) + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") | |
| env["HF_HOME"] = str(CACHE_DIR / "hf_home") | |
| env["TORCH_HOME"] = str(CACHE_DIR / "torch_home") | |
| # 约定输出目录(若脚本支持/或脚本默认输出在当前目录下的 outputs) | |
| # 我们用 cwd + 输出扫描兜底 | |
| env["OUTPUT_DIR"] = str(OUTPUT_DIR) | |
| code, log = _run(cmd, cwd=REPO_DIR, env=env) | |
| # 尝试找到最新 mp4 | |
| mp4 = _find_latest_mp4(t0) | |
| if mp4 is None: | |
| # 兜底:在 repo 内也扫一下 | |
| repo_candidates = list(REPO_DIR.rglob("*.mp4")) | |
| repo_candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True) | |
| if repo_candidates and repo_candidates[0].stat().st_mtime >= t0 - 2: | |
| mp4 = repo_candidates[0] | |
| if code != 0: | |
| return None, f"执行失败(exit={code})。日志如下:\n\n{log}" | |
| if mp4 is None or not mp4.exists(): | |
| return None, f"执行完成,但未找到 mp4 输出文件。日志如下:\n\n{log}" | |
| return str(mp4), f"执行成功:{mp4}\n\n日志如下:\n\n{log}" | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| def ui_prepare() -> str: | |
| try: | |
| return _ensure_ready() | |
| except Exception as e: | |
| return f"准备失败:{e}" | |
| with gr.Blocks(title="LongCat-Video-Avatar (ZeroGPU) - Single File Space") as demo: | |
| gr.Markdown( | |
| """ | |
| # LongCat-Video-Avatar(ZeroGPU / 单文件 Space) | |
| - 单人模式:**AT2V(音频+文本)** / **AI2V(音频+图片)** | |
| - 续写(Video Continuation):把 **num_segments** 设为 > 1 即可(官方参数:ref_img_index / mask_frame_range) | |
| - 提示:为了更自然的口型,prompt 里建议包含 talking/speaking 等动作词(模型卡建议) | |
| """ | |
| ) | |
| with gr.Row(): | |
| btn_prepare = gr.Button("一键准备(下载代码+权重)", variant="primary") | |
| prep_status = gr.Textbox(label="准备状态", value="尚未准备。首次准备会下载较大权重。", lines=2) | |
| btn_prepare.click(fn=ui_prepare, outputs=prep_status) | |
| with gr.Row(): | |
| mode = gr.Radio( | |
| choices=[("Audio-Text-to-Video (AT2V)", "at2v"), ("Audio-Image-to-Video (AI2V)", "ai2v")], | |
| value="at2v", | |
| label="模式" | |
| ) | |
| with gr.Row(): | |
| audio_in = gr.Audio(label="输入音频(wav/mp3等)", type="filepath") | |
| ref_img = gr.Image(label="参考图(仅 AI2V 需要)", type="filepath") | |
| prompt = gr.Textbox( | |
| label="Prompt(建议包含 talking/speaking 等动作词)", | |
| value="A young person is talking naturally, realistic style.", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| resolution = gr.Dropdown(label="分辨率", choices=["480P", "720P"], value="480P") | |
| with gr.Accordion("高级参数(续写/一致性/防重复)", open=False): | |
| num_segments = gr.Slider(label="num_segments(>1 启用续写)", minimum=1, maximum=8, step=1, value=1) | |
| ref_img_index = gr.Slider(label="ref_img_index(默认 10)", minimum=-30, maximum=60, step=1, value=10) | |
| mask_frame_range = gr.Slider(label="mask_frame_range(默认 3)", minimum=1, maximum=12, step=1, value=3) | |
| btn = gr.Button("生成视频", variant="primary") | |
| out_video = gr.Video(label="输出视频(mp4)") | |
| out_log = gr.Textbox(label="运行日志", lines=18) | |
| def _validate(mode_v, audio_fp, img_fp): | |
| if not audio_fp: | |
| raise gr.Error("请先上传音频。") | |
| if mode_v == "ai2v" and not img_fp: | |
| raise gr.Error("AI2V 模式必须上传参考图。") | |
| def run(mode_v, audio_fp, prompt_v, img_fp, seed_v, res_v, seg_v, idx_v, mask_v): | |
| _validate(mode_v, audio_fp, img_fp) | |
| # seed=0 时也允许;如果想随机可自己改成 random | |
| return generate_single(mode_v, audio_fp, prompt_v, img_fp, int(seed_v), res_v, int(seg_v), int(idx_v), int(mask_v)) | |
| btn.click( | |
| fn=run, | |
| inputs=[mode, audio_in, prompt, ref_img, seed, resolution, num_segments, ref_img_index, mask_frame_range], | |
| outputs=[out_video, out_log], | |
| ) | |
| demo.queue(max_size=12).launch() | |