File size: 4,775 Bytes
1d0c0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""Run Wan T2V inference with the sparse FP4 checkpoint-700 transformer."""

from __future__ import annotations

import argparse
import os
from pathlib import Path


DEFAULT_PROMPT = (
    "In the video, a woman is elegantly showcasing her earrings, bringing "
    "attention to their intricate design with a gentle touch of her fingers. "
    "She is bathed in ambient purple and pink lighting, which casts a soft "
    "glow on her delicate features and enhances the vivid tones of her lipstick "
    "and eye makeup. Her hair is styled to frame her face smoothly, emphasizing "
    "the contours of her jawline and cheekbones. The background features a "
    "blurred neon light, adding an artistic and modern touch to the overall "
    "aesthetic."
)

DEFAULT_NEGATIVE_PROMPT = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, "
    "works, paintings, images, static, overall gray, worst quality, low quality, "
    "JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn "
    "hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused "
    "fingers, still picture, messy background, three legs, many people in the "
    "background, walking backwards"
)


def _resolve_weights(repo_id: str, weights: str | None, local_dir: str) -> str:
    if weights:
        path = Path(weights).expanduser()
        if path.exists():
            return str(path.resolve())
        raise FileNotFoundError(f"--weights does not exist: {path}")

    from huggingface_hub import hf_hub_download

    path = hf_hub_download(
        repo_id=repo_id,
        filename="transformer/diffusion_pytorch_model.safetensors",
        local_dir=local_dir,
        repo_type="model",
    )
    return str(Path(path).resolve())


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--repo-id", default="yitongl/sparse_quant_exp")
    parser.add_argument(
        "--model-path",
        default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
        help="Base Wan Diffusers model repo/path.",
    )
    parser.add_argument("--weights", default=None)
    parser.add_argument(
        "--local-dir",
        default="checkpoints/hf_download/sparse_quant_exp",
        help="Local Hugging Face download directory for the uploaded weights.",
    )
    parser.add_argument("--prompt", default=DEFAULT_PROMPT)
    parser.add_argument("--negative-prompt", default=DEFAULT_NEGATIVE_PROMPT)
    parser.add_argument("--output-path", default="outputs/sfp4_checkpoint_700")
    parser.add_argument("--height", type=int, default=448)
    parser.add_argument("--width", type=int, default=832)
    parser.add_argument("--num-frames", type=int, default=77)
    parser.add_argument("--num-inference-steps", type=int, default=50)
    parser.add_argument("--fps", type=int, default=16)
    parser.add_argument("--guidance-scale", type=float, default=5.0)
    parser.add_argument("--flow-shift", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=1000)
    parser.add_argument("--vsa-sparsity", type=float, default=0.9)
    parser.add_argument("--num-gpus", type=int, default=1)
    parser.add_argument("--sp-size", type=int, default=1)
    parser.add_argument("--tp-size", type=int, default=1)
    parser.add_argument("--text-encoder-cpu-offload", action="store_true", default=True)
    parser.add_argument("--pin-cpu-memory", action="store_true", default=False)
    args = parser.parse_args()

    os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "SPARSE_FP4_OURS_P_ATTN")
    os.environ.setdefault("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")

    weights_path = _resolve_weights(args.repo_id, args.weights, args.local_dir)

    from fastvideo import VideoGenerator

    generator = VideoGenerator.from_pretrained(
        model_path=args.model_path,
        num_gpus=args.num_gpus,
        sp_size=args.sp_size,
        tp_size=args.tp_size,
        init_weights_from_safetensors=weights_path,
        dit_cpu_offload=False,
        vae_cpu_offload=False,
        text_encoder_cpu_offload=args.text_encoder_cpu_offload,
        pin_cpu_memory=args.pin_cpu_memory,
        flow_shift=args.flow_shift,
        VSA_sparsity=args.vsa_sparsity,
    )

    result = generator.generate_video(
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        output_path=args.output_path,
        save_video=True,
        return_frames=False,
        height=args.height,
        width=args.width,
        num_frames=args.num_frames,
        num_inference_steps=args.num_inference_steps,
        fps=args.fps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
    )
    print(result)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())