File size: 2,688 Bytes
257f706
 
 
 
 
3297c4b
 
 
257f706
 
3297c4b
257f706
3297c4b
 
 
 
 
257f706
3297c4b
 
 
257f706
3297c4b
 
 
 
 
 
 
 
 
 
257f706
 
 
5fb16c4
3297c4b
257f706
3297c4b
 
5fb16c4
58f3778
5fb16c4
 
 
 
 
3297c4b
257f706
3297c4b
257f706
3297c4b
 
ea97ae7
 
3297c4b
 
 
 
 
 
 
 
257f706
7a6fd96
3297c4b
 
 
 
257f706
 
 
 
 
3297c4b
257f706
ea97ae7
 
 
257f706
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import argparse
from process_pipepline import ProcessPipeline

# simple args container like before
class _Args:
    pass

def _parse_args():
    args = _Args()

    # general paths
    args.ckpt_path     = "./Wan2.2-Animate-14B/process_checkpoint"
    args.video_path    = None
    args.refer_path    = None
    args.save_path     = None

    # processing parameters
    args.resolution_area = [1280, 720]
    args.fps             = 30

    # feature flags
    args.replace_flag  = True
    args.retarget_flag = False
    args.use_flux      = False

    # mask strategy parameters (replacement mode)
    args.iterations = 3
    args.k          = 7
    args.w_len      = 1
    args.h_len      = 1

    return args

def load_preprocess_models(max_duration_s):
    ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"

    pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
    det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')

    if max_duration_s < 0:
        print("using small sam2")
        sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_small.pt'),"sam2_hiera_s.yaml"]
    else:
        sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt'),"sam2_hiera_l.yaml"]

    flux_kontext_path = None

    process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)

    return process_pipeline

def run(process_pipeline, input_video, edited_frame, preprocess_dir, w, h, tag_string,

        pts_by_frame: dict, lbs_by_frame: dict):
    args = _parse_args()

    if tag_string == "retarget_flag":
        retarget_flag = True
        replace_flag = False
    else:
        retarget_flag = False
        replace_flag = True

    os.makedirs(preprocess_dir, exist_ok=True)
    process_pipeline(video_path=input_video, 
                     refer_image_path=edited_frame, 
                     output_path=preprocess_dir,
                     resolution_area=[w, h],
                     fps=args.fps,
                     iterations=args.iterations,
                     k=args.k,
                     w_len=args.w_len,
                     h_len=args.h_len,
                     retarget_flag=retarget_flag,
                     use_flux=args.use_flux,
                     replace_flag=replace_flag,
                     pts_by_frame=pts_by_frame,
                     lbs_by_frame=lbs_by_frame)