File size: 4,623 Bytes
40cfce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import argparse
from process_pipepline import ProcessPipeline


def _parse_args():
    parser = argparse.ArgumentParser(
        description="The preprocessing pipeline for Wan-animate."
    )

    parser.add_argument(
        "--ckpt_path",
        type=str,
        default=None,
        help="The path to the preprocessing model's checkpoint directory. ")

    parser.add_argument(
        "--video_path",
        type=str,
        default=None,
        help="The path to the driving video.")
    parser.add_argument(
        "--refer_path",
        type=str,
        default=None,
        help="The path to the refererence image.")
    parser.add_argument(
        "--save_path",
        type=str,
        default=None,
        help="The path to save the processed results.")
    
    parser.add_argument(
        "--resolution_area",
        type=int,
        nargs=2,
        default=[1280, 720],
        help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio."
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=30,
        help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS."
    )

    parser.add_argument(
        "--replace_flag",
        action="store_true",
        default=False,
        help="Whether to use replacement mode.")
    parser.add_argument(
        "--retarget_flag",
        action="store_true",
        default=False,
        help="Whether to use pose retargeting. Currently only supported in animation mode")
    parser.add_argument(
        "--use_flux",
        action="store_true",
        default=False,
        help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose")
    
    # Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145
    parser.add_argument(
        "--iterations",
        type=int,
        default=3,
        help="Number of iterations for mask dilation."
    )
    parser.add_argument(
        "--k",
        type=int,
        default=7,
        help="Number of kernel size for mask dilation."
    )
    parser.add_argument(
        "--w_len",
        type=int,
        default=1,
        help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
    )
    parser.add_argument(
        "--h_len",
        type=int,
        default=1,
        help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
    )
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = _parse_args()
    args_dict = vars(args)
    print(args_dict)

    assert len(args.resolution_area) == 2, "resolution_area should be a list of two integers [width, height]"
    assert not args.use_flux or args.retarget_flag, "Image editing with FLUX can only be used when pose retargeting is enabled."

    pose2d_checkpoint_path = os.path.join(args.ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
    det_checkpoint_path = os.path.join(args.ckpt_path, 'det/yolov10m.onnx')

    sam2_checkpoint_path = os.path.join(args.ckpt_path, 'sam2/sam2_hiera_large.pt') if args.replace_flag else None
    flux_kontext_path = os.path.join(args.ckpt_path, 'FLUX.1-Kontext-dev') if args.use_flux else None
    process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
    os.makedirs(args.save_path, exist_ok=True)
    process_pipeline(video_path=args.video_path, 
                     refer_image_path=args.refer_path, 
                     output_path=args.save_path,
                     resolution_area=args.resolution_area,
                     fps=args.fps,
                     iterations=args.iterations,
                     k=args.k,
                     w_len=args.w_len,
                     h_len=args.h_len,
                     retarget_flag=args.retarget_flag,
                     use_flux=args.use_flux,
                     replace_flag=args.replace_flag)