| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import re |
| import sys |
| import json |
| import time |
| import shutil |
| import zipfile |
| import hashlib |
| import subprocess |
| from pathlib import Path |
| from datetime import datetime |
| from typing import Any, Dict, Tuple, Optional |
|
|
| |
| |
| |
| def _pip_install(pkgs): |
| """在 Spaces 里尽量避免反复安装:用一个标记文件 + 简单 import 探测。""" |
| cmd = [sys.executable, "-m", "pip", "install", "-U"] + pkgs |
| print("[pip]", " ".join(cmd)) |
| subprocess.check_call(cmd) |
|
|
| def _ensure_imports(): |
| """ |
| 只安装本 app 直接需要的包。 |
| LongCat-Video 自身依赖很多(官方 requirements),这里不强制全量预装, |
| 而是交给官方脚本在运行时 import;若缺包会在日志里体现,再按需加到下面列表。 |
| """ |
| try: |
| import gradio as gr |
| except Exception: |
| _pip_install(["gradio>=4.0.0"]) |
|
|
| try: |
| import requests |
| except Exception: |
| _pip_install(["requests>=2.31.0"]) |
|
|
| try: |
| from huggingface_hub import snapshot_download |
| except Exception: |
| _pip_install(["huggingface_hub[cli]>=0.24.0"]) |
|
|
| |
| try: |
| import spaces |
| except Exception: |
| _pip_install(["spaces>=0.27.0"]) |
|
|
| _ensure_imports() |
|
|
| import gradio as gr |
| import requests |
| from huggingface_hub import snapshot_download |
|
|
| |
| import spaces |
|
|
|
|
| |
| |
| |
| GITHUB_ZIP_URL = "https://github.com/meituan-longcat/LongCat-Video/archive/refs/heads/main.zip" |
|
|
| |
| HF_MODEL_LONGCAT_VIDEO = "meituan-longcat/LongCat-Video" |
| HF_MODEL_LONGCAT_AVATAR = "meituan-longcat/LongCat-Video-Avatar" |
|
|
| |
| BASE_DIR = Path(__file__).parent.resolve() |
| CACHE_DIR = BASE_DIR / "_cache" |
| REPO_DIR = CACHE_DIR / "LongCat-Video-main" |
| WEIGHTS_DIR = CACHE_DIR / "weights" |
| WEIGHTS_LONGCAT_VIDEO = WEIGHTS_DIR / "LongCat-Video" |
| WEIGHTS_LONGCAT_AVATAR = WEIGHTS_DIR / "LongCat-Video-Avatar" |
| OUTPUT_DIR = CACHE_DIR / "outputs" |
| TMP_DIR = CACHE_DIR / "tmp" |
|
|
| |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
|
|
| |
| |
| |
| def _sha1(s: str) -> str: |
| return hashlib.sha1(s.encode("utf-8")).hexdigest()[:10] |
|
|
| def _run(cmd, cwd: Optional[Path] = None, env: Optional[Dict[str, str]] = None) -> Tuple[int, str]: |
| """运行命令并返回 (code, stdout+stderr)。""" |
| print("[run]", " ".join(cmd)) |
| p = subprocess.Popen( |
| cmd, |
| cwd=str(cwd) if cwd else None, |
| env=env, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| bufsize=1, |
| universal_newlines=True, |
| ) |
| out_lines = [] |
| while True: |
| line = p.stdout.readline() |
| if not line and p.poll() is not None: |
| break |
| if line: |
| out_lines.append(line) |
| code = p.wait() |
| return code, "".join(out_lines) |
|
|
| def _download_and_extract_repo(): |
| """下载并解压 GitHub zip 到 CACHE_DIR。""" |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| zip_path = CACHE_DIR / "LongCat-Video-main.zip" |
|
|
| if REPO_DIR.exists() and (REPO_DIR / "run_demo_avatar_single_audio_to_video.py").exists(): |
| return |
|
|
| |
| if REPO_DIR.exists(): |
| shutil.rmtree(REPO_DIR, ignore_errors=True) |
|
|
| |
| if not zip_path.exists(): |
| r = requests.get(GITHUB_ZIP_URL, stream=True, timeout=120) |
| r.raise_for_status() |
| with open(zip_path, "wb") as f: |
| for chunk in r.iter_content(chunk_size=1024 * 1024): |
| if chunk: |
| f.write(chunk) |
|
|
| |
| with zipfile.ZipFile(zip_path, "r") as zf: |
| zf.extractall(CACHE_DIR) |
|
|
| |
| if not (REPO_DIR / "run_demo_avatar_single_audio_to_video.py").exists(): |
| raise RuntimeError("仓库解压后未找到 run_demo_avatar_single_audio_to_video.py,可能 GitHub 结构变化。") |
|
|
| def _download_weights(): |
| """下载 HF 权重到 WEIGHTS_DIR。""" |
| WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| token = os.environ.get("HF_TOKEN", None) |
|
|
| if not WEIGHTS_LONGCAT_VIDEO.exists(): |
| snapshot_download( |
| repo_id=HF_MODEL_LONGCAT_VIDEO, |
| local_dir=str(WEIGHTS_LONGCAT_VIDEO), |
| token=token, |
| local_dir_use_symlinks=False, |
| ) |
|
|
| if not WEIGHTS_LONGCAT_AVATAR.exists(): |
| snapshot_download( |
| repo_id=HF_MODEL_LONGCAT_AVATAR, |
| local_dir=str(WEIGHTS_LONGCAT_AVATAR), |
| token=token, |
| local_dir_use_symlinks=False, |
| ) |
|
|
| def _recursive_patch_attention_backend(obj: Any) -> Any: |
| """ |
| 递归把 config 里疑似 flash-attn backend 的字段替换为 sdpa。 |
| 不依赖具体 key 名,尽量“宽松匹配”: |
| - key 或 value 里出现 flash / flashattn / flash_attn => 改成 "sdpa" |
| """ |
| if isinstance(obj, dict): |
| new = {} |
| for k, v in obj.items(): |
| lk = str(k).lower() |
| if any(x in lk for x in ["attn", "attention", "backend"]): |
| |
| vv = _recursive_patch_attention_backend(v) |
| |
| if isinstance(vv, str) and ("flash" in vv.lower() or "flash_attn" in vv.lower() or "flashattn" in vv.lower()): |
| new[k] = "sdpa" |
| else: |
| new[k] = vv |
| else: |
| new[k] = _recursive_patch_attention_backend(v) |
| return new |
| elif isinstance(obj, list): |
| return [_recursive_patch_attention_backend(x) for x in obj] |
| else: |
| |
| if isinstance(obj, str): |
| lo = obj.lower() |
| if "flash_attn" in lo or "flashattn" in lo or lo.strip() == "flash" or "flash" == lo.strip(): |
| return "sdpa" |
| return obj |
|
|
| def _try_patch_avatar_configs(): |
| """ |
| 官方说明:avatar_single/config.json 和 avatar_multi/config.json 默认启用 FlashAttention-2 :contentReference[oaicite:8]{index=8} |
| 这里尽量替换为 sdpa,避免必须安装 flash-attn。 |
| """ |
| cfgs = [ |
| WEIGHTS_LONGCAT_AVATAR / "avatar_single" / "config.json", |
| WEIGHTS_LONGCAT_AVATAR / "avatar_multi" / "config.json", |
| ] |
| for cfg in cfgs: |
| if not cfg.exists(): |
| continue |
| try: |
| raw = json.loads(cfg.read_text(encoding="utf-8")) |
| patched = _recursive_patch_attention_backend(raw) |
| if patched != raw: |
| cfg.write_text(json.dumps(patched, ensure_ascii=False, indent=2), encoding="utf-8") |
| except Exception as e: |
| print(f"[warn] patch config failed: {cfg} -> {e}") |
|
|
| def _load_template_json(template_path: Path) -> Dict[str, Any]: |
| data = json.loads(template_path.read_text(encoding="utf-8")) |
| if not isinstance(data, dict): |
| raise ValueError("模板 JSON 不是 dict 结构,无法安全修改。") |
| return data |
|
|
| def _recursive_replace_first_match(data: Any, key_pred, value_pred, new_value) -> Tuple[Any, bool]: |
| """ |
| 在任意 JSON 结构中,找到第一个满足条件的 (key, value) 并替换 value。 |
| 返回 (new_data, replaced?) |
| """ |
| if isinstance(data, dict): |
| out = {} |
| replaced = False |
| for k, v in data.items(): |
| if (not replaced) and key_pred(k) and value_pred(v): |
| out[k] = new_value |
| replaced = True |
| else: |
| nv, r = _recursive_replace_first_match(v, key_pred, value_pred, new_value) |
| out[k] = nv |
| replaced = replaced or r |
| return out, replaced |
| elif isinstance(data, list): |
| out_list = [] |
| replaced = False |
| for item in data: |
| if replaced: |
| out_list.append(item) |
| continue |
| nv, r = _recursive_replace_first_match(item, key_pred, value_pred, new_value) |
| out_list.append(nv) |
| replaced = replaced or r |
| return out_list, replaced |
| else: |
| return data, False |
|
|
| def _build_input_json_single( |
| mode: str, |
| audio_path: Path, |
| prompt: str, |
| ref_image_path: Optional[Path], |
| seed: int, |
| resolution: str |
| ) -> Path: |
| """ |
| 基于 assets/avatar/single_example_1.json 模板生成 input_json。 |
| 官方脚本以 --input_json 读取参数 :contentReference[oaicite:9]{index=9} |
| """ |
| template = REPO_DIR / "assets" / "avatar" / "single_example_1.json" |
| if not template.exists(): |
| raise RuntimeError("未找到模板 assets/avatar/single_example_1.json(仓库结构可能变化)。") |
|
|
| data = _load_template_json(template) |
|
|
| |
| data, _ = _recursive_replace_first_match( |
| data, |
| key_pred=lambda k: "prompt" in str(k).lower() or "text" in str(k).lower(), |
| value_pred=lambda v: isinstance(v, str), |
| new_value=prompt.strip() if prompt else "A person is talking." |
| ) |
|
|
| |
| data, _ = _recursive_replace_first_match( |
| data, |
| key_pred=lambda k: "audio" in str(k).lower(), |
| value_pred=lambda v: isinstance(v, str), |
| new_value=str(audio_path) |
| ) |
|
|
| |
| if mode == "ai2v" and ref_image_path is not None: |
| data, _ = _recursive_replace_first_match( |
| data, |
| key_pred=lambda k: ("image" in str(k).lower()) or ("ref" in str(k).lower()), |
| value_pred=lambda v: isinstance(v, str), |
| new_value=str(ref_image_path) |
| ) |
|
|
| |
| data, _ = _recursive_replace_first_match( |
| data, |
| key_pred=lambda k: "seed" in str(k).lower(), |
| value_pred=lambda v: isinstance(v, (int, float, str)), |
| new_value=int(seed) |
| ) |
|
|
| |
| data, _ = _recursive_replace_first_match( |
| data, |
| key_pred=lambda k: "resolution" in str(k).lower(), |
| value_pred=lambda v: isinstance(v, str), |
| new_value=str(resolution) |
| ) |
|
|
| TMP_DIR.mkdir(parents=True, exist_ok=True) |
| out_path = TMP_DIR / f"single_{mode}_{_sha1(str(audio_path) + prompt + str(seed) + str(time.time()))}.json" |
| out_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") |
| return out_path |
|
|
| def _find_latest_mp4(since_ts: float) -> Optional[Path]: |
| if not OUTPUT_DIR.exists(): |
| return None |
| candidates = [] |
| for p in OUTPUT_DIR.rglob("*.mp4"): |
| try: |
| if p.stat().st_mtime >= since_ts - 2: |
| candidates.append(p) |
| except Exception: |
| pass |
| if not candidates: |
| return None |
| candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True) |
| return candidates[0] |
|
|
| def _ensure_ready() -> str: |
| """ |
| 准备: |
| - 下载 repo |
| - 下载权重 |
| - 尝试 patch attention backend |
| """ |
| _download_and_extract_repo() |
| _download_weights() |
| _try_patch_avatar_configs() |
|
|
| |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| return "准备完成:代码与权重已就绪。" |
|
|
|
|
| |
| |
| |
| @spaces.GPU(duration=900) |
| def generate_single( |
| mode: str, |
| audio_file: str, |
| prompt: str, |
| ref_image_file: Optional[str], |
| seed: int, |
| resolution: str, |
| num_segments: int, |
| ref_img_index: int, |
| mask_frame_range: int, |
| ) -> Tuple[Optional[str], str]: |
| """ |
| 返回:(mp4路径 or None, 日志文本) |
| """ |
| t0 = time.time() |
|
|
| |
| audio_path = Path(audio_file).resolve() |
| ref_image_path = Path(ref_image_file).resolve() if ref_image_file else None |
|
|
| |
| input_json = _build_input_json_single( |
| mode=mode, |
| audio_path=audio_path, |
| prompt=prompt, |
| ref_image_path=ref_image_path, |
| seed=seed, |
| resolution=resolution, |
| ) |
|
|
| |
| |
| |
| cmd = [ |
| sys.executable, "-m", "torch.distributed.run", |
| "--nproc_per_node=1", |
| "run_demo_avatar_single_audio_to_video.py", |
| "--context_parallel_size=1", |
| f"--checkpoint_dir={WEIGHTS_LONGCAT_AVATAR}", |
| f"--stage_1={mode}", |
| f"--input_json={input_json}", |
| f"--resolution={resolution}", |
| ] |
|
|
| |
| if num_segments and int(num_segments) > 1: |
| cmd += [ |
| f"--num_segments={int(num_segments)}", |
| f"--ref_img_index={int(ref_img_index)}", |
| f"--mask_frame_range={int(mask_frame_range)}", |
| ] |
|
|
| |
| env = dict(os.environ) |
| env["PYTHONPATH"] = str(REPO_DIR) + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") |
| env["HF_HOME"] = str(CACHE_DIR / "hf_home") |
| env["TORCH_HOME"] = str(CACHE_DIR / "torch_home") |
|
|
| |
| |
| env["OUTPUT_DIR"] = str(OUTPUT_DIR) |
|
|
| code, log = _run(cmd, cwd=REPO_DIR, env=env) |
|
|
| |
| mp4 = _find_latest_mp4(t0) |
| if mp4 is None: |
| |
| repo_candidates = list(REPO_DIR.rglob("*.mp4")) |
| repo_candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True) |
| if repo_candidates and repo_candidates[0].stat().st_mtime >= t0 - 2: |
| mp4 = repo_candidates[0] |
|
|
| if code != 0: |
| return None, f"执行失败(exit={code})。日志如下:\n\n{log}" |
|
|
| if mp4 is None or not mp4.exists(): |
| return None, f"执行完成,但未找到 mp4 输出文件。日志如下:\n\n{log}" |
|
|
| return str(mp4), f"执行成功:{mp4}\n\n日志如下:\n\n{log}" |
|
|
|
|
| |
| |
| |
| def ui_prepare() -> str: |
| try: |
| return _ensure_ready() |
| except Exception as e: |
| return f"准备失败:{e}" |
|
|
| with gr.Blocks(title="LongCat-Video-Avatar (ZeroGPU) - Single File Space") as demo: |
| gr.Markdown( |
| """ |
| # LongCat-Video-Avatar(ZeroGPU / 单文件 Space) |
| |
| - 单人模式:**AT2V(音频+文本)** / **AI2V(音频+图片)** |
| - 续写(Video Continuation):把 **num_segments** 设为 > 1 即可(官方参数:ref_img_index / mask_frame_range) |
| - 提示:为了更自然的口型,prompt 里建议包含 talking/speaking 等动作词(模型卡建议) |
| """ |
| ) |
|
|
| with gr.Row(): |
| btn_prepare = gr.Button("一键准备(下载代码+权重)", variant="primary") |
| prep_status = gr.Textbox(label="准备状态", value="尚未准备。首次准备会下载较大权重。", lines=2) |
|
|
| btn_prepare.click(fn=ui_prepare, outputs=prep_status) |
|
|
| with gr.Row(): |
| mode = gr.Radio( |
| choices=[("Audio-Text-to-Video (AT2V)", "at2v"), ("Audio-Image-to-Video (AI2V)", "ai2v")], |
| value="at2v", |
| label="模式" |
| ) |
|
|
| with gr.Row(): |
| audio_in = gr.Audio(label="输入音频(wav/mp3等)", type="filepath") |
| ref_img = gr.Image(label="参考图(仅 AI2V 需要)", type="filepath") |
|
|
| prompt = gr.Textbox( |
| label="Prompt(建议包含 talking/speaking 等动作词)", |
| value="A young person is talking naturally, realistic style.", |
| lines=2 |
| ) |
|
|
| with gr.Row(): |
| seed = gr.Number(label="Seed", value=0, precision=0) |
| resolution = gr.Dropdown(label="分辨率", choices=["480P", "720P"], value="480P") |
|
|
| with gr.Accordion("高级参数(续写/一致性/防重复)", open=False): |
| num_segments = gr.Slider(label="num_segments(>1 启用续写)", minimum=1, maximum=8, step=1, value=1) |
| ref_img_index = gr.Slider(label="ref_img_index(默认 10)", minimum=-30, maximum=60, step=1, value=10) |
| mask_frame_range = gr.Slider(label="mask_frame_range(默认 3)", minimum=1, maximum=12, step=1, value=3) |
|
|
| btn = gr.Button("生成视频", variant="primary") |
|
|
| out_video = gr.Video(label="输出视频(mp4)") |
| out_log = gr.Textbox(label="运行日志", lines=18) |
|
|
| def _validate(mode_v, audio_fp, img_fp): |
| if not audio_fp: |
| raise gr.Error("请先上传音频。") |
| if mode_v == "ai2v" and not img_fp: |
| raise gr.Error("AI2V 模式必须上传参考图。") |
|
|
| def run(mode_v, audio_fp, prompt_v, img_fp, seed_v, res_v, seg_v, idx_v, mask_v): |
| _validate(mode_v, audio_fp, img_fp) |
| |
| return generate_single(mode_v, audio_fp, prompt_v, img_fp, int(seed_v), res_v, int(seg_v), int(idx_v), int(mask_v)) |
|
|
| btn.click( |
| fn=run, |
| inputs=[mode, audio_in, prompt, ref_img, seed, resolution, num_segments, ref_img_index, mask_frame_range], |
| outputs=[out_video, out_log], |
| ) |
|
|
| demo.queue(max_size=12).launch() |
|
|